Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion graphql_subscriptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
from .subscription_transport_ws import SubscriptionServer

__all__ = ['RedisPubsub', 'SubscriptionManager', 'SubscriptionServer']

Empty file.
82 changes: 82 additions & 0 deletions graphql_subscriptions/executors/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import absolute_import

import asyncio
from websockets import ConnectionClosed

try:
from asyncio import ensure_future
except ImportError:
# ensure_future is only implemented in Python 3.4.4+
# Reference: https://github.com/graphql-python/graphql-core/blob/master/graphql/execution/executors/asyncio.py
def ensure_future(coro_or_future, loop=None):
"""Wrap a coroutine or an awaitable in a future.
If the argument is a Future, it is returned directly.
"""
if isinstance(coro_or_future, asyncio.Future):
if loop is not None and loop is not coro_or_future._loop:
raise ValueError('loop argument must agree with Future')
return coro_or_future
elif asyncio.iscoroutine(coro_or_future):
if loop is None:
loop = asyncio.get_event_loop()
task = loop.create_task(coro_or_future)
if task._source_traceback:
del task._source_traceback[-1]
return task
else:
raise TypeError(
'A Future, a coroutine or an awaitable is required')


class AsyncioExecutor(object):
error = ConnectionClosed
task_cancel_error = asyncio.CancelledError

