diff --git a/tests/test_trading_calendar.py b/tests/test_trading_calendar.py index 2e0cd526..00034724 100644 --- a/tests/test_trading_calendar.py +++ b/tests/test_trading_calendar.py @@ -571,10 +571,28 @@ def test_minutes_for_period(self): _open, _close = self.calendar.open_and_close_for_session( full_session_label ) + _break_start, _break_end = ( + self.calendar.break_start_and_end_for_session( + full_session_label + ) + ) + if not pd.isnull(_break_start): + constructed_minutes = np.concatenate([ + pd.date_range( + start=_open, end=_break_start, freq="min" + ), + pd.date_range( + start=_break_end, end=_close, freq="min" + ) + ]) + else: + constructed_minutes = pd.date_range( + start=_open, end=_close, freq="min" + ) np.testing.assert_array_equal( minutes, - pd.date_range(start=_open, end=_close, freq="min") + constructed_minutes, ) # early close period @@ -679,23 +697,48 @@ def test_minutes_in_range(self): np.testing.assert_array_equal(minutes1, minutes2[1:-1]) # manually construct the minutes - all_minutes = np.concatenate([ - pd.date_range( - start=first_open, - end=first_close, - freq="min" - ), - pd.date_range( - start=middle_open, - end=middle_close, - freq="min" - ), - pd.date_range( - start=last_open, - end=last_close, - freq="min" - ) - ]) + first_break_start, first_break_end = ( + self.calendar.break_start_and_end_for_session(sessions[0]) + ) + middle_break_start, middle_break_end = ( + self.calendar.break_start_and_end_for_session(sessions[1]) + ) + last_break_start, last_break_end = ( + self.calendar.break_start_and_end_for_session(sessions[-1]) + ) + + intervals = [ + (first_open, first_break_start, first_break_end, first_close), + (middle_open, middle_break_start, middle_break_end, middle_close), + (last_open, last_break_start, last_break_end, last_close), + ] + all_minutes = [] + + for _open, _break_start, _break_end, _close in intervals: + if pd.isnull(_break_start): + all_minutes.append( + pd.date_range( + start=_open, + end=_close, + freq="min" + ), + ) + else: + all_minutes.append( + pd.date_range( + start=_open, + end=_break_start, + freq="min" + ), + ) + all_minutes.append( + pd.date_range( + start=_break_end, + end=_close, + freq="min" + ), + ) + all_minutes = np.concatenate(all_minutes) np.testing.assert_array_equal(all_minutes, minutes1) diff --git a/tests/test_xhkg_calendar.py b/tests/test_xhkg_calendar.py index 91b98a4e..207e60b2 100644 --- a/tests/test_xhkg_calendar.py +++ b/tests/test_xhkg_calendar.py @@ -45,6 +45,38 @@ def test_constrain_construction_dates(self): ), ) + def test_session_break(self): + # Test that the calendar correctly reports itself as closed during + # session break + normal_minute = pd.Timestamp('2003-01-27 03:30:00') + break_minute = pd.Timestamp('2003-01-27 04:30:00') + + self.assertTrue(self.calendar.is_open_on_minute(normal_minute)) + self.assertFalse(self.calendar.is_open_on_minute(break_minute)) + # Make sure that ignoring breaks indicates the exchange is open + self.assertTrue( + self.calendar.is_open_on_minute(break_minute, ignore_breaks=True) + ) + + current_session_label = self.calendar.minute_to_session_label( + normal_minute, + direction="none" + ) + self.assertEqual( + current_session_label, + self.calendar.minute_to_session_label( + break_minute, + direction="previous" + ) + ) + self.assertEqual( + current_session_label, + self.calendar.minute_to_session_label( + break_minute, + direction="next" + ) + ) + def test_lunar_new_year_2003(self): # NOTE: Lunar Month 12 2002 is the 12th month of the lunar year that # begins in 2002; this month actually takes place in January 2003. diff --git a/trading_calendars/calendar_helpers.py b/trading_calendars/calendar_helpers.py index 74726a98..08fb6783 100644 --- a/trading_calendars/calendar_helpers.py +++ b/trading_calendars/calendar_helpers.py @@ -1,7 +1,10 @@ import numpy as np +import pandas as pd NANOSECONDS_PER_MINUTE = int(6e10) +NP_NAT = np.array([pd.NaT], dtype=np.int64)[0] + def next_divider_idx(dividers, minute_val): @@ -25,47 +28,43 @@ def previous_divider_idx(dividers, minute_val): return divider_idx - 1 -def is_open(opens, closes, minute_val): - - open_idx = np.searchsorted(opens, minute_val) - close_idx = np.searchsorted(closes, minute_val) - - if open_idx != close_idx: - # if the indices are not same, that means the market is open - return True - else: - try: - # if they are the same, it might be the first minute of a - # session - return minute_val == opens[open_idx] - except IndexError: - # this can happen if we're outside the schedule's range (like - # after the last close) - return False - - -def compute_all_minutes(opens_in_ns, closes_in_ns): +def compute_all_minutes( + opens_in_ns, break_starts_in_ns, break_ends_in_ns, closes_in_ns, +): """ - Given arrays of opens and closes, both in nanoseconds, - return an array of each minute between the opens and closes. - """ - deltas = closes_in_ns - opens_in_ns - - # + 1 because we want 390 mins per standard day, not 389 - daily_sizes = (deltas // NANOSECONDS_PER_MINUTE) + 1 - num_minutes = daily_sizes.sum() + Given arrays of opens and closes (in nanoseconds) and optionally + break_starts and break ends, return an array of each minute between the + opens and closes. - # One allocation for the entire thing. This assumes that each day - # represents a contiguous block of minutes. + NOTE: Add an extra minute to ending boundaries (break_start and close) + so we include the last bar (arange doesn't include its stop). + """ pieces = [] - - for open_, size in zip(opens_in_ns, daily_sizes): - pieces.append( - np.arange(open_, - open_ + size * NANOSECONDS_PER_MINUTE, - NANOSECONDS_PER_MINUTE) - ) - - out = np.concatenate(pieces).view('datetime64[ns]') - assert len(out) == num_minutes + for open_time, break_start_time, break_end_time, close_time in zip( + opens_in_ns, break_starts_in_ns, break_ends_in_ns, closes_in_ns + ): + if break_start_time != NP_NAT: + pieces.append( + np.arange( + open_time, + break_start_time + NANOSECONDS_PER_MINUTE, + NANOSECONDS_PER_MINUTE, + ) + ) + pieces.append( + np.arange( + break_end_time, + close_time + NANOSECONDS_PER_MINUTE, + NANOSECONDS_PER_MINUTE, + ) + ) + else: + pieces.append( + np.arange( + open_time, + close_time + NANOSECONDS_PER_MINUTE, + NANOSECONDS_PER_MINUTE, + ) + ) + out = np.concatenate(pieces).view("datetime64[ns]") return out diff --git a/trading_calendars/exchange_calendar_xhkg.py b/trading_calendars/exchange_calendar_xhkg.py index d95bcb1f..27a7e9a6 100644 --- a/trading_calendars/exchange_calendar_xhkg.py +++ b/trading_calendars/exchange_calendar_xhkg.py @@ -246,6 +246,7 @@ class XHKGExchangeCalendar(TradingCalendar): Exchange calendar for the Hong Kong Stock Exchange (XHKG). Open Time: 9:31 AM, Asia/Hong_Kong + Lunch Break: 12:01 PM - 1:00 PM Asia/Hong_Kong Close Time: 4:00 PM, Asia/Hong_Kong Regularly-Observed Holidays: @@ -279,6 +280,12 @@ class XHKGExchangeCalendar(TradingCalendar): (None, time(10, 1)), (pd.Timestamp('2011-03-07'), time(9, 31)), ) + break_start_times = ( + (None, time(12, 1)), + ) + break_end_times = ( + (None, time(13, 0)), + ) close_times = ( (None, time(16)), ) diff --git a/trading_calendars/trading_calendar.py b/trading_calendars/trading_calendar.py index 163ee3bf..0681292e 100644 --- a/trading_calendars/trading_calendar.py +++ b/trading_calendars/trading_calendar.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABCMeta, abstractproperty +from collections import OrderedDict import warnings -from operator import attrgetter from pandas.tseries.holiday import AbstractHolidayCalendar from six import with_metaclass from pytz import UTC @@ -31,14 +31,13 @@ import toolz from .calendar_helpers import ( + NP_NAT, compute_all_minutes, - is_open, next_divider_idx, previous_divider_idx, ) from .utils.memoize import lazyval from .utils.pandas_utils import days_at_time -from .utils.preprocess import preprocess, coerce start_default = pd.Timestamp('1990-01-01', tz=UTC) @@ -66,7 +65,9 @@ def selection(arr, start, end): return arr[np.all(predicates, axis=0)] -def _group_times(all_days, times, tz, offset): +def _group_times(all_days, times, tz, offset=0): + if times is None: + return None elements = [ days_at_time( selection(all_days, start, end), @@ -114,6 +115,17 @@ def __init__(self, start=start_default, end=end_default): self.tz, self.open_offset, ) + self._break_starts = _group_times( + _all_days, + self.break_start_times, + self.tz, + ) + self._break_ends = _group_times( + _all_days, + self.break_end_times, + self.tz, + ) + self._closes = _group_times( _all_days, self.close_times, @@ -129,17 +141,25 @@ def __init__(self, start=start_default, end=end_default): # Overwrite the special opens and closes on top of the standard ones. _overwrite_special_dates(_all_days, self._opens, _special_opens) _overwrite_special_dates(_all_days, self._closes, _special_closes) + _remove_breaks_for_special_dates( + _all_days, + self._break_starts, + _special_closes, + ) + _remove_breaks_for_special_dates( + _all_days, + self._break_ends, + _special_closes, + ) - # In pandas 0.16.1 _opens and _closes will lose their timezone - # information. This looks like it has been resolved in 0.17.1. - # http://pandas.pydata.org/pandas-docs/stable/whatsnew.html#datetime-with-tz # noqa self.schedule = DataFrame( index=_all_days, - columns=['market_open', 'market_close'], - data={ - 'market_open': self._opens, - 'market_close': self._closes, - }, + data=OrderedDict([ + ('market_open', self._opens), + ('break_start', self._break_starts), + ('break_end', self._break_ends), + ('market_close', self._closes), + ]), dtype='datetime64[ns]', ) @@ -152,9 +172,20 @@ def __init__(self, start=start_default, end=end_default): self.market_opens_nanos = self.schedule.market_open.values.\ astype(np.int64) + self.market_break_starts_nanos = self.schedule.break_start.values.\ + astype(np.int64) + + self.market_break_ends_nanos = self.schedule.break_end.values.\ + astype(np.int64) + self.market_closes_nanos = self.schedule.market_close.values.\ astype(np.int64) + _check_breaks_match( + self.market_break_starts_nanos, + self.market_break_ends_nanos + ) + self._trading_minutes_nanos = self.all_minutes.values.\ astype(np.int64) @@ -193,6 +224,24 @@ def open_times(self): """ raise NotImplementedError() + @property + def break_start_times(self): + """ + Returns a optional list of tuples of (start_date, break_start_time). + If the break start time is constant throughout the calendar, use None + for the start_date. If there is no break, return `None`. + """ + return None + + @property + def break_end_times(self): + """ + Returns a optional list of tuples of (start_date, break_end_time). If + the break end time is constant throughout the calendar, use None for + the start_date. If there is no break, return `None`. + """ + return None + @abstractproperty def close_times(self): """ @@ -224,8 +273,13 @@ def close_offset(self): @lazyval def _minutes_per_session(self): - diff = self.schedule.market_close - self.schedule.market_open - diff = diff.astype('timedelta64[m]') + close_to_open_diff = ( + self.schedule.market_close - self.schedule.market_open + ) + break_diff = ( + self.schedule.break_end - self.schedule.break_start + ).fillna(pd.Timedelta(seconds=0)) + diff = (close_to_open_diff - break_diff).astype("timedelta64[m]") return diff + 1 def minutes_count_for_sessions_in_range(self, start_session, end_session): @@ -340,22 +394,53 @@ def is_session(self, dt): """ return dt in self.schedule.index - def is_open_on_minute(self, dt): + def is_open_on_minute(self, dt, ignore_breaks=False): """ Given a dt, return whether this exchange is open at the given dt. Parameters ---------- - dt: pd.Timestamp + dt : pd.Timestamp or nanosecond offset The dt for which to check if this exchange is open. + ignore_breaks: bool + Whether to consider midday breaks when determining if an exchange + is open. Returns ------- bool Whether the exchange is open on this dt. """ - return is_open(self.market_opens_nanos, self.market_closes_nanos, - dt.value) + if isinstance(dt, pd.Timestamp): + dt = dt.value + + open_idx = np.searchsorted(self.market_opens_nanos, dt) + close_idx = np.searchsorted(self.market_closes_nanos, dt) + + # if the indices are not same, that means we are within a session + if open_idx != close_idx: + if ignore_breaks: + return True + + break_start_on_open_dt = \ + self.market_break_starts_nanos[open_idx - 1] + break_end_on_open_dt = self.market_break_ends_nanos[open_idx - 1] + # NaT comparisions will result in False + if break_start_on_open_dt <= dt < break_end_on_open_dt: + # we're in the middle of a break + return False + else: + return True + + else: + try: + # if they are the same, it might be the first minute of a + # session + return dt == self.market_opens_nanos[open_idx] + except IndexError: + # this can happen if we're outside the schedule's range (like + # after the last close) + return False def next_open(self, dt): """ @@ -755,14 +840,29 @@ def open_and_close_for_session(self, session_label): (Timestamp, Timestamp) The open and close for the given session. """ - sched = self.schedule + return ( + self.session_open(session_label), + self.session_close(session_label) + ) + + def break_start_and_end_for_session(self, session_label): + """ + Returns a tuple of timestamps of the break start and end of the session + represented by the given label. + + Parameters + ---------- + session_label: pd.Timestamp + The session whose break start and end are desired. - # `market_open` and `market_close` should be timezone aware, but pandas - # 0.16.1 does not appear to support this: - # http://pandas.pydata.org/pandas-docs/stable/whatsnew.html#datetime-with-tz # noqa + Returns + ------- + (Timestamp, Timestamp) + The break start and end for the given session. + """ return ( - sched.at[session_label, 'market_open'].tz_localize(UTC), - sched.at[session_label, 'market_close'].tz_localize(UTC), + self.session_break_start(session_label), + self.session_break_end(session_label) ) def session_open(self, session_label): @@ -771,6 +871,28 @@ def session_open(self, session_label): 'market_open' ].tz_localize(UTC) + def session_break_start(self, session_label): + break_start = self.schedule.at[ + session_label, + 'break_start' + ] + if not pd.isnull(break_start): + # older versions of pandas need this guard + break_start = break_start.tz_localize(UTC) + + return break_start + + def session_break_end(self, session_label): + break_end = self.schedule.at[ + session_label, + 'break_end' + ] + if not pd.isnull(break_end): + # older versions of pandas need this guard + break_end = break_end.tz_localize(UTC) + + return break_end + def session_close(self, session_label): return self.schedule.at[ session_label, @@ -812,20 +934,16 @@ def all_minutes(self): """ Returns a DatetimeIndex representing all the minutes in this calendar. """ - opens_in_ns = self._opens.values.astype( - 'datetime64[ns]', - ).view('int64') - - closes_in_ns = self._closes.values.astype( - 'datetime64[ns]', - ).view('int64') - return DatetimeIndex( - compute_all_minutes(opens_in_ns, closes_in_ns), + compute_all_minutes( + self.market_opens_nanos, + self.market_break_starts_nanos, + self.market_break_ends_nanos, + self.market_closes_nanos, + ), tz=UTC, ) - @preprocess(dt=coerce(pd.Timestamp, attrgetter('value'))) def minute_to_session_label(self, dt, direction="next"): """ Given a minute, get the label of its containing session. @@ -850,6 +968,9 @@ def minute_to_session_label(self, dt, direction="next"): pd.Timestamp (midnight UTC) The label of the containing session. """ + if isinstance(dt, pd.Timestamp): + dt = dt.value + if direction == "next": if self._minute_to_session_label_cache[0] == dt: return self._minute_to_session_label_cache[1] @@ -858,16 +979,15 @@ def minute_to_session_label(self, dt, direction="next"): current_or_next_session = self.schedule.index[idx] if direction == "next": - self._minute_to_session_label_cache = (dt, current_or_next_session) + self._minute_to_session_label_cache = ( + dt, current_or_next_session + ) return current_or_next_session elif direction == "previous": - if not is_open(self.market_opens_nanos, self.market_closes_nanos, - dt): - # if the exchange is closed, use the previous session + if not self.is_open_on_minute(dt, ignore_breaks=True): return self.schedule.index[idx - 1] elif direction == "none": - if not is_open(self.market_opens_nanos, self.market_closes_nanos, - dt): + if not self.is_open_on_minute(dt): # if the exchange is closed, blow up raise ValueError("The given dt is not an exchange minute!") else: @@ -896,7 +1016,6 @@ def minute_index_to_session_labels(self, index): raise ValueError( "Non-ordered index passed to minute_index_to_session_labels." ) - # Find the indices of the previous open and the next close for each # minute. prev_opens = ( @@ -914,8 +1033,16 @@ def minute_index_to_session_labels(self, index): example = index[bad_ix] prev_day = prev_opens[bad_ix] - prev_open, prev_close = self.schedule.iloc[prev_day] - next_open, next_close = self.schedule.iloc[prev_day + 1] + prev_open, prev_close = ( + self.schedule.iloc[prev_day].loc[ + ['market_open', 'market_close'] + ] + ) + next_open, next_close = ( + self.schedule.iloc[prev_day + 1].loc[ + ['market_open', 'market_close'] + ] + ) raise ValueError( "{num} non-market minutes in minute_index_to_session_labels:\n" @@ -1002,6 +1129,32 @@ def _calculate_special_closes(self, start, end): ) +def _check_breaks_match(market_break_starts_nanos, market_break_ends_nanos): + """Checks that market_break_starts_nanos and market_break_ends_nanos + + Parameters + ---------- + market_break_starts_nanos : np.ndarray + market_break_ends_nanos : np.ndarray + """ + nats_match = np.equal( + NP_NAT == market_break_starts_nanos, + NP_NAT == market_break_ends_nanos + ) + if not nats_match.all(): + raise ValueError( + """ + Mismatched market breaks + Break starts: + %s + Break ends: + %s + """, + market_break_starts_nanos[~nats_match], + market_break_ends_nanos[~nats_match] + ) + + def scheduled_special_times(calendar, start, end, time, tz): """ Returns a Series mapping each holiday (as a UTC midnight Timestamp) @@ -1015,27 +1168,28 @@ def scheduled_special_times(calendar, start, end, time, tz): ) -def _overwrite_special_dates(midnight_utcs, - opens_or_closes, - special_opens_or_closes): +def _overwrite_special_dates( + session_labels, opens_or_closes, special_opens_or_closes +): """ Overwrite dates in open_or_closes with corresponding dates in - special_opens_or_closes, using midnight_utcs for alignment. + special_opens_or_closes, using session_labels for alignment. """ # Short circuit when nothing to apply. if not len(special_opens_or_closes): return - len_m, len_oc = len(midnight_utcs), len(opens_or_closes) + len_m, len_oc = len(session_labels), len(opens_or_closes) if len_m != len_oc: raise ValueError( "Found misaligned dates while building calendar.\n" - "Expected midnight_utcs to be the same length as open_or_closes,\n" - "but len(midnight_utcs)=%d, len(open_or_closes)=%d" % len_m, len_oc + "Expected session_labels to be the same length as " + "open_or_closes but,\n" + "len(session_labels)=%d, len(open_or_closes)=%d" % (len_m, len_oc) ) # Find the array indices corresponding to each special date. - indexer = midnight_utcs.get_indexer(special_opens_or_closes.index) + indexer = session_labels.get_indexer(special_opens_or_closes.index) # -1 indicates that no corresponding entry was found. If any -1s are # present, then we have special dates that doesn't correspond to any @@ -1051,6 +1205,47 @@ def _overwrite_special_dates(midnight_utcs, opens_or_closes.values[indexer] = special_opens_or_closes.values +def _remove_breaks_for_special_dates( + session_labels, break_start_or_end, special_opens_or_closes +): + """ + Overwrite breaks in break_start_or_end with corresponding dates in + special_opens_or_closes, using session_labels for alignment. + """ + # Short circuit when we have no breaks + if break_start_or_end is None: + return + + # Short circuit when nothing to apply. + if not len(special_opens_or_closes): + return + + len_m, len_oc = len(session_labels), len(break_start_or_end) + if len_m != len_oc: + raise ValueError( + "Found misaligned dates while building calendar.\n" + "Expected session_labels to be the same length as break_starts,\n" + "but len(session_labels)=%d, len(break_start_or_end)=%d" + % (len_m, len_oc) + ) + + # Find the array indices corresponding to each special date. + indexer = session_labels.get_indexer(special_opens_or_closes.index) + + # -1 indicates that no corresponding entry was found. If any -1s are + # present, then we have special dates that doesn't correspond to any + # trading day. + if -1 in indexer: + bad_dates = list(special_opens_or_closes[indexer == -1]) + raise ValueError("Special dates %s are not trading days." % bad_dates) + + # NOTE: This is a slightly dirty hack. We're in-place overwriting the + # internal data of an Index, which is conceptually immutable. Since we're + # maintaining sorting, this should be ok, but this is a good place to + # sanity check if things start going haywire with calendar computations. + break_start_or_end.values[indexer] = NP_NAT + + class HolidayCalendar(AbstractHolidayCalendar): def __init__(self, rules): super(HolidayCalendar, self).__init__(rules=rules) diff --git a/trading_calendars/utils/compat.py b/trading_calendars/utils/compat.py deleted file mode 100644 index de7257df..00000000 --- a/trading_calendars/utils/compat.py +++ /dev/null @@ -1,107 +0,0 @@ -import functools -from operator import methodcaller -import sys - -from six import PY2 - - -if PY2: - from abc import ABCMeta - from types import DictProxyType - from ctypes import py_object, pythonapi - - _new_mappingproxy = pythonapi.PyDictProxy_New - _new_mappingproxy.argtypes = [py_object] - _new_mappingproxy.restype = py_object - - # Make mappingproxy a "class" so that we can use multipledispatch - # with it or do an ``isinstance(ob, mappingproxy)`` check in Python 2. - # You will never actually get an instance of this object, you will just - # get instances of ``types.DictProxyType``; however, ``mappingproxy`` is - # registered as a virtual super class so ``isinstance`` and ``issubclass`` - # will work as expected. The only thing that will appear strange is that: - # ``type(mappingproxy({})) is not mappingproxy``, but you shouldn't do - # that. - class mappingproxy(object): - __metaclass__ = ABCMeta - - def __new__(cls, *args, **kwargs): - return _new_mappingproxy(*args, **kwargs) - - mappingproxy.register(DictProxyType) - - # clear names not imported in the other branch - del DictProxyType - del ABCMeta - del py_object - del pythonapi - - def exc_clear(): - sys.exc_clear() - - def update_wrapper(wrapper, - wrapped, - assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - """Backport of Python 3's functools.update_wrapper for __wrapped__. - """ - for attr in assigned: - try: - value = getattr(wrapped, attr) - except AttributeError: - pass - else: - setattr(wrapper, attr, value) - for attr in updated: - getattr(wrapper, attr).update(getattr(wrapped, attr, {})) - # Issue #17482: set __wrapped__ last so we don't inadvertently copy it - # from the wrapped function when updating __dict__ - wrapper.__wrapped__ = wrapped - # Return the wrapper so this can be used as a decorator via partial() - return wrapper - - def wraps(wrapped, - assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - """Decorator factory to apply update_wrapper() to a wrapper function - - Returns a decorator that invokes update_wrapper() with the decorated - function as the wrapper argument and the arguments to wraps() as the - remaining arguments. Default arguments are as for update_wrapper(). - This is a convenience function to simplify applying partial() to - update_wrapper(). - """ - return functools.partial(update_wrapper, wrapped=wrapped, - assigned=assigned, updated=updated) - - values_as_list = methodcaller('values') - -else: - from types import MappingProxyType as mappingproxy - - def exc_clear(): - # exc_clear was removed in Python 3. The except statement automatically - # clears the exception. - pass - - update_wrapper = functools.update_wrapper - wraps = functools.wraps - - def values_as_list(dictionary): - """Return the dictionary values as a list without forcing a copy - in Python 2. - """ - return list(dictionary.values()) - - -unicode = type(u'') - -__all__ = [ - 'PY2', - 'exc_clear', - 'mappingproxy', - 'unicode', - 'update_wrapper', - 'values_as_list', - 'wraps', -] diff --git a/trading_calendars/utils/preprocess.py b/trading_calendars/utils/preprocess.py deleted file mode 100644 index fb0f8ce6..00000000 --- a/trading_calendars/utils/preprocess.py +++ /dev/null @@ -1,294 +0,0 @@ -""" -Utilities for validating inputs to user-facing API functions. - -Note ----- -This file is lifted from zipline. Figure out how to dedupe this -stuff later on. -""" -from textwrap import dedent -from types import CodeType -from inspect import getargspec -from uuid import uuid4 - -from toolz.curried.operator import getitem -from six import viewkeys, exec_ - -from trading_calendars.utils.compat import wraps - - -_code_argorder = ( - ('co_argcount',) -) + ( - ('co_kwonlyargcount', ) if hasattr(CodeType, 'co_kwonlyargcount') else () -) + ( - ('co_posonlyargcount', ) if hasattr(CodeType, 'co_posonlyargcount') else () -) + ( - 'co_nlocals', - 'co_stacksize', - 'co_flags', - 'co_code', - 'co_consts', - 'co_names', - 'co_varnames', - 'co_filename', - 'co_name', - 'co_firstlineno', - 'co_lnotab', - 'co_freevars', - 'co_cellvars', -) - -NO_DEFAULT = object() - - -def preprocess(*_unused, **processors): - """ - Decorator that applies pre-processors to the arguments of a function before - calling the function. - - Parameters - ---------- - **processors : dict - Map from argument name -> processor function. - - A processor function takes three arguments: (func, argname, argvalue). - - `func` is the the function for which we're processing args. - `argname` is the name of the argument we're processing. - `argvalue` is the value of the argument we're processing. - - Examples - -------- - >>> def _ensure_tuple(func, argname, arg): - ... if isinstance(arg, tuple): - ... return argvalue - ... try: - ... return tuple(arg) - ... except TypeError: - ... raise TypeError( - ... "%s() expected argument '%s' to" - ... " be iterable, but got %s instead." % ( - ... func.__name__, argname, arg, - ... ) - ... ) - ... - >>> @preprocess(arg=_ensure_tuple) - ... def foo(arg): - ... return arg - ... - >>> foo([1, 2, 3]) - (1, 2, 3) - >>> foo("a") - ('a',) - >>> foo(2) - Traceback (most recent call last): - ... - TypeError: foo() expected argument 'arg' to be iterable, but got 2 instead. - """ - if _unused: - raise TypeError("preprocess() doesn't accept positional arguments") - - def _decorator(f): - args, varargs, varkw, defaults = argspec = getargspec(f) - if defaults is None: - defaults = () - no_defaults = (NO_DEFAULT,) * (len(args) - len(defaults)) - args_defaults = list(zip(args, no_defaults + defaults)) - if varargs: - args_defaults.append((varargs, NO_DEFAULT)) - if varkw: - args_defaults.append((varkw, NO_DEFAULT)) - - argset = set(args) | {varargs, varkw} - {None} - - # Arguments can be declared as tuples in Python 2. - if not all(isinstance(arg, str) for arg in args): - raise TypeError( - "Can't validate functions using tuple unpacking: %s" % - (argspec,) - ) - - # Ensure that all processors map to valid names. - bad_names = viewkeys(processors) - argset - if bad_names: - raise TypeError( - "Got processors for unknown arguments: %s." % bad_names - ) - - return _build_preprocessed_function( - f, processors, args_defaults, varargs, varkw, - ) - return _decorator - - -def call(f): - """ - Wrap a function in a processor that calls `f` on the argument before - passing it along. - - Useful for creating simple arguments to the `@preprocess` decorator. - - Parameters - ---------- - f : function - Function accepting a single argument and returning a replacement. - - Examples - -------- - >>> @preprocess(x=call(lambda x: x + 1)) - ... def foo(x): - ... return x - ... - >>> foo(1) - 2 - """ - @wraps(f) - def processor(func, argname, arg): - return f(arg) - return processor - - -def _build_preprocessed_function(func, - processors, - args_defaults, - varargs, - varkw): - """ - Build a preprocessed function with the same signature as `func`. - - Uses `exec` internally to build a function that actually has the same - signature as `func. - """ - format_kwargs = {'func_name': func.__name__} - - def mangle(name): - return 'a' + uuid4().hex + name - - format_kwargs['mangled_func'] = mangled_funcname = mangle(func.__name__) - - def make_processor_assignment(arg, processor_name): - template = "{arg} = {processor}({func}, '{arg}', {arg})" - return template.format( - arg=arg, - processor=processor_name, - func=mangled_funcname, - ) - - exec_globals = {mangled_funcname: func, 'wraps': wraps} - defaults_seen = 0 - default_name_template = 'a' + uuid4().hex + '_%d' - signature = [] - call_args = [] - assignments = [] - star_map = { - varargs: '*', - varkw: '**', - } - - def name_as_arg(arg): - return star_map.get(arg, '') + arg - - for arg, default in args_defaults: - if default is NO_DEFAULT: - signature.append(name_as_arg(arg)) - else: - default_name = default_name_template % defaults_seen - exec_globals[default_name] = default - signature.append('='.join([name_as_arg(arg), default_name])) - defaults_seen += 1 - - if arg in processors: - procname = mangle('_processor_' + arg) - exec_globals[procname] = processors[arg] - assignments.append(make_processor_assignment(arg, procname)) - - call_args.append(name_as_arg(arg)) - - exec_str = dedent( - """\ - @wraps({wrapped_funcname}) - def {func_name}({signature}): - {assignments} - return {wrapped_funcname}({call_args}) - """ - ).format( - func_name=func.__name__, - signature=', '.join(signature), - assignments='\n '.join(assignments), - wrapped_funcname=mangled_funcname, - call_args=', '.join(call_args), - ) - compiled = compile( - exec_str, - func.__code__.co_filename, - mode='exec', - ) - - exec_locals = {} - exec_(compiled, exec_globals, exec_locals) - new_func = exec_locals[func.__name__] - - code = new_func.__code__ - args = { - attr: getattr(code, attr) - for attr in dir(code) - if attr.startswith('co_') - } - # Copy the firstlineno out of the underlying function so that exceptions - # get raised with the correct traceback. - # This also makes dynamic source inspection (like IPython `??` operator) - # work as intended. - try: - # Try to get the pycode object from the underlying function. - original_code = func.__code__ - except AttributeError: - try: - # The underlying callable was not a function, try to grab the - # `__func__.__code__` which exists on method objects. - original_code = func.__func__.__code__ - except AttributeError: - # The underlying callable does not have a `__code__`. There is - # nothing for us to correct. - return new_func - - args['co_firstlineno'] = original_code.co_firstlineno - new_func.__code__ = CodeType(*map(getitem(args), _code_argorder)) - return new_func - - -def coerce(from_, to, **to_kwargs): - """ - A preprocessing decorator that coerces inputs of a given type by passing - them to a callable. - - Parameters - ---------- - from : type or tuple or types - Inputs types on which to call ``to``. - to : function - Coercion function to call on inputs. - **to_kwargs - Additional keywords to forward to every call to ``to``. - - Examples - -------- - >>> @preprocess(x=coerce(float, int), y=coerce(float, int)) - ... def floordiff(x, y): - ... return x - y - ... - >>> floordiff(3.2, 2.5) - 1 - - >>> @preprocess(x=coerce(str, int, base=2), y=coerce(str, int, base=2)) - ... def add_binary_strings(x, y): - ... return bin(x + y)[2:] - ... - >>> add_binary_strings('101', '001') - '110' - """ - def preprocessor(func, argname, arg): - if isinstance(arg, from_): - return to(arg, **to_kwargs) - return arg - return preprocessor