Skip to content

Commit 076d87a

Browse files
committed
[Add] 初步完成基于aiohttp的websocket客户端
1 parent 320eca5 commit 076d87a

File tree

4 files changed

+243
-0
lines changed

4 files changed

+243
-0
lines changed

.flake8

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[flake8]
2+
exclude = build,__pycache__,__init__.py
3+
ignore =
4+
E501 line too long, fixed by black
5+
W503 line break before binary operator

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,4 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
.vscode/settings.json

vnpy_websocket/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .websocket_client import WebsocketClient

vnpy_websocket/websocket_client.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import json
2+
import logging
3+
import sys
4+
import traceback
5+
from datetime import datetime
6+
from typing import Optional
7+
import asyncio
8+
import threading
9+
10+
import aiohttp
11+
12+
from vnpy.trader.utility import get_file_logger
13+
14+
15+
class WebsocketClient:
16+
"""
17+
Websocket API
18+
19+
After creating the client object, use start() to run worker and ping threads.
20+
The worker thread connects websocket automatically.
21+
22+
Use stop to stop threads and disconnect websocket before destroying the client
23+
object (especially when exiting the programme).
24+
25+
Default serialization format is json.
26+
27+
Callbacks to overrides:
28+
* unpack_data
29+
* on_connected
30+
* on_disconnected
31+
* on_packet
32+
* on_error
33+
34+
After start() is called, the ping thread will ping server every 60 seconds.
35+
36+
If you want to send anything other than JSON, override send_packet.
37+
"""
38+
39+
def __init__(self):
40+
"""Constructor"""
41+
self.host = None
42+
43+
self._ws = None
44+
45+
self.session: aiohttp.ClientSession = aiohttp.ClientSession()
46+
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
47+
self.thread: threading.Thread = None
48+
49+
self.proxy_host = None
50+
self.proxy_port = None
51+
self.ping_interval = 60 # seconds
52+
self.header = {}
53+
54+
self.logger: Optional[logging.Logger] = None
55+
56+
# For debugging
57+
self._last_sent_text = None
58+
self._last_received_text = None
59+
60+
def init(
61+
self,
62+
host: str,
63+
proxy_host: str = "",
64+
proxy_port: int = 0,
65+
ping_interval: int = 60,
66+
header: dict = None,
67+
log_path: Optional[str] = None,
68+
):
69+
"""
70+
初始化客户端
71+
"""
72+
self.host = host
73+
self.ping_interval = ping_interval # seconds
74+
if log_path is not None:
75+
self.logger = get_file_logger(log_path)
76+
self.logger.setLevel(logging.DEBUG)
77+
78+
if header:
79+
self.header = header
80+
81+
if proxy_host and proxy_port:
82+
self.proxy = f"http://{proxy_host}:{proxy_port}"
83+
84+
def start(self):
85+
"""
86+
启动客户端
87+
88+
连接成功后会自动调用on_connected回调函数,
89+
90+
请等待on_connected被调用后,再发送数据包。
91+
"""
92+
# 如果目前没有任何事件循环在运行,则启动后台线程
93+
if not self.loop.is_running():
94+
self.thread = threading.Thread(target=self.run)
95+
self.thread.start()
96+
# 否则直接在事件循环中加入新的任务
97+
else:
98+
asyncio.run_coroutine_threadsafe(self._run(), self.loop)
99+
100+
def stop(self):
101+
"""
102+
停止客户端。
103+
"""
104+
coro = self._ws.close()
105+
asyncio.run_coroutine_threadsafe(coro, self.loop)
106+
107+
def join(self):
108+
"""
109+
等待后台线程退出。
110+
"""
111+
if self.thread.is_alive():
112+
self.thread.join()
113+
114+
def send_packet(self, packet: dict):
115+
"""
116+
发送数据包字典到服务器。
117+
118+
如果需要发送非json数据,请重载实现本函数。
119+
"""
120+
text = json.dumps(packet)
121+
self._record_last_sent_text(text)
122+
123+
coro = self._ws.send_str(text)
124+
asyncio.run_coroutine_threadsafe(coro, self.loop)
125+
self._log('sent text: %s', text)
126+
127+
def _log(self, msg, *args):
128+
"""记录日志信息"""
129+
logger = self.logger
130+
if logger:
131+
logger.debug(msg, *args)
132+
133+
def run(self):
134+
"""
135+
在后台线程中运行的主函数
136+
"""
137+
if not self.loop.is_running():
138+
asyncio.set_event_loop(self.loop)
139+
self.loop.run_forever()
140+
141+
asyncio.run_coroutine_threadsafe(self._run(), self.loop)
142+
143+
async def _run(self):
144+
"""
145+
在事件循环中运行的主协程
146+
"""
147+
self._ws = await self.session.ws_connect(
148+
self.host,
149+
proxy=self.proxy,
150+
verify_ssl=False
151+
)
152+
153+
self.on_connected()
154+
155+
async for msg in self._ws:
156+
text = msg.data
157+
print("recv", text)
158+
159+
self._record_last_received_text(text)
160+
161+
try:
162+
data = self.unpack_data(text)
163+
except ValueError as e:
164+
print("websocket unable to parse data: " + text)
165+
raise e
166+
167+
self._log('recv data: %s', data)
168+
self.on_packet(data)
169+
170+
@staticmethod
171+
def unpack_data(data: str):
172+
"""
173+
对字符串数据进行json格式解包
174+
175+
如果需要使用json以外的解包格式,请重载实现本函数。
176+
"""
177+
return json.loads(data)
178+
179+
@staticmethod
180+
def on_connected():
181+
"""
182+
Callback when websocket is connected successfully.
183+
"""
184+
pass
185+
186+
@staticmethod
187+
def on_disconnected():
188+
"""
189+
Callback when websocket connection is lost.
190+
"""
191+
pass
192+
193+
@staticmethod
194+
def on_packet(packet: dict):
195+
"""
196+
Callback when receiving data from server.
197+
"""
198+
pass
199+
200+
def on_error(self, exception_type: type, exception_value: Exception, tb):
201+
"""
202+
Callback when exception raised.
203+
"""
204+
sys.stderr.write(
205+
self.exception_detail(exception_type, exception_value, tb)
206+
)
207+
return sys.excepthook(exception_type, exception_value, tb)
208+
209+
def exception_detail(
210+
self, exception_type: type, exception_value: Exception, tb
211+
):
212+
"""
213+
Print detailed exception information.
214+
"""
215+
text = "[{}]: Unhandled WebSocket Error:{}\n".format(
216+
datetime.now().isoformat(), exception_type
217+
)
218+
text += "LastSentText:\n{}\n".format(self._last_sent_text)
219+
text += "LastReceivedText:\n{}\n".format(self._last_received_text)
220+
text += "Exception trace: \n"
221+
text += "".join(
222+
traceback.format_exception(exception_type, exception_value, tb)
223+
)
224+
return text
225+
226+
def _record_last_sent_text(self, text: str):
227+
"""
228+
Record last sent text for debug purpose.
229+
"""
230+
self._last_sent_text = text[:1000]
231+
232+
def _record_last_received_text(self, text: str):
233+
"""
234+
Record last received text for debug purpose.
235+
"""
236+
self._last_received_text = text[:1000]

0 commit comments

Comments
 (0)