def __init__(self, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self.loop = loop
self.futures = []

def ws_close(self, code):
return self.ws.close(code)

def ws_protocol(self):
return self.ws.subprotocol

def ws_isopen(self):
if self.ws.open:
return True
else:
return False

def ws_send(self, msg):
return self.ws.send(msg)

def ws_recv(self):
return self.ws.recv()

def sleep(self, time):
if self.loop.is_running():
return asyncio.sleep(time)
return self.loop.run_until_complete(asyncio.sleep(time))

@staticmethod
def kill(future):
future.cancel()

def join(self, future=None, timeout=None):
if not isinstance(future, asyncio.Future):
return
if self.loop.is_running():
return asyncio.wait_for(future, timeout=timeout)
return self.loop.run_until_complete(
asyncio.wait_for(future, timeout=timeout))

def execute(self, fn, *args, **kwargs):
result = fn(*args, **kwargs)
if isinstance(result, asyncio.Future) or asyncio.iscoroutine(result):
future = ensure_future(result, loop=self.loop)
self.futures.append(future)
return future
return result
Empty file.
52 changes: 52 additions & 0 deletions graphql_subscriptions/executors/gevent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import absolute_import

from geventwebsocket.exceptions import WebSocketError
import gevent


class GeventExecutor(object):
# used to patch socket library so it doesn't block
socket = gevent.socket
error = WebSocketError

def __init__(self):
self.greenlets = []

def ws_close(self, code):
self.ws.close(code)

def ws_protocol(self):
return self.ws.protocol

def ws_isopen(self):
if self.ws.closed:
return False
else:
return True

def ws_send(self, msg, **kwargs):
self.ws.send(msg, **kwargs)

def ws_recv(self):
return self.ws.receive()

@staticmethod
def sleep(time):
gevent.sleep(time)

@staticmethod
def kill(greenlet):
gevent.kill(greenlet)

@staticmethod
def join(greenlet, timeout=None):
greenlet.join(timeout)

def join_all(self):
gevent.joinall(self.greenlets)
self.greenlets = []

def execute(self, fn, *args, **kwargs):
greenlet = gevent.spawn(fn, *args, **kwargs)
self.greenlets.append(greenlet)
return greenlet
4 changes: 4 additions & 0 deletions graphql_subscriptions/subscription_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .manager import SubscriptionManager
from .pubsub import RedisPubsub

__all__ = ['SubscriptionManager', 'RedisPubsub']
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,15 @@
standard_library.install_aliases()
from builtins import object
from types import FunctionType
import pickle

from graphql import parse, validate, specified_rules, value_from_ast, execute
from graphql.language.ast import OperationDefinition
from promise import Promise
import gevent
import redis

from .utils import to_snake_case
from .validation import SubscriptionHasSingleRootField


class RedisPubsub(object):
def __init__(self, host='localhost', port=6379, *args, **kwargs):
redis.connection.socket = gevent.socket
self.redis = redis.StrictRedis(host, port, *args, **kwargs)
self.pubsub = self.redis.pubsub()
self.subscriptions = {}
self.sub_id_counter = 0
self.greenlet = None

def publish(self, trigger_name, message):
self.redis.publish(trigger_name, pickle.dumps(message))
return True

def subscribe(self, trigger_name, on_message_handler, options):
self.sub_id_counter += 1
try:
if trigger_name not in list(self.subscriptions.values())[0]:
self.pubsub.subscribe(trigger_name)
except IndexError:
self.pubsub.subscribe(trigger_name)
self.subscriptions[self.sub_id_counter] = [
trigger_name, on_message_handler
]
if not self.greenlet:
self.greenlet = gevent.spawn(self.wait_and_get_message)
return Promise.resolve(self.sub_id_counter)

def unsubscribe(self, sub_id):
trigger_name, on_message_handler = self.subscriptions[sub_id]
del self.subscriptions[sub_id]
try:
if trigger_name not in list(self.subscriptions.values())[0]:
self.pubsub.unsubscribe(trigger_name)
except IndexError:
self.pubsub.unsubscribe(trigger_name)
if not self.subscriptions:
self.greenlet = self.greenlet.kill()

def wait_and_get_message(self):
while True:
message = self.pubsub.get_message(ignore_subscribe_messages=True)
if message:
self.handle_message(message)
gevent.sleep(.001)

def handle_message(self, message):
if isinstance(message['channel'], bytes):
channel = message['channel'].decode()
for sub_id, trigger_map in self.subscriptions.items():
if trigger_map[0] == channel:
trigger_map[1](pickle.loads(message['data']))


class ValidationError(Exception):
def __init__(self, errors):
self.errors = errors
Expand All @@ -79,7 +23,7 @@ def __init__(self, schema, pubsub, setup_funcs={}):
self.pubsub = pubsub
self.setup_funcs = setup_funcs
self.subscriptions = {}
self.max_subscription_id = 0
self.max_subscription_id = 1

def publish(self, trigger_name, payload):
self.pubsub.publish(trigger_name, payload)
Expand Down Expand Up @@ -145,11 +89,6 @@ def subscribe(self, query, operation_name, callback, variables, context,
except AttributeError:
channel_options = {}

# TODO: Think about this some more...the Apollo library
# let's all messages through by default, even if
# the users incorrectly uses the setup_funcs (does not
# use 'filter' or 'channel_options' keys); I think it
# would be better to raise an exception here
def filter(arg1, arg2):
return True

Expand Down Expand Up @@ -181,7 +120,8 @@ def context_do_execute_handler(result):
subscription_promises.append(
self.pubsub.
subscribe(trigger_name, on_message, channel_options).then(
lambda id: self.subscriptions[external_subscription_id].append(id)
lambda id: self.subscriptions[external_subscription_id].
append(id)
))

return Promise.all(subscription_promises).then(
Expand Down
113 changes: 113 additions & 0 deletions graphql_subscriptions/subscription_manager/pubsub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from future import standard_library
standard_library.install_aliases()
from builtins import object
import pickle
import sys

from promise import Promise
import redis

from ..executors.gevent import GeventExecutor
from ..executors.asyncio import AsyncioExecutor

PY3 = sys.version_info[0] == 3


class RedisPubsub(object):
def __init__(self,
host='localhost',
port=6379,
executor=GeventExecutor,
*args,
**kwargs):

if executor == AsyncioExecutor:
try:
import aredis
except:
raise ImportError(
'You need the redis client "aredis" installed for use w/ '
'asyncio')

redis_client = aredis
else:
redis_client = redis

# patch redis socket library so it doesn't block if using gevent
if executor == GeventExecutor:
redis_client.connection.socket = executor.socket

self.redis = redis_client.StrictRedis(host, port, *args, **kwargs)
self.pubsub = self.redis.pubsub(ignore_subscribe_messages=True)

self.executor = executor()
self.backgrd_task = None

self.subscriptions = {}
self.sub_id_counter = 0

def publish(self, trigger_name, message):
self.executor.execute(self.redis.publish, trigger_name,
pickle.dumps(message))
return True

def subscribe(self, trigger_name, on_message_handler, options):
self.sub_id_counter += 1

self.subscriptions[self.sub_id_counter] = [
trigger_name, on_message_handler]

if PY3:
trigger_name = trigger_name.encode()

if trigger_name not in list(self.pubsub.channels.keys()):
self.executor.join(self.executor.execute(self.pubsub.subscribe,
trigger_name))
if not self.backgrd_task:
self.backgrd_task = self.executor.execute(
self.wait_and_get_message)

return Promise.resolve(self.sub_id_counter)

def unsubscribe(self, sub_id):
trigger_name, on_message_handler = self.subscriptions[sub_id]
del self.subscriptions[sub_id]

if PY3:
trigger_name = trigger_name.encode()

if trigger_name not in list(self.pubsub.channels.keys()):
self.executor.execute(self.pubsub.unsubscribe, trigger_name)

if not self.subscriptions:
self.backgrd_task = self.executor.kill(self.backgrd_task)

async def _wait_and_get_message_async(self):
try:
while True:
message = await self.pubsub.get_message()
if message:
self.handle_message(message)
await self.executor.sleep(.001)
except self.executor.task_cancel_error:
return

def _wait_and_get_message_sync(self):
while True:
message = self.pubsub.get_message()
if message:
self.handle_message(message)
self.executor.sleep(.001)

def wait_and_get_message(self):
if hasattr(self.executor, 'loop'):
return self._wait_and_get_message_async()
return self._wait_and_get_message_sync()

def handle_message(self, message):

channel = message['channel'].decode() if PY3 else message['channel']

for sub_id, trigger_map in self.subscriptions.items():
if trigger_map[0] == channel:
trigger_map[1](pickle.loads(message['data']))
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

FIELD = 'Field'

# XXX from Apollo pacakge: Temporarily use this validation
# rule to make our life a bit easier.
# Temporarily use this validation rule to make our life a bit easier.


class SubscriptionHasSingleRootField(ValidationRule):
Expand All @@ -27,8 +26,8 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors):
else:
self.context.report_error(
GraphQLError(
'Apollo subscriptions do not support fragments on\
the root field', [node]))
'Subscriptions do not support fragments on '
'the root field', [node]))
if num_fields > 1:
self.context.report_error(
GraphQLError(
Expand All @@ -38,5 +37,5 @@ def enter_OperationDefinition(self, node, key, parent, path, ancestors):

@staticmethod
def too_many_subscription_fields_error(subscription_name):
return 'Subscription "{0}" must have only one\
field.'.format(subscription_name)
return ('Subscription "{0}" must have only one '
'field.'.format(subscription_name))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .server import SubscriptionServer
Loading