diff --git a/PKG-INFO b/PKG-INFO index 53e47e39..e2f6b936 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: tqsdk -Version: 3.5.4 +Version: 3.5.5 Summary: TianQin SDK Home-page: https://www.shinnytech.com/tqsdk Author: TianQin diff --git a/doc/conf.py b/doc/conf.py index 23082e88..399675f5 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -48,9 +48,9 @@ # built documents. # # The short X.Y version. -version = u'3.5.4' +version = u'3.5.5' # The full version, including alpha/beta/rc tags. -release = u'3.5.4' +release = u'3.5.5' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/doc/version.rst b/doc/version.rst index f3c7dad3..59d2e53b 100644 --- a/doc/version.rst +++ b/doc/version.rst @@ -2,6 +2,13 @@ 版本变更 ============================= +3.5.5 (2024/03/27) + +* 修复:TqSim 在调用 set_margin 之后,使用 is_changing 判断某个对象是否更新,可能返回的结果不正确 +* 优化:多账户下使用 :py:meth:`~tqsdk.algorithm.time_table_generater.vwap_table`, + :py:meth:`~tqsdk.algorithm.time_table_generater.twap_table` 不需要用户多次指定账户 + + 3.5.4 (2024/03/01) * 修复:回测时,订阅多合约 K 线时,成交可能不符合预期的问题 diff --git a/setup.py b/setup.py index 701b762f..643d785e 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setuptools.setup( name='tqsdk', - version="3.5.4", + version="3.5.5", description='TianQin SDK', author='TianQin', author_email='tianqincn@gmail.com', diff --git a/tqsdk/__version__.py b/tqsdk/__version__.py index 17553bb1..b363ca01 100644 --- a/tqsdk/__version__.py +++ b/tqsdk/__version__.py @@ -1 +1 @@ -__version__ = '3.5.4' +__version__ = '3.5.5' diff --git a/tqsdk/algorithm/time_table_generater.py b/tqsdk/algorithm/time_table_generater.py index 0c5feff0..ec3e1695 100644 --- a/tqsdk/algorithm/time_table_generater.py +++ b/tqsdk/algorithm/time_table_generater.py @@ -13,6 +13,7 @@ from tqsdk import utils from tqsdk.datetime import _get_trading_timestamp, _get_trade_timestamp, _get_trading_day_from_timestamp, \ _datetime_to_timestamp_nano, _timestamp_nano_to_datetime +from tqsdk.lib.time_table import TqTimeTable from tqsdk.rangeset import _rangeset_slice, _rangeset_head from tqsdk.tradeable import TqAccount, TqKq, TqSim @@ -124,7 +125,7 @@ def twap_table(api: TqApi, symbol: str, target_pos: int, duration: int, min_volu interval_list = _gen_random_list(sum_val=duration, min_val=min_interval, max_val=max_interval, length=len(volume_list)) - time_table = DataFrame(columns=['interval', 'volume', 'price']) + time_table = TqTimeTable(account=account) for index, volume in enumerate(volume_list): assert interval_list[index] >= 3 active_interval = 2 @@ -253,7 +254,7 @@ def vwap_table(api: TqApi, symbol: str, target_pos: int, duration: float, predicted_percent = volume_percent.groupby(level=1).mean() # 将历史上相同时间单元的成交量占比使用算数平均计算出预测值 # 计算每个时间单元的成交量预测值 - time_table = DataFrame(columns=['interval', 'volume', 'price']) + time_table = TqTimeTable(account=account) volume_left = target_volume # 剩余手数 percent_left = 1 # 剩余百分比 for index, value in predicted_percent.items(): diff --git a/tqsdk/api.py b/tqsdk/api.py index 6824f4fd..62db77a7 100644 --- a/tqsdk/api.py +++ b/tqsdk/api.py @@ -972,11 +972,12 @@ def _get_data_series(self, call_func: str, symbol_list: Union[str, List[str]], d if adj_type not in [None, "F", "B"]: raise Exception("adj_type 参数只支持 None (不复权) | 'F' (前复权) | 'B' (后复权) ") ds = DataSeries(self, symbol_list, dur_nano, start_dt_nano, end_dt_nano, adj_type) - while not self._loop.is_running() and not ds.is_ready: - deadline = time.time() + 30 - if not self.wait_update(deadline=deadline): - raise TqTimeoutError( - f"{call_func} 获取数据 ({symbol_list, duration_seconds, start_dt, end_dt}) 超时,请检查客户端及网络是否正常。") + if not self._loop.is_running(): + while not ds._task.done(): + deadline = time.time() + 30 + if not self.wait_update(deadline=deadline, _task=ds._task): + raise TqTimeoutError( + f"{call_func} 获取数据 ({symbol_list, duration_seconds, start_dt, end_dt}) 超时,请检查客户端及网络是否正常。") return ds.df # ---------------------------------------------------------------------- @@ -1811,13 +1812,12 @@ def set_risk_management_rule(self, exchange_id: str, enable: bool, count_limit: rule = _get_obj(self._data, ["trade", self._account._get_account_key(account), "risk_management_rule", exchange_id], RiskManagementRule(self)) if not self._loop.is_running(): deadline = time.time() + 30 - while not (rule_pack['enable'] == rule['enable'] - and rule_pack['self_trade'].items() <= rule['self_trade'].items() - and rule_pack['frequent_cancellation'].items() <= rule['frequent_cancellation'].items() - and rule_pack['trade_position_ratio'].items() <= rule['trade_position_ratio'].items()): - # @todo: merge diffs - if not self.wait_update(deadline=deadline): - raise TqTimeoutError("设置风控规则超时请检查客户端及网络是否正常") + cond = lambda: (rule_pack['enable'] == rule['enable'] + and rule_pack['self_trade'].items() <= rule['self_trade'].items() + and rule_pack['frequent_cancellation'].items() <= rule['frequent_cancellation'].items() + and rule_pack['trade_position_ratio'].items() <= rule['trade_position_ratio'].items()) + if not self._wait_update_until(cond=cond, deadline=deadline): + raise TqTimeoutError("设置风控规则超时请检查客户端及网络是否正常") return rule # ---------------------------------------------------------------------- @@ -1933,6 +1933,38 @@ def wait_update(self, deadline: Optional[float] = None, _task: Union[asyncio.Tas else: # 订阅多个合约 self._update_serial_multi(serial) + def _wait_update_until(self, cond: Callable[[], bool], deadline: Optional[float] = None) -> bool: + """ + TqApi 内部使用,用于等待某个条件满足。持续调用 wait_update(),直到 cond() 返回 True。 + + Args: + cond (Callable[[], bool]): 条件函数 + deadline (float): [可选]指定截止时间,自unix epoch(1970-01-01 00:00:00 GMT)以来的秒数(time.time())。默认没有超时(无限等待) + + Returns: + bool: 当 cond() 为 True 时返回 True, 如果到截止时间 cond() 依然为 False 则返回 False + + 注:用于 tqsdk 内部,某些地方会用到 api.wait_update(),等待数据更新后再返回给用户, + * 简单调用 wait_update() 导致 api._sync_diffs 丢失变更 + * 为了避免这种情况,内部调用 wait_update() 应该传入 _task 参数,这样 api._sync_diffs 不会丢失变更 + """ + if cond(): + return True + + async def _async_wait_task(): + async with self.register_update_notify() as update_chan: + async for _ in update_chan: + if cond(): + break + + _task = self.create_task(_async_wait_task()) + + while not cond(): + data_updated = self.wait_update(deadline=deadline, _task=_task) + if data_updated is False: + return False # TimeoutError + return True + # ---------------------------------------------------------------------- def is_changing(self, obj: Any, key: Union[str, List[str], None] = None) -> bool: """ @@ -2226,9 +2258,9 @@ def query_graphql(self, query: str, variables: dict, query_id: Optional[str] = N }) deadline = time.time() + 60 if not self._loop.is_running(): - while query_id not in symbols: - if not self.wait_update(deadline=deadline): - raise TqTimeoutError("查询合约服务 %s 超时,请检查客户端及网络是否正常 %s" % (query, query_id)) + if not self._wait_update_until(cond=lambda: query_id in symbols, deadline=deadline): + # 使用 _task 参数,确保不会丢掉 _sync_diffs 里的变更 + raise TqTimeoutError("查询合约服务 %s 超时,请检查客户端及网络是否正常 %s" % (query, query_id)) if isinstance(self._backtest, TqBacktest): self._send_pack({ "aid": "ins_query", diff --git a/tqsdk/data_series.py b/tqsdk/data_series.py index 5f4328f5..255108a0 100644 --- a/tqsdk/data_series.py +++ b/tqsdk/data_series.py @@ -68,9 +68,12 @@ def __init__(self, api, symbol_list, dur_nano, start_dt_nano, end_dt_nano, adj_t self._adj_type = adj_type self._dividend_cache = {} # 缓存合约对应的复权系数矩阵,每个合约只计算一次 self.df = pd.DataFrame() - self.is_ready = False DataSeries._ensure_cache_dir() # 确认缓存文件夹存在 - self._api.create_task(self._run()) + self._task = self._api.create_task(self._run()) + + @property + def is_ready(self): + return self._task.done() async def _run(self): symbol = self._symbol_list[0] # todo: 目前只处理一个合约的情况 @@ -105,7 +108,6 @@ async def _run(self): target_rangeset_dt = _rangeset_intersection([(self._start_dt_nano, self._end_dt_nano)], rangeset_dt) assert len(target_rangeset_dt) <= 1 # 用户请求应该落在一个时间段内,或者用户请求的时间段内没有任何数据 if len(target_rangeset_dt) == 0: # 用户请求的时间段内没有任何数据 - self.is_ready = True return # 此时用户请求时间范围,转化为 target_rangeset_dt[0] @@ -160,8 +162,6 @@ async def _run(self): adj_cols = DataSeries._get_adj_cols(symbol, self._dur_nano) ge = self.df["datetime"].ge(dt) self.df.loc[ge, adj_cols] = self.df.loc[ge, adj_cols] / factor - # 结束状态 - self.is_ready = True async def _download_data_series(self, rangeset): symbol = self._symbol_list[0] diff --git a/tqsdk/lib/target_pos_scheduler.py b/tqsdk/lib/target_pos_scheduler.py index 9e1d02cd..5609930d 100644 --- a/tqsdk/lib/target_pos_scheduler.py +++ b/tqsdk/lib/target_pos_scheduler.py @@ -11,6 +11,7 @@ from tqsdk.channel import TqChan from tqsdk.datetime import _get_trade_timestamp from tqsdk.lib.target_pos_task import TargetPosTask +from tqsdk.lib.time_table import TqTimeTable from tqsdk.lib.utils import _check_time_table, _get_deadline_from_interval from tqsdk.objs import Trade @@ -93,6 +94,8 @@ def __init__(self, api: TqApi, symbol: str, time_table: DataFrame, offset_priori api.close() """ self._api = api + if isinstance(time_table, TqTimeTable): + account = time_table.__dict__["_account"] self._account = api._account._check_valid(account) # 这些参数直接传给 TargetPosTask,由 TargetPosTask 来检查其合法性 diff --git a/tqsdk/lib/time_table.py b/tqsdk/lib/time_table.py new file mode 100644 index 00000000..3bb3611b --- /dev/null +++ b/tqsdk/lib/time_table.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +__author__ = 'mayanqiong' + + +from pandas import DataFrame + + +class TqTimeTable(DataFrame): + + def __init__(self, account=None): + self.__dict__["_account"] = account + self.__dict__["_columns"] = ['interval', 'volume', 'price'] + super(TqTimeTable, self).__init__(data=[], columns=self.__dict__["_columns"]) diff --git a/tqsdk/tradeable/sim/tqsim.py b/tqsdk/tradeable/sim/tqsim.py index 528ec728..8207db2f 100644 --- a/tqsdk/tradeable/sim/tqsim.py +++ b/tqsdk/tradeable/sim/tqsim.py @@ -129,9 +129,14 @@ def set_margin(self, symbol: str, margin: float=float('nan')): # 当用户代码执行到 sim.set_margin(),立即向 quote_chan 中发送一个数据包,quote_task 就会到 ready 状态,此时调用 wait_update(), # 到所有 task 执行到 pending 状态时,sim 的 diffs 中有数据了,此时收到 api 发来 peek_message 不会转发给上游,用户会先收到 sim 本身的账户数据, # 在下一次 wait_update,sim 的 diffs 为空,才会收到行情数据 + # 3. 20240322 增加,用户拿到的 _sync_diffs 不应该有丢失: + # * api 中区分了 _diffs (每次调用 wait_update 都会更新) 和 _sync_diffs (仅在同步代码,或者说用户代码调用 wait_update 时更新) + # * 用户在调用 set_margin 之后,如果立即调用 is_changing,会使用 _sync_diffs 判断变更 + # * 如果中间这一次调用 wait_update() api._sync_diffs 会丢失变更 + # * 如果中间这一次调用 wait_update(_task=task) api._sync_diffs 不会重置, 就不会丢失变更 # 在回测时,以下代码应该只经历一次 wait_update - while margin != self.get_position(symbol).get("future_margin"): - self._api.wait_update() + cond = lambda: margin == self.get_position(symbol).get("future_margin") + self._api._wait_update_until(cond=cond) return margin def get_margin(self, symbol: str):