diff --git a/livetrading/__init__.py b/livetrading/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/livetrading/broker.py b/livetrading/broker.py new file mode 100644 index 00000000..65c13488 --- /dev/null +++ b/livetrading/broker.py @@ -0,0 +1,85 @@ +from decimal import Decimal +from typing import Any, Dict, Optional + +from livetrading.event import KLinesEventSource, Pair, PairInfo, TickersEventSource +from livetrading.rest_cli import RestClient +from livetrading.websocket_client import WSClient + + +class Broker: + """A client for crypto currency exchange. + + :param dispatcher: The event dispatcher. + :param config: Config settings for exchange. + """ + def __init__( + self, dispatcher, config + ): + self.dispatcher = dispatcher + self.config = config + self.api_cli = RestClient(self.config) + self.cli: Optional[Any] = None # external libs as ccxt + self.ws_cli = WSClient(config) + self._cached_pairs: Dict[Pair] = {} + + def subscribe_to_ticker_events( + self, pair: Pair, interval: str, event_handler + ): + """Registers a callable that will be called every ticker. + + :param bar_duration: The bar duration. One of 1s, 1m, 3m, 5m, 15m, 30m, 1h, 2h, 4h, 6h, 8h, 12h, 1d, 3d, 1w, 1M. + :param pair: The trading pair. + :param event_handler: A callable that receives an TickerEvent. + """ + + event_source = TickersEventSource(pair, interval, self.ws_cli) + channel = "ticker" + + self._subscribe_to_ws_channel_events( + channel, + event_handler, + event_source + ) + + def subscribe_to_bar_events( + self, pair: Pair, event_handler, interval + ): + """Registers a callable that will be called every bar. + + :param pair: The trading pair. + :param event_handler: A callable that receives an BarEvent. + """ + event_source = KLinesEventSource(pair, self.ws_cli) + channel = event_source.ws_channel(interval) + + self._subscribe_to_ws_channel_events( + channel, + event_handler, + event_source + ) + + def get_pair_info(self, pair: Pair) -> PairInfo: + """Returns information about a trading pair. + + :param pair: The trading pair. + """ + ret = self._cached_pairs.get(pair) + api_path = '/'.join(['products', pair]) + if not ret: + pair_info = self.api_cli.call(method='GET', apipath=api_path) + self._cached_pairs[pair] = PairInfo(Decimal(pair_info['base_increment']), + Decimal(pair_info['quote_increment'])) + return self._cached_pairs + + def get_data_df(self, event_source): + data_source = self.ws_cli.event_sources[event_source] + return list(data_source.events) + + def _subscribe_to_ws_channel_events( + self, channel: str, event_handler, event_source + ): + # Set the event source for the channel. + self.ws_cli.set_channel_event_source(channel, event_source) + + # Subscribe the event handler to the event source. + self.dispatcher.subscribe(event_source, event_handler) diff --git a/livetrading/config.py b/livetrading/config.py new file mode 100644 index 00000000..148f9033 --- /dev/null +++ b/livetrading/config.py @@ -0,0 +1,4 @@ +from configloader import ConfigLoader + +config = ConfigLoader() +config.update_from_json_file('path_to_json_file') diff --git a/livetrading/converter.py b/livetrading/converter.py new file mode 100644 index 00000000..666cf2fa --- /dev/null +++ b/livetrading/converter.py @@ -0,0 +1,17 @@ +import pandas as pd + +DEFAULT_DATAFRAME_COLUMNS = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume'] + +def ohlcv_to_dataframe(historical_data: list) -> pd.DataFrame: + """ + Converts historical data to a Dataframe + :param historical_data: list with candle (OHLCV) data + :return: DataFrame + """ + df = pd.DataFrame( + [{fn: getattr(f, fn) for fn in DEFAULT_DATAFRAME_COLUMNS} for f in historical_data] + ) + df['Date'] = pd.to_datetime(df['Date'], unit='ms', utc=True, ) + df = df.set_index('Date') + df = df.sort_index(ascending=True) + return df.head() diff --git a/livetrading/env b/livetrading/env new file mode 100644 index 00000000..059df011 --- /dev/null +++ b/livetrading/env @@ -0,0 +1,5 @@ +{ +"ws_url": "wss://ws-feed.exchange.coinbase.com", +"api_url": "https://api.exchange.coinbase.com/", +"ws_timeout": 5 +} diff --git a/livetrading/event.py b/livetrading/event.py new file mode 100644 index 00000000..3c755b62 --- /dev/null +++ b/livetrading/event.py @@ -0,0 +1,236 @@ +import abc +import dataclasses +import datetime + +from collections import deque +from dateutil.parser import isoparse +from typing import Optional + + +intervals = { + "1s": 1, + "1m": 60, + "3m": 3 * 60, + "5m": 5 * 60, + "15m": 15 * 60, + "30m": 30 * 60, + "1h": 3600, + "2h": 2 * 3600, + "4h": 4 * 3600, + "6h": 6 * 3600, + "8h": 8 * 3600, + "12h": 12 * 3600, + "1d": 86400, + "3d": 3 * 86400, + "1w": 7 * 86400, + "1M": 31 * 86400 +} + + +@dataclasses.dataclass +class Bar: + """A Bar, aka candlestick, is the summary of the trading activity in a given period. + + :param date: The beginning of the period. It must have timezone information set. + :param pair: The trading pair. + :param open: The opening price. + :param high: The highest traded price. + :param low: The lowest traded price. + :param close: The closing price. + :param volume: The volume traded. + """ + date: datetime + pair: str + Open: float + High: float + Low: float + Close: float + Volume: float + + +@dataclasses.dataclass +class Pair: + """A trading pair. + + :param base_symbol: The base symbol. + :param quote_symbol: The quote symbol. + """ + base_symbol: str + quote_symbol: str + + def __str__(self): + # change format here to reflect corresponding exchange + return "{}-{}".format(self.base_symbol, self.quote_symbol) + + +@dataclasses.dataclass +class PairInfo: + """Information about a trading pair. + + :param base_increment: The increment for the base symbol. + :param quote_increment: The increment for the quote symbol. + """ + base_increment: float + quote_increment: float + + +class Ticker: + """A Ticker constantly updating stream of information about a stock. + :param datetime: The beginning of the period. It must have timezone information set. + :param pair: The trading pair. + :param open: The opening price. + :param high: The highest traded price. + :param low: The lowest traded price. + :param price: The price. + :param volume: The volume traded. + """ + def __init__(self, pair: Pair, json: dict): + self.pair: Pair = pair + self.json: dict = json + self.Date = isoparse(json['time']) + self.Volume = float(json["volume_24h"]) + self.Open = float(json["open_24h"]) + self.High = float(json["high_24h"]) + self.Low = float(json["low_24h"]) + self.Close = float(json["price"]) + + +class KlineBar(Bar): + """ + K-line, aka candlestick, is a chart marked with the opening price, closing price, + highest price, and lowest price to reflect price changes. + :param pair: The trading pair. + :param json: Message json. + """ + def __init__(self, pair: Pair, json: dict): + super().__init__( + datetime.utcfromtimestamp( + int(json["t"] / 1e3).replace(tzinfo=datetime.timezone.utc)), + pair, float(json["o"]), float(json["h"]), + float(json["l"]), float(json["c"]), float(json["v"]) + ) + self.pair: Pair = pair + self.json: dict = json + + +class EventProducer: + """Base class for event producers. + .. note:: + + Main method is for main functions that should be performed for an event producer. + Finalize method is called on error or stop. + """ + def main(self): + """Override to run the loop that produces events.""" + pass + + def finalize(self): + """Override to perform task and transaction cancellation.""" + pass + + +class Event: + """Base class for events. + + :param when: The datetime when the event occurred. + Used to calculate the datetime for the next event. + It must have timezone information set. + """ + + def __init__(self, when: datetime.datetime): + self.when: datetime.datetime = when + + +class EventSource(metaclass=abc.ABCMeta): + """Base class for events storage. + + :param producer: EventProducer. + """ + + def __init__(self, producer: Optional[EventProducer] = None): + self.producer = producer + self.events = deque() + + +class ChannelEventSource(EventSource): + """Base class for websockets channels. + + :param producer: EventProducer. + """ + def __init__(self, producer: EventProducer): + super().__init__(producer=producer) + + @abc.abstractmethod + def push_to_queue(self, message: dict): + raise NotImplementedError() + + +class TickersEventSource(ChannelEventSource): + """An event source for :class:`Ticker` instances. + + :param pair: The trading pair. + """ + def __init__(self, pair: Pair, when: datetime, producer: EventProducer): + super().__init__(producer=producer) + self.pair: Pair = pair + self.when = intervals.get(when) + + def push_to_queue(self, message: dict): + timestamp = message["time"] + dt = isoparse(timestamp) + datetime.timedelta(seconds=self.when) + self.events.append(TickerEvent( + dt, + Ticker(self.pair, message))) + + +class KLinesEventSource(EventSource): + """An event source for :class:`KLineBar` instances. + + :param pair: The trading pair.. + """ + def __init__(self, pair: Pair, producer: EventProducer): + super().__init__(producer=producer) + self.pair: Pair = pair + + def push_to_queue(self, message: dict): + kline_event = message["data"] + kline = kline_event["k"] + # Wait for the last update to the kline. + if kline["x"] is False: + return + self.events.append(BarEvent( + datetime.utcfromtimestamp( + int(kline_event["E"] / 1e3).replace(tzinfo=datetime.timezone.utc)), + KlineBar(self.pair, kline))) + + def ws_channel(self, interval: str) -> str: + """ + Generate websocket channel + """ + return "{}@kline_{}".format( + "{}{}".format(self.pair.base_symbol.upper(), self.pair.quote_symbol.upper()).lower(), + interval) + + +class BarEvent(Event): + """An event for :class:`Bar` instances. + + :param when: The datetime when the event occurred. It must have timezone information set. + :param bar: The bar. + """ + def __init__(self, when, bar: Bar): + super().__init__(when) + + self.data = bar + + +class TickerEvent(Event): + """An event for :class:`Ticker` instances. + + :param when: The datetime when the event occurred. It must have timezone information set. + :param ticker: The Ticker. + """ + def __init__(self, when, ticker: Ticker): + super().__init__(when) + + self.data = ticker diff --git a/livetrading/executor.py b/livetrading/executor.py new file mode 100644 index 00000000..3fe4f5d0 --- /dev/null +++ b/livetrading/executor.py @@ -0,0 +1,130 @@ +import time +import datetime +import logging +from functools import partial +from typing import Any, Dict, List, Set, Optional + +from backtesting import Backtest +from .converter import ohlcv_to_dataframe +from .event import Event, EventSource, EventProducer + +logger = logging.getLogger(__name__) + + +class EventDispatcher: + """Responsible for connecting event sources to event handlers and dispatching events + in the right order. + """ + def __init__(self, strategy): + self._event_handlers: Dict[EventSource, List[Any]] = {} + self._prefetched_events: Dict[EventSource, Optional[Event]] = {} + self._prev_events: Dict[EventSource, datetime.datetime] = {} + self._producers: Set[EventProducer] = set() + self._running = False + self._stopped = False + self._current_event_dt = None + self.strategy = strategy + self.backtesting = None + + def set_strategy(self, strategy): + self._strategy = strategy + + def set_backtesting_partial(self, cash: float = 10_000, + commission: float = .0, + margin: float = 1., + trade_on_close=False, + hedging=False, + exclusive_orders=False): + self.backtesting = partial(Backtest, strategy=self.strategy, cash=cash, commission=commission, + margin=margin, trade_on_close=trade_on_close, + hedging=hedging, exclusive_orders=exclusive_orders) + + def subscribe(self, source: EventSource, event_handler: Any): + """Registers an callable that will be called when an event source has new events. + + :param source: An event source. + :param event_handler: An callable that receives an event. + """ + assert not self._running + handlers = self._event_handlers.setdefault(source, []) + if event_handler not in handlers: + handlers.append(event_handler) + if source.producer: + self._producers.add(source.producer) + + def run(self): + assert not self._running, "Running or already ran" + + self._running = True + try: + # Run producers and dispatch loop. + for producer in self._producers: + producer.main() + self._dispatch_loop() + except Exception as error: + logger.error(error) + finally: + for producer in self._producers: + producer.finalize() + + def on_error(self, error: Any): + logger.error(error) + + def _dispatch_next(self, ge_or_assert: Optional[datetime.datetime]): + # Pre-fetch events from all sources. + sources_to_pop = [ + source for source in self._event_handlers.keys() if + self._prefetched_events.get(source) is None + ] + for source in sources_to_pop: + if source.events: + df = ohlcv_to_dataframe([event.data for event in source.events]) + bt = self.backtesting(data=df) + bt.run() + + event = source.events.pop() + # Check that events from the same source are returned in order. + prev_event = self._prev_events.get(source) + if prev_event is not None and event.when < prev_event.when: + continue + + self._prev_events[source] = event + self._prefetched_events[source] = event + + # Calculate the datetime for the next event using the prefetched events. + next_dt = None + prefetched_events = [e for e in self._prefetched_events.values() if e] + if prefetched_events: + next_dt = min(map(lambda e: e.when, prefetched_events)) + assert ge_or_assert is None or next_dt is None or next_dt >= ge_or_assert, \ + f"{next_dt} can't be dispatched after {ge_or_assert}" + + # Dispatch events matching the desired datetime. + event_handlers = [] + for source, e in self._prefetched_events.items(): + if e is not None and e.when == next_dt: + # Collect event handlers for the event source. + event_handlers += [event_handler(e) for event_handler in + self._event_handlers.get(source, [])] + # Consume the event. + self._prefetched_events[source] = None + + self._current_event_dt = None + return next_dt + + def stop(self): + """Requests the event dispatcher to stop the event processing loop.""" + self._stopped = True + + for producer in self._producers: + producer.finalize() + + def _dispatch_loop(self): + last_dt = None + + while not self._stopped: + dispatched_dt = self._dispatch_next(last_dt) + if dispatched_dt is None: + time.sleep(0.01) + else: + last_dt = dispatched_dt diff --git a/livetrading/live_trading.py b/livetrading/live_trading.py new file mode 100644 index 00000000..b8a7c0a7 --- /dev/null +++ b/livetrading/live_trading.py @@ -0,0 +1,68 @@ +import pandas as pd +import websocket + +from backtesting import Strategy +from livetrading import executor +from livetrading.broker import Broker, Pair +from livetrading.config import config + + +def SMA(arr: pd.Series, n: int) -> pd.Series: + """ + Returns `n`-period simple moving average of array `arr`. + """ + return pd.Series(arr).rolling(n).mean() + + +class LiveStrategy(Strategy): + n1 = 10 + n2 = 20 + + def __init__(self, broker, data, params): + super().__init__(broker=broker, data=data, params=params) + + def init(self): + sma1 = self.I(SMA, self.data.Close, self.n1) + sma2 = self.I(SMA, self.data.Close, self.n2) + + def set_atr_periods(self): + if len(self.data) > 1: + print(self.data.High, self.data.Low) + + def next(self): + print(self.data) + + +class PositionManager: + def __init__(self, exchange, position_amount): + assert position_amount > 0 + self.exchange = exchange + self.position_amount = position_amount + + def on_event(self, bar_event): + # react on event from websocket + pass + + +if __name__ == '__main__': + + websocket.enableTrace(False) + + event_dis = executor.EventDispatcher(LiveStrategy) + + exchange = Broker(event_dis, config=config) + + pair_info = exchange.get_pair_info('BTC-USD') + + position_mgr = PositionManager(exchange, 0.8) + + strategy = LiveStrategy(exchange, [], {}) + + exchange.subscribe_to_ticker_events(Pair(base_symbol="UTC", quote_symbol="SDT"), + '3m', position_mgr.on_event) + + event_dis.set_strategy(strategy) + + event_dis.set_backtesting_partial(cash=100000) + + event_dis.run() diff --git a/livetrading/rest_cli.py b/livetrading/rest_cli.py new file mode 100644 index 00000000..2b9faf79 --- /dev/null +++ b/livetrading/rest_cli.py @@ -0,0 +1,37 @@ +import json +import logging +import requests + +from typing import Optional +from urllib.parse import urljoin + +logger = logging.getLogger(__name__) + + +class RestClient: + """"Class for REST API. + :param config: Config settings for exchange. + """ + def __init__(self, config): + self.url = config['api_url'] + self.session = requests.Session() + self.session.auth = (config.get('username'), config.get('password')) + + def call(self, method, apipath, params: Optional[dict] = None, data=None): + + if str(method).upper() not in ('GET', 'POST', 'PUT', 'DELETE'): + raise ValueError(f'invalid method <{method}>') + + headers = {"Accept": "application/json", + "Content-Type": "application/json" + } + url = urljoin(self.url, apipath) + + try: + resp = self.session.request(method, url, headers=headers, data=json.dumps(data), + params=params) + if resp.status_code == 200: + return resp.json() + return resp.text + except ConnectionError: + logger.warning("Connection error") diff --git a/livetrading/websocket_client.py b/livetrading/websocket_client.py new file mode 100644 index 00000000..6a6c4f8d --- /dev/null +++ b/livetrading/websocket_client.py @@ -0,0 +1,63 @@ +import logging +import websocket, json, _thread + +from typing import Dict, List, Set + +from livetrading.event import EventSource, EventProducer + +logger = logging.getLogger(__name__) + + +class WSClient(EventProducer, websocket.WebSocketApp): + """"Class for channel based web socket clients. + :param config: Config settings for exchange. + """ + def __init__(self, config): + super(WSClient, self).__init__(config['ws_url']) + self.event_sources: Dict[str, EventSource] = {} + self.pending_subscriptions: Set[str] = set() + self.timeout = config['ws_timeout'] + self.on_open = lambda ws: self.subscribe_msg() + self.on_message = lambda ws, msg: self.handle_message(json.loads(msg)) + self.on_error = lambda ws, e: logger.warning(f"Error: {e}") + self.on_close = self.on_close + self._running = False + self.thread = None + + def set_channel_event_source(self, channel: str, event_source: EventSource): + assert channel not in self.event_sources, "channel already registered" + self.event_sources[channel] = event_source + self.pending_subscriptions.add(channel) + + def subscribe_msg(self): + self.pending_subscriptions.update(self.event_sources.keys()) + channels = list(self.pending_subscriptions) + self.subscribe_to_channels(channels) + + def on_close(self): + self.pending_subscriptions = set() + + def main(self): + if not self._running: + self.thread = _thread.start_new_thread(self.run_forever, ()) + self._running = True + + def subscribe_to_channels( + self, channels: List[str] + ): + sub_msg = { + "type": "subscribe", + "product_ids": [ + "ETH-USD", + "BTC-USD" + ], + "channels": channels + } + self.send(json.dumps(sub_msg)) + logger.info(f"Subscribed to channels: {channels}") + + def handle_message(self, message: dict) -> None: + channel = message.get("type") + event_source = self.event_sources.get(channel) + if event_source: + event_source.push_to_queue(message) \ No newline at end of file diff --git a/setup.py b/setup.py index 60fa15ea..242c6f87 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,10 @@ 'numpy >= 1.17.0', 'pandas >= 0.25.0, != 0.25.0', 'bokeh >= 1.4.0', - ], + 'configloader >= 1.0.1', + 'websocket-client >= 1.6.0', + 'urllib3 >= 2.0.3' + ], extras_require={ 'doc': [ 'pdoc3',