Skip to content

Commit f5fda83

Browse files
author
Joe Jevnik
committed
MAINT: remove hard dependency on odo
Odo is currently a hard dependency of warp_prism to convert the sqlalchemy types into a numpy dtypes. Odo is no longer actively maintained and breaks with newer versions of pandas. This change reimplements the needed functionality in warp_prism directly without using odo. This PR does leave the odo edge registration code so that existing users don't see a change in functionality.
1 parent dbd61bf commit f5fda83

File tree

2 files changed

+110
-16
lines changed

2 files changed

+110
-16
lines changed

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,17 @@
4242
),
4343
],
4444
install_requires=[
45-
'datashape',
4645
'numpy',
4746
'pandas',
4847
'sqlalchemy',
4948
'psycopg2',
50-
'odo',
5149
'toolz',
52-
'networkx<=1.11',
5350
],
5451
extras_require={
5552
'dev': [
53+
'odo',
54+
'pandas==0.18.1',
55+
'networkx<=1.11',
5656
'flake8==3.3.0',
5757
'pycodestyle==2.3.1',
5858
'pyflakes==1.5.0',

warp_prism/__init__.py

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from io import BytesIO
2+
import numbers
23

3-
from datashape import discover
4-
from datashape.predicates import istabular
54
import numpy as np
6-
from odo import convert
75
import pandas as pd
86
import sqlalchemy as sa
7+
from sqlalchemy.dialects import postgresql as _postgresql
98
from sqlalchemy.ext.compiler import compiles
10-
from toolz import keymap
119

1210
from ._warp_prism import (
1311
raw_to_arrays as _raw_to_arrays,
@@ -18,7 +16,7 @@
1816
__version__ = '0.1.1'
1917

2018

21-
_typeid_map = keymap(np.dtype, _raw_typeid_map)
19+
_typeid_map = {np.dtype(k): v for k, v in _raw_typeid_map.items()}
2220
_object_type_id = _raw_typeid_map['object']
2321

2422

@@ -66,14 +64,107 @@ def _compile_copy_to_binary_postgres(element, compiler, **kwargs):
6664
)
6765

6866

67+
types = {np.dtype(k): v for k, v in {
68+
'i8': sa.BigInteger,
69+
'i4': sa.Integer,
70+
'i2': sa.SmallInteger,
71+
'f4': sa.REAL,
72+
'f8': sa.FLOAT,
73+
'O': sa.Text,
74+
'M8[D]': sa.Date,
75+
'M8[us]': sa.DateTime,
76+
'?': sa.Boolean,
77+
"m8[D]": sa.Interval(second_precision=0, day_precision=9),
78+
"m8[h]": sa.Interval(second_precision=0, day_precision=0),
79+
"m8[m]": sa.Interval(second_precision=0, day_precision=0),
80+
"m8[s]": sa.Interval(second_precision=0, day_precision=0),
81+
"m8[ms]": sa.Interval(second_precision=3, day_precision=0),
82+
"m8[us]": sa.Interval(second_precision=6, day_precision=0),
83+
"m8[ns]": sa.Interval(second_precision=9, day_precision=0),
84+
}.items()}
85+
86+
_revtypes = dict(map(reversed, types.items()))
87+
_revtypes.update({
88+
sa.DATETIME: np.dtype('M8[us]'),
89+
sa.TIMESTAMP: np.dtype('M8[us]'),
90+
sa.FLOAT: np.dtype('f8'),
91+
sa.DATE: np.dtype('M8[D]'),
92+
sa.BIGINT: np.dtype('i8'),
93+
sa.INTEGER: np.dtype('i4'),
94+
sa.BIGINT: np.dtype('i8'),
95+
sa.types.NullType: np.dtype('O'),
96+
sa.REAL: np.dtype('f4'),
97+
sa.Float: np.dtype('f8'),
98+
})
99+
100+
_precision_types = {
101+
sa.Float,
102+
_postgresql.base.DOUBLE_PRECISION,
103+
}
104+
105+
106+
def _precision_to_dtype(precision):
107+
if isinstance(precision, numbers.Integral):
108+
if 1 <= precision <= 24:
109+
return np.dtype('f4')
110+
elif 25 <= precision <= 53:
111+
return np.dtype('f8')
112+
raise ValueError('%s is not a supported precision' % precision)
113+
114+
115+
_units_of_power = {
116+
0: 's',
117+
3: 'ms',
118+
6: 'us',
119+
9: 'ns'
120+
}
121+
122+
123+
def _discover_type(type_):
124+
if isinstance(type_, sa.Interval):
125+
if type_.second_precision is None and type_.day_precision is None:
126+
return np.dtype('m8[us]')
127+
elif type_.second_precision == 0 and type_.day_precision == 0:
128+
return np.dtype('m8[s]')
129+
130+
if (type_.second_precision in _units_of_power and
131+
not type_.day_precision):
132+
unit = _units_of_power[type_.second_precision]
133+
elif type_.day_precision > 0:
134+
unit = 'D'
135+
else:
136+
raise ValueError(
137+
'Cannot infer INTERVAL type_e with parameters'
138+
'second_precision=%d, day_precision=%d' %
139+
(type_.second_precision, type_.day_precision),
140+
)
141+
return np.dtype('m8[%s]' % unit)
142+
if type(type_) in _precision_types and type_.precision is not None:
143+
return _precision_to_dtype(type_.precision)
144+
if type_ in _revtypes:
145+
return _revtypes[type_]
146+
if type(type_) in _revtypes:
147+
return _revtypes[type(type_)]
148+
if isinstance(type_, sa.Numeric):
149+
raise ValueError('Cannot adapt numeric type to numpy dtype')
150+
if isinstance(type_, (sa.String, sa.Unicode)):
151+
return np.dtype('O')
152+
else:
153+
for k, v in _revtypes.items():
154+
if isinstance(k, type) and (isinstance(type_, k) or
155+
hasattr(type_, 'impl') and
156+
isinstance(type_.impl, k)):
157+
return v
158+
if k == type_:
159+
return v
160+
raise NotImplementedError('No SQL-numpy match for type %s' % type_)
161+
162+
69163
def _warp_prism_types(query):
70-
for name, dtype in discover(query).measure.fields:
164+
for col in query.columns:
165+
dtype = _discover_type(col.type)
71166
try:
72-
np_dtype = getattr(dtype, 'ty', dtype).to_numpy_dtype()
73-
if np_dtype.kind == 'U':
74-
yield _object_type_id
75-
else:
76-
yield _typeid_map[np_dtype]
167+
yield _typeid_map[dtype]
77168
except KeyError:
78169
raise TypeError(
79170
'warp_prism cannot query columns of type %s' % dtype,
@@ -136,7 +227,7 @@ def to_arrays(query, *, bind=None):
136227
return {column_names[n]: v for n, v in enumerate(out)}
137228

138229

139-
null_values = keymap(np.dtype, {
230+
null_values = {np.dtype(k): v for k, v in {
140231
'float32': np.nan,
141232
'float64': np.nan,
142233
'int16': np.nan,
@@ -145,7 +236,7 @@ def to_arrays(query, *, bind=None):
145236
'bool': np.nan,
146237
'datetime64[ns]': np.datetime64('nat', 'ns'),
147238
'object': None,
148-
})
239+
}.items()}
149240

150241
# alias because ``to_dataframe`` shadows this name
151242
_default_null_values_for_type = null_values
@@ -216,6 +307,9 @@ def register_odo_dataframe_edge():
216307
If the selectable is not in a postgres database, it will fallback to the
217308
default odo edge.
218309
"""
310+
from odo import convert
311+
from datashape.predicates import istabular
312+
219313
# estimating 8 times faster
220314
df_cost = convert.graph.edge[sa.sql.Select][pd.DataFrame]['cost'] / 8
221315

0 commit comments

Comments
 (0)