Skip to content

Commit 943c99e

Browse files
committed
Initial support for polars.Series and polars.DataFrame.
1 parent e168147 commit 943c99e

8 files changed

+144
-19
lines changed

CHANGELOG

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
- [NEW]: Upgrade to Cython 0.29.24
55

6+
- [NEW]: Support polars.Series in Function API
7+
8+
- [NEW]: Support polars.Series in Streaming API
9+
10+
- [NEW]: Support polars.DataFrame in Abstract API
11+
612
0.4.21
713
======
814

README.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ will have an initial "lookback" period (a required number of observations
233233
before an output is generated) set to ``NaN``.
234234

235235
For convenience, the Function API supports both ``numpy.ndarray`` and
236-
``pandas.Series`` inputs.
236+
``pandas.Series`` and ``polars.Series`` inputs.
237237

238238
All of the following examples use the Function API:
239239

@@ -270,9 +270,10 @@ If you're already familiar with using the function API, you should feel right
270270
at home using the Abstract API.
271271

272272
Every function takes a collection of named inputs, either a ``dict`` of
273-
``numpy.ndarray`` or ``pandas.Series``, or a ``pandas.DataFrame``. If a
274-
``pandas.DataFrame`` is provided, the output is returned as a
275-
``pandas.DataFrame`` with named output columns.
273+
``numpy.ndarray`` or ``pandas.Series`` or ``polars.Series``, or a
274+
``pandas.DataFrame`` or ``polars.DataFrame``. If a ``pandas.DataFrame`` or
275+
``polars.DataFrame`` is provided, the output is returned as the same type
276+
with named output columns.
276277

277278
For example, inputs could be provided for the typical "OHLCV" data:
278279

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
numpy>=1.19.4
2-
Cython>=0.29.21
1+
numpy
2+
Cython

requirements_dev.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
-r requirements.txt
2-
beautifulsoup4>=4.9.3
3-
mistune>=0.8.4
4-
Pygments>=2.7.4
2+
beautifulsoup4
3+
mistune
4+
Pygments

requirements_test.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
-r requirements.txt
2-
pandas>=1.1.4
3-
nose>=1.3.7
2+
pandas
3+
nose
4+
polars

talib/__init__.py

+47
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,39 @@ def wrapper(*args, **kwargs):
4949

5050
return wrapper
5151

