From 53f7aafb4f842fd9f27011f31ac538b5273e9dcd Mon Sep 17 00:00:00 2001 From: shinny-pack Date: Fri, 21 Jun 2024 02:40:31 +0000 Subject: [PATCH] Update Version 3.6.0 --- PKG-INFO | 2 +- doc/conf.py | 7 +- doc/demo/index.rst | 1 + doc/demo/jupyter.rst | 14 ++++ doc/demo/notebooks/demo.ipynb | 116 ++++++++++++++++++++++++++++++ doc/quickstart.rst | 2 +- doc/usage/index.rst | 1 + doc/usage/jupyter.rst | 35 +++++++++ doc/usage/option_trade.rst | 6 +- doc/version.rst | 8 +++ setup.py | 2 +- tqsdk/__version__.py | 2 +- tqsdk/algorithm/twap.py | 6 +- tqsdk/api.py | 69 +++++++++++------- tqsdk/backtest/backtest.py | 15 ++-- tqsdk/backtest/replay.py | 3 +- tqsdk/baseApi.py | 27 +++++++ tqsdk/baseModule.py | 3 +- tqsdk/connect.py | 6 +- tqsdk/data_extension.py | 2 +- tqsdk/data_series.py | 3 +- tqsdk/lib/target_pos_scheduler.py | 6 +- tqsdk/lib/target_pos_task.py | 3 +- tqsdk/stockprofit.py | 2 +- tqsdk/symbols.py | 3 +- tqsdk/tafunc.py | 25 ++++++- tqsdk/tqwebhelper.py | 7 +- tqsdk/trade_extension.py | 2 +- tqsdk/tradeable/otg/tqzq.py | 2 +- tqsdk/tradeable/sim/basesim.py | 8 +-- 30 files changed, 304 insertions(+), 84 deletions(-) create mode 100644 doc/demo/jupyter.rst create mode 100644 doc/demo/notebooks/demo.ipynb create mode 100644 doc/usage/jupyter.rst diff --git a/PKG-INFO b/PKG-INFO index 24659c5c..a9c2698d 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: tqsdk -Version: 3.5.10 +Version: 3.6.0 Summary: TianQin SDK Home-page: https://www.shinnytech.com/tqsdk Author: TianQin diff --git a/doc/conf.py b/doc/conf.py index 1e98e98a..73674371 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -14,7 +14,7 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.githubpages', 'autodocsumm'] +extensions = ["nbsphinx", 'sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.githubpages', 'autodocsumm', ] smartquotes = False # 设置 graphviz_dot 路径 @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = u'3.5.10' +version = u'3.6.0' # The full version, including alpha/beta/rc tags. -release = u'3.5.10' +release = u'3.6.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -63,6 +63,7 @@ # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +nbsphinx_execute = 'never' # 无输出的notebook cell将不会被运行,有输出的将在make过程中保留 # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' diff --git a/doc/demo/index.rst b/doc/demo/index.rst index ae57a70f..911bd391 100644 --- a/doc/demo/index.rst +++ b/doc/demo/index.rst @@ -10,4 +10,5 @@ option_base.rst algorithm.rst strategy.rst + jupyter.rst diff --git a/doc/demo/jupyter.rst b/doc/demo/jupyter.rst new file mode 100644 index 00000000..4740f470 --- /dev/null +++ b/doc/demo/jupyter.rst @@ -0,0 +1,14 @@ +.. _demo_jupyter: + +Jupyter 示例 +==================================================== + +.. contents:: 目录 + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + notebooks/demo + diff --git a/doc/demo/notebooks/demo.ipynb b/doc/demo/notebooks/demo.ipynb new file mode 100644 index 00000000..3a32fdc2 --- /dev/null +++ b/doc/demo/notebooks/demo.ipynb @@ -0,0 +1,116 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "## 获取实时行情", + "id": "cf67a7321f3bb1f9" + }, + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-06-19T02:43:38.474676Z", + "start_time": "2024-06-19T02:43:37.562702Z" + } + }, + "source": [ + "from tqsdk import TqApi, TqSim, TqAuth\n", + "auth = TqAuth('快期账户', '快期密码')" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "在使用天勤量化之前,默认您已经知晓并同意以下免责条款,如果不同意请立即停止使用:https://www.shinnytech.com/blog/disclaimer/\n" + ] + } + ], + "execution_count": 1 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-19T02:43:40.994847Z", + "start_time": "2024-06-19T02:43:38.476915Z" + } + }, + "cell_type": "code", + "source": "api = TqApi(TqSim(), auth=auth)", + "id": "1e9889a79d1712c6", + "outputs": [], + "execution_count": 2 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-19T02:43:41.019471Z", + "start_time": "2024-06-19T02:43:40.996711Z" + } + }, + "cell_type": "code", + "source": "q = api.get_quote('SHFE.rb2410')", + "id": "2882551c574b0eb", + "outputs": [], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-19T02:43:41.025745Z", + "start_time": "2024-06-19T02:43:41.021628Z" + } + }, + "cell_type": "code", + "source": "print(q.datetime, q.last_price)", + "id": "61778b7e63c12fd7", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-06-19 10:43:40.500000 3639\n" + ] + } + ], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-19T02:43:41.051300Z", + "start_time": "2024-06-19T02:43:41.027038Z" + } + }, + "cell_type": "code", + "source": "api.close()", + "id": "d399ddff6ba9fc70", + "outputs": [], + "execution_count": 5 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/quickstart.rst b/doc/quickstart.rst index b19341a6..8cd65369 100644 --- a/doc/quickstart.rst +++ b/doc/quickstart.rst @@ -24,7 +24,7 @@ ------------------------------------------------- 天勤量化的核心是TqSdk开发包, 在安装天勤量化 (TqSdk) 前, 你需要先准备适当的环境和Python包管理工具, 包括: -* Python >=3.6.4,3.7,3.8,3.9 版本 +* Python >=3.7,3.8,3.9,3.10,3.11,3.12 版本 * Windows 7 以上版本, Mac Os, 或 Linux diff --git a/doc/usage/index.rst b/doc/usage/index.rst index 771c5f08..0bd9bbcf 100644 --- a/doc/usage/index.rst +++ b/doc/usage/index.rst @@ -16,3 +16,4 @@ targetpostask.rst backtest.rst web_gui.rst + jupyter.rst diff --git a/doc/usage/jupyter.rst b/doc/usage/jupyter.rst new file mode 100644 index 00000000..46780ef2 --- /dev/null +++ b/doc/usage/jupyter.rst @@ -0,0 +1,35 @@ +.. _jupyter: + +在 Jupyter Notebook 中使用 TqSdk +==================================================== + +本文档将介绍如何在 Jupyter Notebook 中使用 TqSdk。与普通代码类似,直接导入 TqSdk 并使用即可。 + +安装 Jupyter Notebook +---------------------------------------------------- + +Jupyter Notebook 是一个开源项目,能够在交互式编程环境中提供丰富的可视化表达,可以创建和共享代码和文档。 + +可以使用以下命令安装 Jupyter Notebook: + +```bash +pip install jupyter +``` + +更多 Jupyter Notebook 安装文档请参考 + +`Jupyter Notebook 安装文档 `_ :: +`Jupyter Notebook 文档 `_ :: + + +安装 TqSdk +---------------------------------------------------- + +请参考 :ref:`安装文档 ` 。 + + +在 Jupyter Notebook 中使用 TqSdk +---------------------------------------------------- + +请参考 :ref:`示例 ` 。 + diff --git a/doc/usage/option_trade.rst b/doc/usage/option_trade.rst index 2865f29c..122614f9 100644 --- a/doc/usage/option_trade.rst +++ b/doc/usage/option_trade.rst @@ -61,13 +61,13 @@ TqSdk 内提供了完善的期权查询函数 :py:meth:`~tqsdk.TqApi.query_optio ls = api.query_options("SHFE.au2012", strike_price=340) print(ls) # 标的为 "SHFE.au2012" 、行权价为 340 的期权 - ls = api.query_options("SSE.510300", exchange_id="CFFEX") + ls = api.query_options("SSE.510300") print(ls) # 中金所沪深300股指期权 - ls = api.query_options("SSE.510300", exchange_id="SSE") + ls = api.query_options("SSE.510300") print(ls) # 上交所沪深300etf期权 - ls = api.query_options("SSE.510300", exchange_id="SSE", exercise_year=2020, exercise_month=12) + ls = api.query_options("SSE.510300", exercise_year=2020, exercise_month=12) print(ls) # 上交所沪深300etf期权, 限制条件 2020 年 12 月份行权 diff --git a/doc/version.rst b/doc/version.rst index 770f95ba..89bf9b01 100644 --- a/doc/version.rst +++ b/doc/version.rst @@ -2,6 +2,14 @@ 版本变更 ============================= +3.6.0 (2024/06/21) + +* 增加:支持在 Jupyter 中使用 Tqsdk,详情参考 :ref:`jupyter` +* 修复:tafunc 中部分计算函数对 :py:meth:`~tqsdk.TqApi.get_kline_data_series`、:py:meth:`~tqsdk.TqApi.get_tick_data_series` 返回的 duration 列单位处理错误的问题 +* 优化:tqsdk 内部 task 的取消机制,避免了部分 task 无法正常取消的问题 +* docs: 修正文档,支持的 python 版本为 3.7, 3.8, 3.9, 3.10, 3.11, 3.12 + + 3.5.10 (2024/05/27) * 修复:某些情况下网络连接发生超时错误时,可能无法重连的问题 diff --git a/setup.py b/setup.py index f154ba08..c2d9c677 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name='tqsdk', - version="3.5.10", + version="3.6.0", description='TianQin SDK', author='TianQin', author_email='tianqincn@gmail.com', diff --git a/tqsdk/__version__.py b/tqsdk/__version__.py index 8e9d94a6..826cf62c 100644 --- a/tqsdk/__version__.py +++ b/tqsdk/__version__.py @@ -1 +1 @@ -__version__ = '3.5.10' +__version__ = '3.6.0' diff --git a/tqsdk/algorithm/twap.py b/tqsdk/algorithm/twap.py index 5649d269..335a395a 100644 --- a/tqsdk/algorithm/twap.py +++ b/tqsdk/algorithm/twap.py @@ -218,8 +218,7 @@ async def _insert_order(self, volume, end_time, strict_end_time, exit_immediatel if exit_immediately and volume_left == 0: break finally: - self._order_task._task.cancel() - await asyncio.gather(self._order_task._task, return_exceptions=True) + await self._api._cancel_task(self._order_task._task) while not trade_chan.empty(): v = await trade_chan.recv() volume_left = volume_left - (v if self._direction == "BUY" else -v) @@ -240,8 +239,7 @@ async def _insert_order_active(self, volume): break finally: await trade_chan.close() - self._order_task._task.cancel() - await asyncio.gather(self._order_task._task, return_exceptions=True) + await self._api._cancel_task(self._order_task._task) def _get_volume_list(self): if self._volume < self._max_volume_each_order: diff --git a/tqsdk/api.py b/tqsdk/api.py index b8af6ae5..50a7a696 100644 --- a/tqsdk/api.py +++ b/tqsdk/api.py @@ -30,6 +30,7 @@ import warnings from datetime import datetime, date, timedelta from typing import Union, List, Any, Optional, Coroutine, Callable, Tuple, Dict +from asyncio.events import _get_running_loop, _set_running_loop import numpy as np import psutil @@ -348,17 +349,23 @@ def close(self) -> None: """ if self._loop.is_closed(): return - if self._loop.is_running(): - raise Exception("不能在协程中调用 close, 如需关闭 api 实例需在 wait_update 返回后再关闭") - elif asyncio._get_running_loop(): - raise Exception( - "TqSdk 使用了 python3 的原生协程和异步通讯库 asyncio,您所使用的 IDE 不支持 asyncio, 请使用 pycharm 或其它支持 asyncio 的 IDE") - # 总会发送 serial_extra_array 数据,由 TqWebHelper 处理 - for _, serial in self._serials.items(): - self._process_serial_extra_array(serial) - super(TqApi, self)._close() - mem = psutil.virtual_memory() - self._logger.debug("process end", mem_total=mem.total, mem_free=mem.free) + other_loop = None + try: + if self._loop.is_running(): + raise Exception("不能在协程中调用 close, 如需关闭 api 实例需在 wait_update 返回后再关闭") + else: + other_loop = _get_running_loop() + if other_loop: + _set_running_loop(None) + # 总会发送 serial_extra_array 数据,由 TqWebHelper 处理 + for _, serial in self._serials.items(): + self._process_serial_extra_array(serial) + super(TqApi, self)._close() + mem = psutil.virtual_memory() + self._logger.debug("process end", mem_total=mem.total, mem_free=mem.free) + finally: + if other_loop: + _set_running_loop(other_loop) def __enter__(self): return self @@ -603,7 +610,7 @@ def get_kline_serial(self, symbol: Union[str, List[str]], duration_seconds: int, 注意: 周期在日线以内时此参数可以任意填写, 在日线以上时只能是日线(86400)的整数倍, 最大为28天 (86400*28)。 data_length (int): 需要获取的序列长度。默认200根, 返回的K线序列数据是从当前最新一根K线开始往回取data_length根。\ - 每个序列最大支持请求 8000 个数据 + 每个序列最大支持请求 10000 个数据 adj_type (str/None): [可选]指定复权类型,默认为 None。adj_type 参数只对股票和基金类型合约有效。\ "F" 表示前复权;"B" 表示后复权;None 表示不做处理。 @@ -619,7 +626,7 @@ def get_kline_serial(self, symbol: Union[str, List[str]], duration_seconds: int, 3. 若设置了较大的序列长度参数,而所有可对齐的数据并没有这么多,则序列前面部分数据为NaN(这与获取单合约K线且数据不足序列长度时情况相似)。 - 4. 若主合约与副合约的交易时间在所有合约数据中最晚一根K线时间开始往回的 8000*周期 时间段内完全不重合,则无法生成多合约K线,程序会报出获取数据超时异常。 + 4. 若主合约与副合约的交易时间在所有合约数据中最晚一根K线时间开始往回的 10000*周期 时间段内完全不重合,则无法生成多合约K线,程序会报出获取数据超时异常。 5. datetime、duration是所有合约公用的字段,则未单独为每个副合约增加一份副本,这两个字段使用原始字段名(即没有数字后缀)。 @@ -698,8 +705,8 @@ def get_kline_serial(self, symbol: Union[str, List[str]], duration_seconds: int, adj_type = adj_type[0] if adj_type else adj_type if adj_type and len(symbol) > 1: raise Exception("参数错误,多合约 K 线序列不支持复权。") - if data_length > 8964: - data_length = 8964 + if data_length > 10000: + data_length = 10000 dur_id = duration_seconds * 1000000000 request = (tuple(symbol), duration_seconds, data_length, adj_type) # request 中 symbols 为 tuple 序列 serial = self._requests["klines"].get(request, None) @@ -708,7 +715,7 @@ def get_kline_serial(self, symbol: Union[str, List[str]], duration_seconds: int, "chart_id": _generate_uuid("PYSDK_realtime"), "ins_list": ",".join(symbol), "duration": dur_id, - "view_width": data_length if len(symbol) == 1 else 8964, + "view_width": data_length if len(symbol) == 1 else 10000, # 如果同时订阅了两个以上合约K线,初始化数据时默认获取 1w 根K线(初始化完成后修改指令为设定长度) } if serial is None else {} # 将数据权转移给TqChan时其所有权也随之转移,因pack还需要被用到,所以传入副本 @@ -726,7 +733,7 @@ def get_kline_serial(self, symbol: Union[str, List[str]], duration_seconds: int, if not self.wait_update(deadline=deadline, _task=[task, serial["df"].__dict__["_task"]]): if len(symbol) > 1: raise TqTimeoutError("获取 %s (%d) 的K线超时,请检查客户端及网络是否正常,或任一副合约在主合约行情的最后 %d 秒内无可对齐的K线" % ( - symbol, duration_seconds, 8964 * duration_seconds)) + symbol, duration_seconds, 10000 * duration_seconds)) else: raise TqTimeoutError("获取 %s (%d) 的K线超时,请检查客户端及网络是否正常" % (symbol, duration_seconds)) return serial["df"] @@ -750,7 +757,7 @@ def get_tick_serial(self, symbol: str, data_length: int = 200, adj_type: Union[s Args: symbol (str): 指定合约代码. - data_length (int): 需要获取的序列长度。每个序列最大支持请求 8000 个数据 + data_length (int): 需要获取的序列长度。每个序列最大支持请求 10000 个数据 adj_type (str/None): [可选]指定复权类型,默认为 None。adj_type 参数只对股票和基金类型合约有效。\ "F" 表示前复权;"B" 表示后复权;None 表示不做处理。 @@ -799,8 +806,8 @@ def get_tick_serial(self, symbol: str, data_length: int = 200, adj_type: Union[s if adj_type not in [None, "F", "B", "FORWARD", "BACK"]: raise Exception("adj_type 参数只支持 None (不复权) | 'F' (前复权) | 'B' (后复权)") adj_type = adj_type[0] if adj_type else adj_type - if data_length > 8964: - data_length = 8964 + if data_length > 10000: + data_length = 10000 request = (symbol, data_length, adj_type) serial = self._requests["ticks"].get(request, None) pack = { @@ -1050,7 +1057,7 @@ def query_his_cont_quotes(self, symbol: Union[str, List[str]], n: int = 200) -> * str: 一个合约代码 * list of str: 合约代码列表 (一次提取多个合约的K线并根据相同的时间向第一个合约(主合约)对齐) - n:返回 n 个交易日交易日的对应品种的主力, 默认值为 200,最大为 8964 + n:返回 n 个交易日交易日的对应品种的主力, 默认值为 200 Returns: pandas.DataFrame: 包含 n 行数据,列数为指定主连合约代码个数加 1,有以下列: @@ -1876,12 +1883,20 @@ def wait_update(self, deadline: Optional[float] = None, _task: Union[asyncio.Tas 可能输出 ""(空字符串), 表示还没有收到该合约的行情 """ - if self._loop.is_running(): - raise Exception("不能在协程中调用 wait_update, 如需在协程中等待业务数据更新请使用 register_update_notify") - elif asyncio._get_running_loop(): - raise Exception( - "TqSdk 使用了 python3 的原生协程和异步通讯库 asyncio,您所使用的 IDE 不支持 asyncio, 请使用 pycharm 或其它支持 asyncio 的 IDE") - self._wait_timeout = False + other_loop = None + try: + if self._loop.is_running(): + raise Exception("不能在协程中调用 wait_update, 如需在协程中等待业务数据更新请使用 register_update_notify") + else: + other_loop = _get_running_loop() + if other_loop: + _set_running_loop(None) + return self._wait_update(deadline=deadline, _task=_task) + finally: + if other_loop: + _set_running_loop(other_loop) + + def _wait_update(self, deadline: Optional[float] = None, _task: Union[asyncio.Task, List[asyncio.Task], None] = None) -> bool: # 先尝试执行各个task,再请求下个业务数据,可能用户的同步代码会在 chan 中 send 数据,需要先 run_tasks self._run_until_idle(async_run=False) diff --git a/tqsdk/backtest/backtest.py b/tqsdk/backtest/backtest.py index fc46c560..7d76c2ba 100644 --- a/tqsdk/backtest/backtest.py +++ b/tqsdk/backtest/backtest.py @@ -193,8 +193,7 @@ async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan, md_recv_ch # 关闭所有 generator for s in self._generators.values(): await s.aclose() - md_task.cancel() - await asyncio.gather(md_task, return_exceptions=True) + await self._api._cancel_task(md_task) async def _md_handler(self): async for pack in self._md_recv_chan: @@ -476,7 +475,7 @@ async def _gen_serial(self, ins, dur): """k线/tick 序列的 async generator, yield 出来的行情数据带有时间戳, 因此 _send_diff 可以据此归并""" # 先定位左端点, focus_datetime 是 lower_bound ,这里需要的是 upper_bound # 因此将 view_width 和 focus_position 设置成一样,这样 focus_datetime 所对应的 k线刚好位于屏幕外 - # 使用两个长度为 8964 的 chart,去缓存/回收下游需要的数据 + # 使用两个长度为 10000 的 chart,去缓存/回收下游需要的数据 chart_id_a = _generate_uuid("PYSDK_backtest") chart_id_b = _generate_uuid("PYSDK_backtest") chart_info = { @@ -484,9 +483,9 @@ async def _gen_serial(self, ins, dur): "chart_id": chart_id_a, "ins_list": ins, "duration": dur, - "view_width": 8964, # 设为8964原因:可满足用户所有的订阅长度,并在backtest中将所有的 相同合约及周期 的K线用同一个serial存储 + "view_width": 10000, # 设为 10000 原因:可满足用户所有的订阅长度,并在backtest中将所有的 相同合约及周期 的K线用同一个serial存储 "focus_datetime": int(self._current_dt), - "focus_position": 8964, + "focus_position": 10000, } chart_a = _get_obj(self._data, ["charts", chart_id_a]) chart_b = _get_obj(self._data, ["charts", chart_id_b]) @@ -535,7 +534,7 @@ async def _gen_serial(self, ins, dur): right_id = chart.get("right_id", -1) if current_id is None: current_id = max(left_id, 0) - # 发送下一段 chart 8964 根 kline + # 发送下一段 chart 10000 根 kline chart_info["chart_id"] = chart_id_b if chart_info["chart_id"] == chart_id_a else chart_id_a chart_info["left_kline_id"] = right_id chart_info.pop("focus_datetime", None) @@ -545,7 +544,7 @@ async def _gen_serial(self, ins, dur): if current_id > last_id: # 当前 id 已超过 last_id return - # 将订阅的8964长度的窗口中的数据都遍历完后,退出循环,然后再次进入并处理下一窗口数据 + # 将订阅的10000长度的窗口中的数据都遍历完后,退出循环,然后再次进入并处理下一窗口数据 if current_id > right_id: break item = {k: v for k, v in serials[0]["data"].get(str(current_id), {}).items()} @@ -556,7 +555,7 @@ async def _gen_serial(self, ins, dur): "last_id": current_id, "data": { str(current_id): item, - str(current_id - 8964): None, + str(current_id - 10000): None, } } } diff --git a/tqsdk/backtest/replay.py b/tqsdk/backtest/replay.py index 9db6c35b..4f95ba85 100644 --- a/tqsdk/backtest/replay.py +++ b/tqsdk/backtest/replay.py @@ -80,8 +80,7 @@ async def _run(self): await asyncio.sleep(30) finally: await self._send_chan.close() - _senddata_task.cancel() - await asyncio.gather(_senddata_task, return_exceptions=True) + await self._api._cancel_task(_senddata_task) def _prepare_session(self): create_session_url = "http://replay.api.shinnytech.com/t/rmd/replay/create_session" diff --git a/tqsdk/baseApi.py b/tqsdk/baseApi.py index 99b51f19..e51f46d4 100644 --- a/tqsdk/baseApi.py +++ b/tqsdk/baseApi.py @@ -44,6 +44,33 @@ def _create_task(self, coro: Coroutine, _caller_api: bool = False) -> asyncio.Ta task.add_done_callback(self._on_task_done) return task + async def _cancel_tasks(self, *tasks): + # 目前的 task 退出流程无法处理在 finally 中被 cancel 的情况, + # 例如: twap _run 中调用 _insert_order 并且刚好时间段结束执行到 finally 时整个 api 被 close 触发 CancelError + if len(tasks) == 0: + return + task = tasks[0] + other_tasks = tasks[1:] + try: + await self._cancel_task(task) + finally: + if tasks: + await self._cancel_tasks(*other_tasks) + + async def _cancel_task(self, task): + exception = None + task.cancel() + # 如果 task 已经 done,可以调用 cancel 但是 cancelled 不会变成 True + while not task.done(): + try: + await asyncio.shield(task) # task 不会再被 cancel + except asyncio.CancelledError as ex: + if not task.cancelled(): + exception = ex + await asyncio.sleep(0) # Give callbacks a chance to run + if exception: + raise exception from None + def _call_soon(self, org_call_soon, callback, *args, **kargs): """ioloop.call_soon的补丁, 用来追踪是否有任务完成并等待执行""" self._event_rev += 1 diff --git a/tqsdk/baseModule.py b/tqsdk/baseModule.py index ab22ddfc..4553f816 100644 --- a/tqsdk/baseModule.py +++ b/tqsdk/baseModule.py @@ -44,8 +44,7 @@ async def _run(self, api, api_send_chan, api_recv_chan, *args): await self._handle_req_data(pack) await self._send_diff(api_recv_chan) finally: - [task.cancel() for task in up_handle_tasks] - await asyncio.gather(*up_handle_tasks, return_exceptions=True) + await self._api._cancel_tasks(*up_handle_tasks) async def _up_handler(self, api_send_chan, recv_chan, chan_index): async for pack in recv_chan: diff --git a/tqsdk/connect.py b/tqsdk/connect.py index d5ea793a..ebd40a20 100644 --- a/tqsdk/connect.py +++ b/tqsdk/connect.py @@ -196,8 +196,7 @@ async def _run(self, api, url, send_chan, recv_chan): self._logger.debug("websocket connection info", current_time=time.time(), start_read_message=client.reader._start_read_message, read_size=client.reader._read_size) - send_task.cancel() - await send_task + await self._api._cancel_task(send_task) # 希望做到的效果是遇到网络问题可以断线重连, 但是可能抛出的例外太多了(TimeoutError,socket.gaierror等), 又没有文档或工具可以理出 try 代码中所有可能遇到的例外 # 而这里的 except 又需要处理所有子函数及子函数的子函数等等可能抛出的例外, 因此这里只能遇到问题之后再补, 并且无法避免 false positive 和 false negative except (websockets.exceptions.ConnectionClosed, websockets.exceptions.InvalidStatusCode, websockets.exceptions.InvalidURI, @@ -322,8 +321,7 @@ async def _run(self, api, api_send_chan, api_recv_chan, ws_send_chan, ws_recv_ch else: await api_recv_chan.send(pack) finally: - send_task.cancel() - await asyncio.gather(send_task, return_exceptions=True) + await self._api._cancel_task(send_task) async def _send_handler(self, api_send_chan, ws_send_chan): async for pack in api_send_chan: diff --git a/tqsdk/data_extension.py b/tqsdk/data_extension.py index c398e51e..0dfb741e 100644 --- a/tqsdk/data_extension.py +++ b/tqsdk/data_extension.py @@ -100,7 +100,7 @@ async def _run(self, api_send_chan, api_recv_chan, md_send_chan, md_recv_chan): else: await self._md_send_chan.send(pack) finally: - md_task.cancel() + await self._api._cancel_task(md_task) async def _md_handler(self): """0 接收上游数据包 """ diff --git a/tqsdk/data_series.py b/tqsdk/data_series.py index 0d6083c5..69ad0bdf 100644 --- a/tqsdk/data_series.py +++ b/tqsdk/data_series.py @@ -185,8 +185,7 @@ async def _download_data_series(self, rangeset): target_filename = os.path.join(CACHE_DIR, f"{symbol}.{self._dur_nano}.{start_id}.{end_id + 1}") shutil.move(temp_filename, target_filename) finally: - task.cancel() - await task + await self._api._cancel_task(task) async def _download_data(self, start_dt, end_dt, data_chan): # 下载的数据应该是 [start_dt, end_dt) diff --git a/tqsdk/lib/target_pos_scheduler.py b/tqsdk/lib/target_pos_scheduler.py index cab55e24..de862fa6 100644 --- a/tqsdk/lib/target_pos_scheduler.py +++ b/tqsdk/lib/target_pos_scheduler.py @@ -135,8 +135,7 @@ async def _run(self): async for _ in self._api.register_update_notify(quote): if _get_trade_timestamp(quote.datetime, float('nan')) > row['deadline']: if target_pos_task: - target_pos_task.cancel() - await asyncio.gather(target_pos_task, return_exceptions=True) + await self._api._cancel_task(target_pos_task._task) break elif target_pos_task: # 最后一项,如果有 target_pos_task 等待持仓调整完成,否则直接退出 position = self._account.get_position(self._symbol) @@ -147,8 +146,7 @@ async def _run(self): _index = _index + 1 finally: if target_pos_task: - target_pos_task.cancel() - await asyncio.gather(target_pos_task, return_exceptions=True) + await self._api._cancel_task(target_pos_task._task) await self._trade_objs_chan.close() await self._trade_recv_task diff --git a/tqsdk/lib/target_pos_task.py b/tqsdk/lib/target_pos_task.py index ccd8e4e2..4f02abf6 100644 --- a/tqsdk/lib/target_pos_task.py +++ b/tqsdk/lib/target_pos_task.py @@ -235,8 +235,7 @@ async def _exit_task(self): # self._account 类型为 TqSim/TqKq/TqAccount,都包括 _account_key 变量 TargetPosTaskSingleton._instances.pop(self._account._account_key + "#" + self._symbol, None) await self._pos_chan.close() - self._time_update_task.cancel() - await asyncio.gather(self._time_update_task, return_exceptions=True) + await self._api._cancel_task(self._time_update_task) self._wait_task_finished.set_result(True) def __await__(self): diff --git a/tqsdk/stockprofit.py b/tqsdk/stockprofit.py index b26171d6..c0815ac4 100644 --- a/tqsdk/stockprofit.py +++ b/tqsdk/stockprofit.py @@ -47,7 +47,7 @@ async def _run(self, api_send_chan, api_recv_chan, md_send_chan, md_recv_chan): else: await self._md_send_chan.send(pack) finally: - md_task.cancel() + await self._api._cancel_task(md_task) async def _md_handler(self): diff --git a/tqsdk/symbols.py b/tqsdk/symbols.py index 4986804f..a8c5494b 100644 --- a/tqsdk/symbols.py +++ b/tqsdk/symbols.py @@ -60,8 +60,7 @@ async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan, md_recv_ch data.append({"quotes": updated_quotes}) await self._sim_recv_chan.send(pack) finally: - sim_task.cancel() - await asyncio.gather(sim_task, return_exceptions=True) + await self._api._cancel_task(sim_task) async def _sim_handler(self): # 下游发来的数据包,直接转发到上游 diff --git a/tqsdk/tafunc.py b/tqsdk/tafunc.py index b25cec32..6cd4ecc1 100644 --- a/tqsdk/tafunc.py +++ b/tqsdk/tafunc.py @@ -820,7 +820,27 @@ def barlast(cond): return pd.Series(r) -def _get_t_series(series: pd.Series, dur: int, expire_datetime: int): +def _duration_ensure_unit_to_s(dur: Union[int, pd.Series]) -> int: + """ + 判断输入duration这一时间的单位。如果单位是纳秒,则转换为秒。 + + Args: + dur: Union[int, pd.Series]: 数据周期 + 可能是一个可以直接判断的int,也可能是Dataframe中的某一列 + + Returns: + int: 以秒为单位的数据周期 + """ + if isinstance(dur, pd.Series): + # 进行类型转换,确保运算过程不报错 + dur = dur[0] + if dur % (10 ** 9) == 0: + dur //= (10 ** 9) + return dur + + +def _get_t_series(series: pd.Series, dur: Union[int, pd.Series], expire_datetime: int): + dur = _duration_ensure_unit_to_s(dur) t = pd.Series(pd.to_timedelta(expire_datetime - (series / 1e9 + dur), unit='s')) return (t.dt.days * 86400 + t.dt.seconds) / (360 * 86400) @@ -921,7 +941,7 @@ def _get_volatility(series: pd.Series, dur: Union[pd.Series, int] = 86400, tradi if series_u.size < 2: # 自由度小于2无法计算,返回一个默认值 return float("nan") seconds_per_day = 24 * 60 * 60 - dur = dur[0] if isinstance(dur, pd.Series) else dur + dur = _duration_ensure_unit_to_s(dur) if dur < 24 * 60 * 60 and trading_time: periods = _get_period_timestamp(0, trading_time.get("day", []) + trading_time.get("night", [])) seconds_per_day = sum([p[1] - p[0] for p in periods]) / 1e9 @@ -1531,4 +1551,3 @@ def _cum_counts(s: Series): output: [0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 2, 0, 1, 0, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 0, 0] """ return s * (s.groupby((s != s.shift()).cumsum()).cumcount() + 1) - diff --git a/tqsdk/tqwebhelper.py b/tqsdk/tqwebhelper.py index 05ebe479..ea9f71c4 100644 --- a/tqsdk/tqwebhelper.py +++ b/tqsdk/tqwebhelper.py @@ -82,8 +82,7 @@ async def _run(self, api_send_chan, api_recv_chan, web_send_chan, web_recv_chan) if pack['aid'] not in ['set_chart_data', 'set_report_data']: await web_send_chan.send(pack) finally: - _data_handler_without_web_task.cancel() - await asyncio.gather(_data_handler_without_web_task, return_exceptions=True) + await self._api._cancel_task(_data_handler_without_web_task) else: self._web_dir = os.path.join(os.path.dirname(__file__), 'web') file_path = os.path.abspath(sys.argv[0]) @@ -159,9 +158,7 @@ async def _run(self, api_send_chan, api_recv_chan, web_send_chan, web_recv_chan) # 发送的转发给上游 await web_send_chan.send(pack) finally: - _data_task.cancel() - _httpserver_task.cancel() - await asyncio.gather(_data_task, _httpserver_task, return_exceptions=True) + await self._api._cancel_tasks(*[_data_task, _httpserver_task]) async def _data_handler_without_web(self, api_recv_chan, web_recv_chan): # 没有 web_gui, 接受全部数据转发给下游 api_recv_chan diff --git a/tqsdk/trade_extension.py b/tqsdk/trade_extension.py index d950843c..3733e943 100644 --- a/tqsdk/trade_extension.py +++ b/tqsdk/trade_extension.py @@ -81,7 +81,7 @@ async def _run(self, api_send_chan, api_recv_chan, md_send_chan, md_recv_chan): else: await self._md_send_chan.send(pack) finally: - md_task.cancel() + await self._api._cancel_task(md_task) async def _md_handler(self): """0 接收上游数据包 """ diff --git a/tqsdk/tradeable/otg/tqzq.py b/tqsdk/tradeable/otg/tqzq.py index b9241311..2a3bc36b 100644 --- a/tqsdk/tradeable/otg/tqzq.py +++ b/tqsdk/tradeable/otg/tqzq.py @@ -25,7 +25,7 @@ def __init__(self, account_id: str, password: str, td_url: str) -> None: Example1:: from tqsdk import TqApi, TqZq - account = TqZq(user_name="众期账户", password="众期密码", td_url="众期柜台地址") + account = TqZq(account_id="众期账户", password="众期密码", td_url="众期柜台地址") api = TqApi(account, auth=TqAuth("快期账户", "账户密码")) """ diff --git a/tqsdk/tradeable/sim/basesim.py b/tqsdk/tradeable/sim/basesim.py index 7b434833..d7fd723b 100644 --- a/tqsdk/tradeable/sim/basesim.py +++ b/tqsdk/tradeable/sim/basesim.py @@ -78,9 +78,8 @@ async def _run(self, api, api_send_chan, api_recv_chan, md_send_chan, md_recv_ch await super(BaseSim, self)._run(api, api_send_chan, api_recv_chan, md_send_chan, md_recv_chan) finally: self._handle_stat_report() - for s in self._quote_tasks: - self._quote_tasks[s]["task"].cancel() - await asyncio.gather(*[self._quote_tasks[s]["task"] for s in self._quote_tasks], return_exceptions=True) + tasks = [self._quote_tasks[s]["task"] for s in self._quote_tasks] + await self._api._cancel_tasks(*tasks) async def _handle_recv_data(self, pack, chan): """ @@ -237,8 +236,7 @@ async def _quote_handler(self, symbol, quote_chan, order_chan): finally: await quote_chan.close() await order_chan.close() - task.cancel() - await asyncio.gather(task, return_exceptions=True) + await self._api._cancel_task(task) async def _forward_chan_handler(self, chan_from, chan_to): async for pack in chan_from: