diff --git a/.gitignore b/.gitignore index 832236e0..25471f18 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ venv +p2venv *.pyc dist ldclient_py.egg-info build/ -test.py \ No newline at end of file +.idea +*.iml +test.py diff --git a/README.md b/README.md index 8d5e86b2..e7744bed 100644 --- a/README.md +++ b/README.md @@ -32,10 +32,11 @@ Development information (for developing this module itself) pip install -r requirements.txt pip install -r test-requirements.txt + pip install -r twisted-requirements.txt 2. Run tests: - $ py.test + $ py.test testing Learn more diff --git a/circle.yml b/circle.yml index f13b8ea2..ce6235da 100644 --- a/circle.yml +++ b/circle.yml @@ -1,13 +1,15 @@ dependencies: pre: - - pyenv shell 2.6.8; $(pyenv which pip) install --upgrade pip + - pyenv shell 2.7.10; $(pyenv which pip) install --upgrade pip - pyenv shell 3.3.3; $(pyenv which pip) install --upgrade pip - - pyenv shell 2.6.8; $(pyenv which pip) install -r test-requirements.txt + - pyenv shell 2.7.10; $(pyenv which pip) install -r test-requirements.txt - pyenv shell 3.3.3; $(pyenv which pip) install -r test-requirements.txt - - pyenv shell 2.6.8; $(pyenv which python) setup.py install + - pyenv shell 2.7.10; $(pyenv which pip) install -r twisted-requirements.txt + - pyenv shell 3.3.3; $(pyenv which pip) install -r twisted-requirements.txt + - pyenv shell 2.7.10; $(pyenv which python) setup.py install - pyenv shell 3.3.3; $(pyenv which python) setup.py install test: override: - - pyenv shell 2.6.8; $(pyenv which py.test) testing - - pyenv shell 3.3.3; $(pyenv which py.test) testing + - pyenv shell 2.7.10; $(pyenv which py.test) testing + - pyenv shell 3.3.3; $(pyenv which py.test) --ignore=testing/test_sse_twisted.py -s testing diff --git a/demo/demo_twisted.py b/demo/demo_twisted.py new file mode 100644 index 00000000..e2b26972 --- /dev/null +++ b/demo/demo_twisted.py @@ -0,0 +1,20 @@ +from __future__ import print_function +from ldclient.twisted import TwistedLDClient +from twisted.internet import task, defer + +@defer.inlineCallbacks +def main(reactor): + apiKey = 'whatever' + client = TwistedLDClient(apiKey) + user = { + u'key': u'xyz', + u'custom': { + u'bizzle': u'def' + } + } + val = yield client.toggle('foo', user) + yield client.flush() + print("Value: {}".format(val)) + +if __name__ == '__main__': + task.react(main) \ No newline at end of file diff --git a/ldclient/__init__.py b/ldclient/__init__.py index e5ce0490..eddad03b 100644 --- a/ldclient/__init__.py +++ b/ldclient/__init__.py @@ -55,10 +55,14 @@ def __init__(self, capacity = 10000, stream_uri = 'https://stream.launchdarkly.com', stream = False, - verify = True): + verify = True, + stream_processor_class = None, + feature_store_class = None): self._base_uri = base_uri.rstrip('\\') self._stream_uri = stream_uri.rstrip('\\') self._stream = stream + self._stream_processor_class = StreamProcessor if not stream_processor_class else stream_processor_class + self._feature_store_class = InMemoryFeatureStore if not feature_store_class else feature_store_class self._connect = connect_timeout self._read = read_timeout self._upload_limit = upload_limit @@ -135,27 +139,19 @@ def __init__(self, api_key, config): self.daemon = True self._api_key = api_key self._config = config - self._store = InMemoryFeatureStore() + self._store = config._feature_store_class() + self._running = False def run(self): log.debug("Starting stream processor") + self._running = True hdrs = _stream_headers(self._api_key) uri = self._config._stream_uri + "/" messages = SSEClient(uri, verify = self._config._verify, headers = hdrs) for msg in messages: - payload = json.loads(msg.data) - if msg.event == 'put/features': - self._store.init(payload) - elif msg.event == 'patch/features': - key = payload['path'][1:] - feature = payload['data'] - self._store.upsert(key, feature) - elif msg.event == 'delete/features': - key = payload['path'][1:] - version = payload['version'] - self._store.delete(key, version) - else: - log.warning('Unhandled event in stream processor: ' + msg.event) + if not self._running: + break + self.process_message(self._store, msg) def initialized(self): return self._store.initialized() @@ -163,6 +159,25 @@ def initialized(self): def get_feature(self, key): return self._store.get(key) + def stop(self): + self._running = False + + @staticmethod + def process_message(store, msg): + payload = json.loads(msg.data) + if msg.event == 'put': + store.init(payload) + elif msg.event == 'patch': + key = payload['path'][1:] + feature = payload['data'] + store.upsert(key, feature) + elif msg.event == 'delete': + key = payload['path'][1:] + version = payload['version'] + store.delete(key, version) + else: + log.warning('Unhandled event in stream processor: ' + msg.event) + class Consumer(Thread): def __init__(self, queue, api_key, config): Thread.__init__(self) @@ -251,8 +266,9 @@ def __init__(self, api_key, config = None): self._consumer = None self._offline = False self._lock = Lock() + self._stream_processor = None if self._config._stream: - self._stream_processor = StreamProcessor(api_key, config) + self._stream_processor = config._stream_processor_class(api_key, config) self._stream_processor.start() def _check_consumer(self): @@ -261,9 +277,11 @@ def _check_consumer(self): self._consumer = Consumer(self._queue, self._api_key, self._config) self._consumer.start() - def _stop_consumer(self): + def _stop_consumers(self): if self._consumer and self._consumer.is_alive(): self._consumer.stop() + if self._stream_processor and self._stream_processor.is_alive(): + self._stream_processor.stop() def _send(self, event): if self._offline: @@ -283,7 +301,7 @@ def identify(self, user): def set_offline(self): self._offline = True - self._stop_consumer() + self._stop_consumers() def set_online(self): self._offline = False @@ -339,8 +357,11 @@ def _toggle(self, key, user, default): def _headers(api_key): return {'Authorization': 'api_key ' + api_key, 'User-Agent': 'PythonClient/' + __version__, 'Content-Type': "application/json"} -def _stream_headers(api_key): - return {'Authorization': 'api_key ' + api_key, 'User-Agent': 'PythonClient/' + __version__, 'Accept': "text/event-stream"} +def _stream_headers(api_key, client="PythonClient"): + return {'Authorization': 'api_key ' + api_key, + 'User-Agent': 'PythonClient/' + __version__, + 'Cache-Control': 'no-cache', + 'Accept': "text/event-stream"} def _param_for_user(feature, user): if 'key' in user and user['key']: @@ -420,5 +441,4 @@ def _evaluate(feature, user): total += float(variation['weight']) / 100.0 if param < total: return variation['value'] - return None diff --git a/ldclient/twisted.py b/ldclient/twisted.py new file mode 100644 index 00000000..4f01acba --- /dev/null +++ b/ldclient/twisted.py @@ -0,0 +1,173 @@ +from __future__ import absolute_import +from functools import partial + +import json +from queue import Empty +import errno +from cachecontrol import CacheControl +from ldclient import LDClient, _headers, log, _evaluate, _stream_headers, StreamProcessor, Config +from ldclient.twisted_sse import TwistedSSEClient +from requests.packages.urllib3.exceptions import ProtocolError +from twisted.internet import task, defer +import txrequests + + +class TwistedLDClient(LDClient): + def __init__(self, api_key, config=None): + if config is None: + config = TwistedConfig.default() + super(TwistedLDClient, self).__init__(api_key, config) + self._session = CacheControl(txrequests.Session()) + + def _check_consumer(self): + if not self._consumer or not self._consumer.is_alive(): + self._consumer = TwistedConsumer(self._session, self._queue, self._api_key, self._config) + self._consumer.start() + + def flush(self): + if self._offline: + return defer.succeed(True) + self._check_consumer() + return self._consumer.flush() + + def toggle(self, key, user, default=False): + @defer.inlineCallbacks + def run(should_retry): + # noinspection PyBroadException + try: + if self._offline: + defer.returnValue(default) + val = yield self._toggle(key, user, default) + self._send({'kind': 'feature', 'key': key, 'user': user, 'value': val}) + defer.returnValue(val) + except ProtocolError as e: + inner = e.args[1] + if inner.errno == errno.ECONNRESET and should_retry: + log.warning('ProtocolError exception caught while getting flag. Retrying.') + d = yield run(False) + defer.returnValue(d) + else: + log.exception('Unhandled exception. Returning default value for flag.') + defer.returnValue(default) + except Exception: + log.exception('Unhandled exception. Returning default value for flag.') + defer.returnValue(default) + + return run(True) + + @defer.inlineCallbacks + def _toggle(self, key, user, default): + if self._config._stream and self._stream_processor.initialized(): + feature = self._stream_processor.get_feature(key) + else: + hdrs = _headers(self._api_key) + uri = self._config._base_uri + '/api/eval/features/' + key + r = yield self._session.get(uri, headers=hdrs, timeout=(self._config._connect, self._config._read)) + r.raise_for_status() + feature = r.json() + val = _evaluate(feature, user) + if val is None: + val = default + defer.returnValue(val) + + +class TwistedConfig(Config): + def __init__(self, *args, **kwargs): + super(TwistedConfig, self).__init__(*args, **kwargs) + self._stream_processor_class = TwistedStreamProcessor + + +class TwistedStreamProcessor(object): + + def __init__(self, api_key, config): + self._store = config._feature_store_class() + self.sse_client = TwistedSSEClient(config._stream_uri + "/", headers=_stream_headers(api_key, + "PythonTwistedClient"), + verify=config._verify, + on_event=partial(StreamProcessor.process_message, self._store)) + self.running = False + + def start(self): + self.sse_client.start() + self.running = True + + def stop(self): + self.sse_client.stop() + + def get_feature(self, key): + return self._store.get(key) + + def initialized(self): + return self._store.initialized() + + def is_alive(self): + return self.running + + +class TwistedConsumer(object): + def __init__(self, session, queue, api_key, config): + self._queue = queue + """ @type: queue.Queue """ + self._session = session + """ :type: txrequests.Session """ + + self._api_key = api_key + self._config = config + """ :type: Deferred """ + self._looping_call = None + """ :type: LoopingCall""" + + def start(self): + self._flushed = defer.Deferred() + self._looping_call = task.LoopingCall(self._consume) + self._looping_call.start(5) + + def stop(self): + self._looping_call.stop() + + def is_alive(self): + return self._looping_call is not None and self._looping_call.running + + def flush(self): + return self._consume() + + def _consume(self): + items = [] + try: + while True: + items.append(self._queue.get_nowait()) + except Empty: + pass + + if items: + return self.send_batch(items) + + @defer.inlineCallbacks + def send_batch(self, events): + @defer.inlineCallbacks + def do_send(should_retry): + # noinspection PyBroadException + try: + if isinstance(events, dict): + body = [events] + else: + body = events + hdrs = _headers(self._api_key) + uri = self._config._base_uri + '/api/events/bulk' + r = yield self._session.post(uri, headers=hdrs, timeout=(self._config._connect, self._config._read), + data=json.dumps(body)) + r.raise_for_status() + except ProtocolError as e: + inner = e.args[1] + if inner.errno == errno.ECONNRESET and should_retry: + log.warning('ProtocolError exception caught while sending events. Retrying.') + yield do_send(False) + else: + log.exception('Unhandled exception in event consumer. Analytics events were not processed.') + except: + log.exception('Unhandled exception in event consumer. Analytics events were not processed.') + try: + yield do_send(True) + finally: + for _ in events: + self._queue.task_done() \ No newline at end of file diff --git a/ldclient/twisted_sse.py b/ldclient/twisted_sse.py new file mode 100644 index 00000000..8770acc2 --- /dev/null +++ b/ldclient/twisted_sse.py @@ -0,0 +1,170 @@ +from __future__ import absolute_import + +from copy import deepcopy +from ldclient import log +from twisted.internet.defer import Deferred +from twisted.internet.ssl import ClientContextFactory +from twisted.web.client import Agent +from twisted.web.http_headers import Headers +from twisted.protocols.basic import LineReceiver + + +class NoValidationContextFactory(ClientContextFactory): + def getContext(self, *_): + return ClientContextFactory.getContext(self) + + +class TwistedSSEClient(object): + def __init__(self, url, headers, verify, on_event): + self.url = url + "/features" + self.verify = verify + self.headers = headers + self.on_event = on_event + self.on_error_retry = 30 + self.running = False + self.current_request = None + + def reconnect(self, old_protocol): + """ + :type old_protocol: EventSourceProtocol + """ + if not self.running: + return + + retry = old_protocol.retry + if not retry: + retry = 5 + from twisted.internet import reactor + reactor.callLater(retry, self.connect, old_protocol.last_id) + + def start(self): + self.running = True + self.connect() + + def connect(self, last_id=None): + """ + Connect to the event source URL + """ + headers = deepcopy(self.headers) + if last_id: + headers['Last-Event-ID'] = last_id + headers = dict([(x, [y.encode('utf-8')]) for x, y in headers.items()]) + url = self.url.encode('utf-8') + from twisted.internet import reactor + if self.verify: + agent = Agent(reactor) + else: + agent = Agent(reactor, NoValidationContextFactory()) + + d = agent.request( + 'GET', + url, + Headers(headers), + None) + self.current_request = d + d.addErrback(self.on_connect_error) + d.addCallback(self.on_response) + + def stop(self): + if self.running and self.current_request: + self.current_request.cancel() + + def on_response(self, response): + from twisted.internet import reactor + if response.code != 200: + log.error("non 200 response received: %d" % response.code) + reactor.callLater(self.on_error_retry, self.connect) + else: + finished = Deferred() + protocol = EventSourceProtocol(self.on_event, finished) + finished.addBoth(self.reconnect) + response.deliverBody(protocol) + return finished + + def on_connect_error(self, ignored): + """ + :type ignored: twisted.python.Failure + """ + from twisted.internet import reactor + ignored.printTraceback() + log.error("error connecting to endpoint {}: {}".format(self.url, ignored.getTraceback())) + reactor.callLater(self.on_error_retry, self.connect) + + +class EventSourceProtocol(LineReceiver): + def __init__(self, on_event, finished_deferred): + self.finished = finished_deferred + self.on_event = on_event + # Initialize the event and data buffers + self.event = '' + self.data = '' + self.id = None + self.last_id = None + self.retry = 5 # 5 second retry default + self.reset() + self.delimiter = b'\n' + + def reset(self): + self.event = 'message' + self.data = '' + self.id = None + self.retry = None + + def lineReceived(self, line): + if line == '': + # Dispatch event + self.dispatch_event() + else: + try: + field, value = line.split(':', 1) + # If value starts with a space, strip it. + value = lstrip(value) + except ValueError: + # We got a line with no colon, treat it as a field(ignore) + return + + if field == '': + # This is a comment; ignore + pass + elif field == 'data': + self.data += value + '\n' + elif field == 'event': + self.event = value + elif field == 'id': + self.id = value + pass + elif field == 'retry': + self.retry = value + pass + + def connectionLost(self, reason): + self.finished.callback(self) + + def dispatch_event(self): + """ + Dispatch the event + """ + # If last character is LF, strip it. + if self.data.endswith('\n'): + self.data = self.data[:-1] + log.debug("Dispatching event %s[%s]: %s", self.event, self.id, self.data) + event = Event(self.data, self.event, self.id, self.retry) + self.on_event(event) + if self.id: + self.last_id = self.id + self.reset() + + +class Event(object): + def __init__(self, data='', event='message', id=None, retry=None): + self.data = data + self.event = event + self.id = id + self.retry = retry + + def __str__(self, *args, **kwargs): + return self.data + + +def lstrip(value): + return value[1:] if value.startswith(' ') else value \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..df0d38d0 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +twisted = 1 \ No newline at end of file diff --git a/setup.py b/setup.py index 894e43e3..d10cb27e 100644 --- a/setup.py +++ b/setup.py @@ -11,11 +11,13 @@ # parse_requirements() returns generator of pip.req.InstallRequirement objects install_reqs = parse_requirements('requirements.txt', session=uuid.uuid1()) test_reqs = parse_requirements('test-requirements.txt', session=uuid.uuid1()) +twisted_reqs = parse_requirements('twisted-requirements.txt', session=uuid.uuid1()) # reqs is a list of requirement # e.g. ['django==1.5.1', 'mezzanine==1.4.6'] reqs = [str(ir.req) for ir in install_reqs] testreqs = [str(ir.req) for ir in test_reqs] +txreqs = [str(ir.req) for ir in twisted_reqs] class PyTest(Command): user_options = [] @@ -43,6 +45,9 @@ def run(self): 'Operating System :: OS Independent', 'Programming Language :: Python :: 2 :: Only', ], + extras_require={ + "twisted": txreqs + }, tests_require=testreqs, cmdclass = {'test': PyTest}, ) \ No newline at end of file diff --git a/test-requirements.txt b/test-requirements.txt index 9440c735..2b820b06 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1 +1,2 @@ -pytest==2.6.4 +pytest==2.7.2 +pytest-twisted==1.5 diff --git a/testing/sse_util.py b/testing/sse_util.py new file mode 100644 index 00000000..a05771c3 --- /dev/null +++ b/testing/sse_util.py @@ -0,0 +1,174 @@ +import json +import logging +from queue import Empty +import socket +import ssl +import threading + +import time +from twisted.internet import defer, reactor + +try: + import queue as queuemod +except: + import Queue as queuemod + +try: + from SimpleHTTPServer import SimpleHTTPRequestHandler + # noinspection PyPep8Naming + import SocketServer as socketserver + import urlparse +except ImportError: + # noinspection PyUnresolvedReferences + from http.server import SimpleHTTPRequestHandler + # noinspection PyUnresolvedReferences + import socketserver + # noinspection PyUnresolvedReferences + from urllib import parse as urlparse + + +class TestServer(socketserver.TCPServer): + allow_reuse_address = True + + +class GenericServer: + + def __init__(self, host='localhost', use_ssl=False, port=None, cert_file="self_signed.crt", + key_file="self_signed.key"): + + self.get_paths = {} + self.post_paths = {} + self.raw_paths = {} + self.stopping = False + parent = self + + class CustomHandler(SimpleHTTPRequestHandler): + + def handle_request(self, paths): + # sort so that longest path wins + for path, handler in sorted(paths.items(), key=lambda item: len(item[0]), reverse=True): + if self.path.startswith(path): + handler(self) + return + self.send_response(404) + self.end_headers() + self.wfile.close() + + def do_GET(self): + self.handle_request(parent.get_paths) + + # noinspection PyPep8Naming + def do_POST(self): + self.handle_request(parent.post_paths) + + self.httpd = TestServer(("0.0.0.0", 0), CustomHandler) + port = port if port is not None else self.httpd.socket.getsockname()[1] + self.url = ("https://" if use_ssl else "http://") + host + ":%s" % port + self.port = port + logging.info("serving at port %s: %s" % (port, self.url)) + + if use_ssl: + self.httpd.socket = ssl.wrap_socket(self.httpd.socket, + certfile=cert_file, + keyfile=key_file, + server_side=True, + ssl_version=ssl.PROTOCOL_TLSv1) + self.start() + + def start(self): + self.stopping = False + httpd_thread = threading.Thread(target=self.httpd.serve_forever) + httpd_thread.setDaemon(True) + httpd_thread.start() + + def stop(self): + self.shutdown() + + def post_events(self): + q = queuemod.Queue() + def do_nothing(handler): + handler.send_response(200) + handler.end_headers() + handler.wfile.close() + + self.post_paths["/api/events/bulk"] = do_nothing + return q + + def get(self, path, func): + """ + Registers a handler function to be called when a GET request beginning with 'path' is made. + + :param path: The path prefix to listen on + :param func: The function to call. Should be a function that takes the querystring as a parameter. + """ + self.get_paths[path] = func + + def post(self, path, func): + """ + Registers a handler function to be called when a POST request beginning with 'path' is made. + + :param path: The path prefix to listen on + :param func: The function to call. Should be a function that takes the post body as a parameter. + """ + self.post_paths[path] = func + + def shutdown(self): + self.stopping = True + self.httpd.shutdown() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self.shutdown() + finally: + pass + + +class SSEServer(GenericServer): + def __init__(self, host='localhost', use_ssl=False, port=None, cert_file="self_signed.crt", + key_file="self_signed.key", queue=queuemod.Queue()): + GenericServer.__init__(self, host, use_ssl, port, cert_file, key_file) + + def feed_forever(handler): + handler.send_response(200) + handler.send_header('Content-type', 'text/event-stream; charset=utf-8') + handler.end_headers() + while not self.stopping: + try: + event = queue.get(block=True, timeout=1) + """ :type: ldclient.twisted_sse.Event """ + if event: + lines = "event: {event}\ndata: {data}\n\n".format(event=event.event, + data=json.dumps(event.data)) + handler.wfile.write(lines) + except Empty: + pass + + self.get_paths["/"] = feed_forever + self.queue = queue + + +@defer.inlineCallbacks +def wait_until(condition, timeout=5): + end_time = time.time() + timeout + + while True: + result = yield defer.maybeDeferred(condition) + if result: + defer.returnValue(condition) + elif time.time() > end_time: + raise Exception("Timeout waiting for {}".format(condition.__name__)) # pragma: no cover + else: + d = defer.Deferred() + reactor.callLater(.1, d.callback, None) + yield d + + +def is_equal(f, val): + @defer.inlineCallbacks + def is_equal_eval(): + result = yield defer.maybeDeferred(f) + defer.returnValue(result == val) + return is_equal_eval \ No newline at end of file diff --git a/testing/test_sse_twisted.py b/testing/test_sse_twisted.py new file mode 100644 index 00000000..911a8c92 --- /dev/null +++ b/testing/test_sse_twisted.py @@ -0,0 +1,71 @@ +import logging +from ldclient.twisted import TwistedLDClient, TwistedConfig +from ldclient.twisted_sse import Event +import pytest +from testing.sse_util import wait_until, SSEServer, GenericServer, is_equal + +logging.basicConfig(level=logging.DEBUG) + + +@pytest.fixture() +def server(request): + server = GenericServer() + def fin(): + server.shutdown() + request.addfinalizer(fin) + return server + +@pytest.fixture() +def stream(request): + server = SSEServer() + def fin(): + server.shutdown() + request.addfinalizer(fin) + return server + + +@pytest.inlineCallbacks +def test_sse_init(server, stream): + stream.queue.put(Event(event="put", data=feature("foo", "jim"))) + client = TwistedLDClient("apikey", TwistedConfig(stream=True, base_uri=server.url, stream_uri=stream.url)) + yield wait_until(is_equal(lambda: client.toggle("foo", user('xyz'), "blah"), "jim")) + + +@pytest.inlineCallbacks +def test_sse_reconnect(server, stream): + server.post_events() + stream.queue.put(Event(event="put", data=feature("foo", "on"))) + client = TwistedLDClient("apikey", TwistedConfig(stream=True, base_uri=server.url, stream_uri=stream.url)) + yield wait_until(is_equal(lambda: client.toggle("foo", user('xyz'), "blah"), "on")) + + stream.stop() + + yield wait_until(is_equal(lambda: client.toggle("foo", user('xyz'), "blah"), "on")) + + stream.start() + + stream.queue.put(Event(event="put", data=feature("foo", "jim"))) + client = TwistedLDClient("apikey", TwistedConfig(stream=True, base_uri=server.url, stream_uri=stream.url)) + yield wait_until(is_equal(lambda: client.toggle("foo", user('xyz'), "blah"), "jim")) + + +def feature(key, val): + return { + key: {"name": "Feature {}".format(key), "key": key, "kind": "flag", "salt": "Zm9v", "on": val, + "variations": [{"value": val, "weight": 100, + "targets": [{"attribute": "key", "op": "in", "values": []}], + "userTarget": {"attribute": "key", "op": "in", "values": []}}, + {"value": False, "weight": 0, + "targets": [{"attribute": "key", "op": "in", "values": []}], + "userTarget": {"attribute": "key", "op": "in", "values": []}}], + "commitDate": "2015-09-08T21:24:16.712Z", + "creationDate": "2015-09-08T21:06:16.527Z", "version": 4}} + + +def user(name): + return { + u'key': name, + u'custom': { + u'bizzle': u'def' + } + } diff --git a/testing/test_twisted.py b/testing/test_twisted.py new file mode 100644 index 00000000..b4dc17ad --- /dev/null +++ b/testing/test_twisted.py @@ -0,0 +1,158 @@ +from __future__ import absolute_import + +from builtins import object +import ldclient +from ldclient.twisted import TwistedLDClient +import pytest +from twisted.internet import defer + +try: + import queue +except: + import Queue as queue + +client = TwistedLDClient("API_KEY", ldclient.Config("http://localhost:3000")) + +user = { + u'key': u'xyz', + u'custom': { + u'bizzle': u'def' + } + } + +class MockConsumer(object): + def __init__(self): + self._running = False + + def stop(self): + self._running = False + + def start(self): + self._running = True + + def is_alive(self): + return self._running + + def flush(self): + return defer.succeed(True) + + +def mock_consumer(): + return MockConsumer() + +def noop_consumer(): + return + +def mock_toggle(key, user, default): + hash = minimal_feature = { + u'key': u'feature.key', + u'salt': u'abc', + u'on': True, + u'variations': [ + { + u'value': True, + u'weight': 100, + u'targets': [] + }, + { + u'value': False, + u'weight': 0, + u'targets': [] + } + ] + } + val = ldclient._evaluate(hash, user) + if val is None: + return defer.succeed(default) + return defer.succeed(val) + +def setup_function(function): + client.set_online() + client._queue = queue.Queue(10) + client._consumer = mock_consumer() + +@pytest.fixture(autouse=True) +def noop_check_consumer(monkeypatch): + monkeypatch.setattr(client, '_check_consumer', noop_consumer) + +@pytest.fixture(autouse=True) +def no_remote_toggle(monkeypatch): + monkeypatch.setattr(client, '_toggle', mock_toggle) + +def test_set_offline(): + client.set_offline() + assert client.is_offline() == True + +def test_set_online(): + client.set_offline() + client.set_online() + assert client.is_offline() == False + +@pytest.inlineCallbacks +def test_toggle(): + result = yield client.toggle('xyz', user, default=None) + assert result == True + +@pytest.inlineCallbacks +def test_toggle_offline(): + client.set_offline() + assert (yield client.toggle('xyz', user, default=None)) == None + +@pytest.inlineCallbacks +def test_toggle_event(): + val = yield client.toggle('xyz', user, default=None) + def expected_event(e): + return e['kind'] == 'feature' and e['key'] == 'xyz' and e['user'] == user and e['value'] == True + assert expected_event(client._queue.get(False)) + +@pytest.inlineCallbacks +def test_toggle_event_offline(): + client.set_offline() + yield client.toggle('xyz', user, default=None) + assert client._queue.empty() + + +def test_identify(): + client.identify(user) + def expected_event(e): + return e['kind'] == 'identify' and e['key'] == u'xyz' and e['user'] == user + assert expected_event(client._queue.get(False)) + +def test_identify_offline(): + client.set_offline() + client.identify(user) + assert client._queue.empty() + +def test_track(): + client.track('my_event', user, 42) + def expected_event(e): + return e['kind'] == 'custom' and e['key'] == 'my_event' and e['user'] == user and e['data'] == 42 + assert expected_event(client._queue.get(False)) + +def test_track_offline(): + client.set_offline() + client.track('my_event', user, 42) + assert client._queue.empty() + +def drain(queue): + while not queue.empty(): + queue.get() + queue.task_done() + return + +@pytest.inlineCallbacks +def test_flush_empties_queue(): + client.track('my_event', user, 42) + client.track('my_event', user, 33) + drain(client._queue) + yield client.flush() + assert client._queue.empty() + + +@pytest.inlineCallbacks +def test_flush_offline_does_not_empty_queue(): + client.track('my_event', user, 42) + client.track('my_event', user, 33) + client.set_offline() + yield client.flush() + assert not client._queue.empty() \ No newline at end of file diff --git a/twisted-requirements.txt b/twisted-requirements.txt new file mode 100644 index 00000000..3c7f85d8 --- /dev/null +++ b/twisted-requirements.txt @@ -0,0 +1,2 @@ +txrequests>=0.9 +pyOpenSSL>=0.14 \ No newline at end of file