52+
# If polars is available, wrap talib functions so that they support
53+
# polars.Series input
54+
try:
55+
from polars import Series as _pl_Series
56+
except ImportError:
57+
# polars not available, nothing to wrap
58+
_polars_wrapper = lambda x: x
59+
else:
60+
def _polars_wrapper(func):
61+
@wraps(func)
62+
def wrapper(*args, **kwargs):
63+
# Use Series' float64 values if pandas, else use values as passed
64+
args = [arg.to_numpy().astype(float) if isinstance(arg, _pl_Series) else arg
65+
for arg in args]
66+
kwargs = {k: v.to_numpy().astype(float) if isinstance(v, _pl_Series) else v
67+
for k, v in kwargs.items()}
68+
69+
result = func(*args, **kwargs)
70+
71+
# check to see if we got a streaming result
72+
first_result = result[0] if isinstance(result, tuple) else result
73+
is_streaming_fn_result = not hasattr(first_result, '__len__')
74+
if is_streaming_fn_result:
75+
return result
76+
77+
# Series was passed in, Series gets out
78+
if isinstance(result, tuple):
79+
# Handle multi-array results such as BBANDS
80+
return tuple(_pl_Series(arr) for arr in result)
81+
return _pl_Series(result)
82+
83+
return wrapper
84+
5285
from ._ta_lib import (
5386
_ta_initialize, _ta_shutdown, MA_Type, __ta_version__,
5487
_ta_set_unstable_period as set_unstable_period,
@@ -75,6 +108,20 @@ def wrapper(*args, **kwargs):
75108
setattr(stream, func_name, wrapped_func)
76109
globals()[stream_func_name] = wrapped_func
77110

111+
# wrap them with polars
112+
func = __import__("_ta_lib", globals(), locals(), __TA_FUNCTION_NAMES__, level=1)
113+
for func_name in __TA_FUNCTION_NAMES__:
114+
wrapped_func = _polars_wrapper(getattr(func, func_name))
115+
setattr(func, func_name, wrapped_func)
116+
globals()[func_name] = wrapped_func
117+
118+
stream_func_names = ['stream_%s' % fname for fname in __TA_FUNCTION_NAMES__]
119+
stream = __import__("stream", globals(), locals(), stream_func_names, level=1)
120+
for func_name, stream_func_name in zip(__TA_FUNCTION_NAMES__, stream_func_names):
121+
wrapped_func = _polars_wrapper(getattr(stream, func_name))
122+
setattr(stream, func_name, wrapped_func)
123+
globals()[stream_func_name] = wrapped_func
124+
78125
__version__ = '0.4.22'
79126

80127
# In order to use this python library, talib (i.e. this __file__) will be

talib/_abstract.pxi

+37-8
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,35 @@ __INPUT_PRICE_SERIES_DEFAULTS = {'price': 'close',
2929
'periods': 'periods', # only used by MAVP; not a price series!
3030
}
3131

32+
__INPUT_ARRAYS_TYPES = [dict]
33+
__ARRAY_TYPES = [np.ndarray]
34+
3235
# allow use of pandas.DataFrame for input arrays
3336
try:
3437
import pandas
35-
__INPUT_ARRAYS_TYPES = (dict, pandas.DataFrame)
36-
__ARRAY_TYPES = (np.ndarray, pandas.Series)
38+
__INPUT_ARRAYS_TYPES.append(pandas.DataFrame)
39+
__ARRAY_TYPES.append(pandas.Series)
3740
__PANDAS_DATAFRAME = pandas.DataFrame
3841
__PANDAS_SERIES = pandas.Series
3942
except ImportError:
40-
__INPUT_ARRAYS_TYPES = (dict,)
41-
__ARRAY_TYPES = (np.ndarray,)
4243
__PANDAS_DATAFRAME = None
4344
__PANDAS_SERIES = None
4445

46+
# allow use of polars.DataFrame for input arrays
47+
try:
48+
import polars
49+
__INPUT_ARRAYS_TYPES.append(polars.DataFrame)
50+
__ARRAY_TYPES.append(polars.Series)
51+
__POLARS_DATAFRAME = polars.DataFrame
52+
__POLARS_SERIES = polars.Series
53+
except ImportError:
54+
__POLARS_DATAFRAME = None
55+
__POLARS_SERIES = None
56+
57+
__INPUT_ARRAYS_TYPES = tuple(__INPUT_ARRAYS_TYPES)
58+
__ARRAY_TYPES = tuple(__ARRAY_TYPES)
59+
60+
4561
if sys.version >= '3':
4662

4763
def str2bytes(s):
@@ -64,10 +80,10 @@ class Function(object):
6480
intended to simplify using individual TALIB functions by providing a
6581
unified interface for setting/controlling input data, setting function
6682
parameters and retrieving results. Input data consists of a ``dict`` of
67-
``numpy`` arrays (or a ``pandas.DataFrame``), one array for each of open,
68-
high, low, close and volume. This can be set with the set_input_arrays()
69-
method. Which keyed array(s) are used as inputs when calling the function
70-
is controlled using the input_names property.
83+
``numpy`` arrays (or a ``pandas.DataFrame`` or ``polars.DataFrame``), one
84+
array for each of open, high, low, close and volume. This can be set with
85+
the set_input_arrays() method. Which keyed array(s) are used as inputs when
86+
calling the function is controlled using the input_names property.
7187
7288
This class gets initialized with a TALIB function name and optionally an
7389
input_arrays object. It provides the following primary functions for
@@ -334,6 +350,13 @@ class Function(object):
334350
return __PANDAS_DATAFRAME(numpy.column_stack(ret),
335351
index=index,
336352
columns=self.output_names)
353+
elif __POLARS_DATAFRAME is not None and \
354+
isinstance(self.__input_arrays, __POLARS_DATAFRAME):
355+
if len(ret) == 1:
356+
return __POLARS_SERIES(ret[0])
357+
else:
358+
return __POLARS_DATAFRAME(numpy.column_stack(ret),
359+
columns=self.output_names)
337360
else:
338361
return ret[0] if len(ret) == 1 else ret
339362

@@ -382,6 +405,9 @@ class Function(object):
382405
if __PANDAS_DATAFRAME is not None \
383406
and isinstance(self.__input_arrays, __PANDAS_DATAFRAME):
384407
no_existing_input_arrays = self.__input_arrays.empty
408+
elif __POLARS_DATAFRAME is not None \
409+
and isinstance(self.__input_arrays, __POLARS_DATAFRAME):
410+
no_existing_input_arrays = self.__input_arrays.empty
385411
else:
386412
no_existing_input_arrays = not bool(self.__input_arrays)
387413

@@ -432,6 +458,9 @@ class Function(object):
432458
if __PANDAS_SERIES is not None and \
433459
isinstance(series, __PANDAS_SERIES):
434460
series = series.values.astype(float)
461+
elif __POLARS_SERIES is not None and \
462+
isinstance(series, __POLARS_SERIES):
463+
series = series.to_numpy().astype(float)
435464
args.append(series)
436465
for opt_input in self.__opt_inputs:
437466
value = self.__get_opt_input_value(opt_input)

talib/test_polars.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
import polars as pl
3+
from nose.tools import assert_equals, assert_is_instance, assert_true
4+
5+
import talib
6+
from talib.test_data import series, assert_np_arrays_equal
7+
8+
def test_MOM():
9+
values = pl.Series([90.0,88.0,89.0])
10+
result = talib.MOM(values, timeperiod=1)
11+
assert_is_instance(result, pl.Series)
12+
assert_np_arrays_equal(result.to_numpy(), [np.nan, -2, 1])
13+
result = talib.MOM(values, timeperiod=2)
14+
assert_is_instance(result, pl.Series)
15+
assert_np_arrays_equal(result.to_numpy(), [np.nan, np.nan, -1])
16+
result = talib.MOM(values, timeperiod=3)
17+
assert_is_instance(result, pl.Series)
18+
assert_np_arrays_equal(result.to_numpy(), [np.nan, np.nan, np.nan])
19+
result = talib.MOM(values, timeperiod=4)
20+
assert_is_instance(result, pl.Series)
21+
assert_np_arrays_equal(result.to_numpy(), [np.nan, np.nan, np.nan])
22+
23+
def test_MAVP():
24+
a = pl.Series([1,5,3,4,7,3,8,1,4,6], dtype=pl.Float64)
25+
b = pl.Series([2,4,2,4,2,4,2,4,2,4], dtype=pl.Float64)
26+
result = talib.MAVP(a, b, minperiod=2, maxperiod=4)
27+
assert_is_instance(result, pl.Series)
28+
assert_np_arrays_equal(result.to_numpy(), [np.nan,np.nan,np.nan,3.25,5.5,4.25,5.5,4.75,2.5,4.75])
29+
sma2 = talib.SMA(a, 2)
30+
assert_is_instance(sma2, pl.Series)
31+
assert_np_arrays_equal(result.to_numpy()[4::2], sma2.to_numpy()[4::2])
32+
sma4 = talib.SMA(a, 4)
33+
assert_is_instance(sma4, pl.Series)
34+
assert_np_arrays_equal(result.to_numpy()[3::2], sma4.to_numpy()[3::2])
35+
result = talib.MAVP(a, b, minperiod=2, maxperiod=3)
36+
assert_is_instance(result, pl.Series)
37+
assert_np_arrays_equal(result.to_numpy(), [np.nan,np.nan,4,4,5.5,4.666666666666667,5.5,4,2.5,3.6666666666666665])
38+
sma3 = talib.SMA(a, 3)
39+
assert_is_instance(sma3, pl.Series)
40+
assert_np_arrays_equal(result.to_numpy()[2::2], sma2.to_numpy()[2::2])
41+
assert_np_arrays_equal(result.to_numpy()[3::2], sma3.to_numpy()[3::2])

0 commit comments

Comments
 (0)