diff --git a/.gitignore b/.gitignore index 036cec9..51fb095 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,4 @@ benchmarks.db .dir-locals.el TAGS +.gdb_history diff --git a/.travis.yml b/.travis.yml index 6b100df..2c22faa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: python sudo: false python: - - "3.4" - "3.5" - "3.6" @@ -20,8 +19,10 @@ addons: install: - ${CC} --version - - pip install numpy + - pip install numpy pandas sqlalchemy - python -c "import numpy;print(numpy.__version__)" + - python -c "import pandas;print(pandas.__version__)"; + - python -c "import sqlalchemy;print(sqlalchemy.__version__)"; - pip install -e .[dev] script: diff --git a/setup.py b/setup.py index 175244b..083247b 100644 --- a/setup.py +++ b/setup.py @@ -42,17 +42,14 @@ ), ], install_requires=[ - 'datashape', 'numpy', - 'pandas', - 'sqlalchemy', 'psycopg2', - 'odo', - 'toolz', - 'networkx<=1.11', ], extras_require={ 'dev': [ + 'odo', + 'pandas==0.18.1', + 'networkx<=1.11', 'flake8==3.3.0', 'pycodestyle==2.3.1', 'pyflakes==1.5.0', diff --git a/warp_prism/__init__.py b/warp_prism/__init__.py index 7a560a5..6bb2897 100644 --- a/warp_prism/__init__.py +++ b/warp_prism/__init__.py @@ -1,251 +1,9 @@ -from io import BytesIO +from .query import to_arrays, to_dataframe +from .odo import register_odo_dataframe_edge -from datashape import discover -from datashape.predicates import istabular -import numpy as np -from odo import convert -import pandas as pd -import sqlalchemy as sa -from sqlalchemy.ext.compiler import compiles -from toolz import keymap +__version__ = '0.2.1' -from ._warp_prism import ( - raw_to_arrays as _raw_to_arrays, - typeid_map as _raw_typeid_map, -) - -__version__ = '0.1.1' - - -_typeid_map = keymap(np.dtype, _raw_typeid_map) -_object_type_id = _raw_typeid_map['object'] - - -class _CopyToBinary(sa.sql.expression.Executable, sa.sql.ClauseElement): - - def __init__(self, element, bind): - self.element = element - self._bind = bind = bind - - @property - def bind(self): - return self._bind - - -def literal_compile(s): - """Compile a sql expression with bind params inlined as literals. - - Parameters - ---------- - s : Selectable - The expression to compile. - - Returns - ------- - cs : str - An equivalent sql string. - """ - return str(s.compile(compile_kwargs={'literal_binds': True})) - - -@compiles(_CopyToBinary, 'postgresql') -def _compile_copy_to_binary_postgres(element, compiler, **kwargs): - selectable = element.element - return compiler.process( - sa.text( - 'COPY {stmt} TO STDOUT (FORMAT BINARY)'.format( - stmt=( - compiler.preparer.format_table(selectable) - if isinstance(selectable, sa.Table) else - '({})'.format(literal_compile(selectable)) - ), - ) - ), - **kwargs - ) - - -def _warp_prism_types(query): - for name, dtype in discover(query).measure.fields: - try: - np_dtype = getattr(dtype, 'ty', dtype).to_numpy_dtype() - if np_dtype.kind == 'U': - yield _object_type_id - else: - yield _typeid_map[np_dtype] - except KeyError: - raise TypeError( - 'warp_prism cannot query columns of type %s' % dtype, - ) - - -def _getbind(selectable, bind): - """Return an explicitly passed connection or infer the connection from - the selectable. - - Parameters - ---------- - selectable : sa.sql.Selectable - The selectable object being queried. - bind : bind or None - The explicit connection or engine to use to execute the query. - - Returns - ------- - bind : The bind which should be used to execute the query. - """ - if bind is None: - return selectable.bind - - if isinstance(bind, sa.engine.base.Engine): - return bind - - return sa.create_engine(bind) - - -def to_arrays(query, *, bind=None): - """Run the query returning a the results as np.ndarrays. - - Parameters - ---------- - query : sa.sql.Selectable - The query to run. This can be a select or a table. - bind : sa.Engine, optional - The engine used to create the connection. If not provided - ``query.bind`` will be used. - - Returns - ------- - arrays : dict[str, (np.ndarray, np.ndarray)] - A map from column name to the result arrays. The first array holds the - values and the second array is a boolean mask for NULLs. The values - where the mask is False are 0 interpreted by the type. - """ - # check types before doing any work - types = tuple(_warp_prism_types(query)) - - buf = BytesIO() - bind = _getbind(query, bind) - - stmt = _CopyToBinary(query, bind) - with bind.connect() as conn: - conn.connection.cursor().copy_expert(literal_compile(stmt), buf) - out = _raw_to_arrays(buf.getbuffer(), types) - column_names = query.c.keys() - return {column_names[n]: v for n, v in enumerate(out)} - - -null_values = keymap(np.dtype, { - 'float32': np.nan, - 'float64': np.nan, - 'int16': np.nan, - 'int32': np.nan, - 'int64': np.nan, - 'bool': np.nan, - 'datetime64[ns]': np.datetime64('nat', 'ns'), - 'object': None, -}) - -# alias because ``to_dataframe`` shadows this name -_default_null_values_for_type = null_values - - -def to_dataframe(query, *, bind=None, null_values=None): - """Run the query returning a the results as a pd.DataFrame. - - Parameters - ---------- - query : sa.sql.Selectable - The query to run. This can be a select or a table. - bind : sa.Engine, optional - The engine used to create the connection. If not provided - ``query.bind`` will be used. - null_values : dict[str, any] - The null values to use for each column. This falls back to - ``warp_prism.null_values`` for columns that are not specified. - - Returns - ------- - df : pd.DataFrame - A pandas DataFrame holding the results of the query. The columns - of the DataFrame will be named the same and be in the same order as the - query. - """ - arrays = to_arrays(query, bind=bind) - - if null_values is None: - null_values = {} - - for name, (array, mask) in arrays.items(): - if array.dtype.kind == 'i': - if not mask.all(): - try: - null = null_values[name] - except KeyError: - # no explicit override, cast to float and use NaN as null - array = array.astype('float64') - null = np.nan - - array[~mask] = null - - arrays[name] = array - continue - - if array.dtype.kind == 'M': - # pandas needs datetime64[ns], not ``us`` or ``D`` - array = array.astype('datetime64[ns]') - - try: - null = null_values[name] - except KeyError: - null = _default_null_values_for_type[array.dtype] - - array[~mask] = null - arrays[name] = array - - return pd.DataFrame(arrays, columns=[column.name for column in query.c]) - - -def register_odo_dataframe_edge(): - """Register an odo edge for sqlalchemy selectable objects to dataframe. - - This edge will have a lower cost that the default edge so it will be - selected as the fasted path. - - If the selectable is not in a postgres database, it will fallback to the - default odo edge. - """ - # estimating 8 times faster - df_cost = convert.graph.edge[sa.sql.Select][pd.DataFrame]['cost'] / 8 - - @convert.register( - pd.DataFrame, - (sa.sql.Select, sa.sql.Selectable), - cost=df_cost, - ) - def select_or_selectable_to_frame(el, bind=None, dshape=None, **kwargs): - bind = _getbind(el, bind) - - if bind.dialect.name != 'postgresql': - # fall back to the general edge - raise NotImplementedError() - - return to_dataframe(el, bind=bind) - - # higher priority than df edge so that - # ``odo('select one_column from ...', list)`` returns a list of scalars - # instead of a list of tuples of length 1 - @convert.register( - pd.Series, - (sa.sql.Select, sa.sql.Selectable), - cost=df_cost - 1, - ) - def select_or_selectable_to_series(el, bind=None, dshape=None, **kwargs): - bind = _getbind(el, bind) - - if istabular(dshape) or bind.dialect.name != 'postgresql': - # fall back to the general edge - raise NotImplementedError() - - return to_dataframe(el, bind=bind).iloc[:, 0] +__all__ = [ + 'to_arrays', 'to_dataframe', 'register_odo_dataframe_edge', +] diff --git a/warp_prism/odo.py b/warp_prism/odo.py new file mode 100644 index 0000000..336df31 --- /dev/null +++ b/warp_prism/odo.py @@ -0,0 +1,51 @@ +from .query import to_dataframe +from .sql import getbind + + +def register_odo_dataframe_edge(): + """Register an odo edge for sqlalchemy selectable objects to dataframe. + + This edge will have a lower cost that the default edge so it will be + selected as the fasted path. + + If the selectable is not in a postgres database, it will fallback to the + default odo edge. + """ + from datashape.predicates import istabular + from odo import convert + import pandas as pd + import sqlalchemy as sa + + # estimating 8 times faster + df_cost = convert.graph.edge[sa.sql.Select][pd.DataFrame]['cost'] / 8 + + @convert.register( + pd.DataFrame, + (sa.sql.Select, sa.sql.Selectable), + cost=df_cost, + ) + def select_or_selectable_to_frame(el, bind=None, dshape=None, **kwargs): + bind = getbind(el, bind) + + if bind.dialect.name != 'postgresql': + # fall back to the general edge + raise NotImplementedError() + + return to_dataframe(el, bind=bind) + + # higher priority than df edge so that + # ``odo('select one_column from ...', list)`` returns a list of scalars + # instead of a list of tuples of length 1 + @convert.register( + pd.Series, + (sa.sql.Select, sa.sql.Selectable), + cost=df_cost - 1, + ) + def select_or_selectable_to_series(el, bind=None, dshape=None, **kwargs): + bind = getbind(el, bind) + + if istabular(dshape) or bind.dialect.name != 'postgresql': + # fall back to the general edge + raise NotImplementedError() + + return to_dataframe(el, bind=bind).iloc[:, 0] diff --git a/warp_prism/query.py b/warp_prism/query.py new file mode 100644 index 0000000..ebba412 --- /dev/null +++ b/warp_prism/query.py @@ -0,0 +1,125 @@ +from functools import wraps +import io + +try: + import pandas as pd +except ImportError: + pd = None +import numpy as np + +from .sql import getbind, mogrify +from .types import query_typeids +from ._warp_prism import raw_to_arrays as _raw_to_arrays + + +def to_arrays(query, params=None, *, bind=None): + """Run the query returning a the results as np.ndarrays. + + Parameters + ---------- + query : str or sa.sql.Selectable + The query to run. This can be a select or a table. + params : dict or tuple or None + Bind parameters for ``query``. + bind : psycopg2.connection, sa.Engine, or sa.Connection, optional + The engine used to create the connection. If not provided + ``query.bind`` will be used. + + Returns + ------- + arrays : dict[str, (np.ndarray, np.ndarray)] + A map from column name to the result arrays. The first array holds the + values and the second array is a boolean mask for NULLs. The values + where the mask is False are 0 interpreted by the type. + """ + + buf = io.BytesIO() + bind = getbind(query, bind) + + with bind.cursor() as cur: + bound_query = mogrify(cur, query, params) + column_names, typeids = query_typeids(cur, bound_query) + cur.copy_expert('copy (%s) to stdout binary' % bound_query, buf) + + out = _raw_to_arrays(buf.getbuffer(), typeids) + + return {column_names[n]: v for n, v in enumerate(out)} + + +null_values = {np.dtype(k): v for k, v in { + 'float32': np.nan, + 'float64': np.nan, + 'int16': np.nan, + 'int32': np.nan, + 'int64': np.nan, + 'bool': np.nan, + 'datetime64[ns]': np.datetime64('nat', 'ns'), + 'object': None, +}.items()} + +# alias because ``to_dataframe`` shadows this name +_default_null_values_for_type = null_values + + +def to_dataframe(query, params=None, *, bind=None, null_values=None): + """Run the query returning a the results as a pd.DataFrame. + + Parameters + ---------- + query : str or sa.sql.Selectable + The query to run. This can be a select or a table. + params : dict or tuple or None + Bind parameters for ``query``. + bind : psycopg2.connection, sa.Engine, or sa.Connection, optional + The engine used to create the connection. If not provided + ``query.bind`` will be used. + null_values : dict[str, any] + The null values to use for each column. This falls back to + ``warp_prism.null_values`` for columns that are not specified. + + Returns + ------- + df : pd.DataFrame + A pandas DataFrame holding the results of the query. The columns + of the DataFrame will be named the same and be in the same order as the + query. + """ + arrays = to_arrays(query, bind=bind) + + if null_values is None: + null_values = {} + + for name, (array, mask) in arrays.items(): + if array.dtype.kind == 'i': + if not mask.all(): + try: + null = null_values[name] + except KeyError: + # no explicit override, cast to float and use NaN as null + array = array.astype('float64') + null = np.nan + + array[~mask] = null + + arrays[name] = array + continue + + if array.dtype.kind == 'M': + # pandas needs datetime64[ns], not ``us`` or ``D`` + array = array.astype('datetime64[ns]') + + try: + null = null_values[name] + except KeyError: + null = _default_null_values_for_type[array.dtype] + + array[~mask] = null + arrays[name] = array + + return pd.DataFrame(arrays) + + +if pd is None: + @wraps(to_dataframe) + def to_dataframe(*args, **kwargs): + raise NotImplementedError('to_dataframe requires pandas') diff --git a/warp_prism/sa.py b/warp_prism/sa.py new file mode 100644 index 0000000..55b5fa3 --- /dev/null +++ b/warp_prism/sa.py @@ -0,0 +1,12 @@ +def literal_compile(s): + """Compile a sql expression with bind params inlined as literals. + Parameters + ---------- + s : Selectable + The expression to compile. + Returns + ------- + cs : str + An equivalent sql string. + """ + return str(s.compile(compile_kwargs={'literal_binds': True})) diff --git a/warp_prism/sql.py b/warp_prism/sql.py new file mode 100644 index 0000000..4102345 --- /dev/null +++ b/warp_prism/sql.py @@ -0,0 +1,60 @@ +import psycopg2 +try: + import sqlalchemy as sa +except ImportError: + sa = None + + +def _sa_literal_compile(s): + """Compile a sql expression with variables inlined as literals. + + Parameters + ---------- + s : sa.sql.Selectable + The expression to compile. + + Returns + ------- + cs : str + An equivalent sql string. + """ + return str(s.compile(compile_kwargs={'literal_binds': True})) + + +def mogrify(cursor, query, params): + if sa is not None: + if isinstance(query, sa.Table): + query = _sa_literal_compile(sa.select(query.c)) + elif isinstance(query, sa.sql.Selectable): + query = _sa_literal_compile(query) + + return cursor.mogrify(query, params).decode('utf-8') + + +def getbind(query, bind): + """Get the connection to use for a query. + + Parameters + ---------- + query : str or sa.sql.Selectable + The query to run. + bind : psycopg2.extensions.connection or sa.engine.base.Engine or None + The explicitly provided bind. + + Returns + ------- + bind : psycopg2.extensions.connection + The connection to use for the query. + """ + if bind is not None: + if sa is None or isinstance(bind, psycopg2.extensions.connection): + return bind + + if isinstance(bind, sa.engine.base.Engine): + return bind.connect().connection.connection + + return sa.create_engine(bind).connect().connection.connection + elif sa is None or not isinstance(query, sa.sql.Selectable): + raise TypeError("missing 1 required argument: 'bind'") + else: + return query.bind.connect().connection.connection diff --git a/warp_prism/tests/__init__.py b/warp_prism/tests/__init__.py index 26b3c26..06e0911 100644 --- a/warp_prism/tests/__init__.py +++ b/warp_prism/tests/__init__.py @@ -2,40 +2,12 @@ from uuid import uuid4 import warnings -from odo import resource -import sqlalchemy as sa +import psycopg2 -def _dropdb(root_conn, db_name): - root_conn.execute('COMMIT') - root_conn.execute('DROP DATABASE %s' % db_name) - - -@contextmanager -def disposable_engine(uri): - """An engine which is disposed on exit. - - Parameters - ---------- - uri : str - The uri to the db. - - Yields - ------ - engine : sa.engine.Engine - """ - engine = resource(uri) - try: - yield engine - finally: - engine.dispose() - - -_pg_stat_activity = sa.Table( - 'pg_stat_activity', - sa.MetaData(), - sa.Column('pid', sa.Integer), -) +def _dropdb(cur, db_name): + cur.execute('COMMIT') + cur.execute('DROP DATABASE %s' % db_name) @contextmanager @@ -45,35 +17,26 @@ def tmp_db_uri(): db_name = '_warp_prism_test_' + uuid4().hex root = 'postgresql://localhost/' uri = root + db_name - with disposable_engine(root + 'postgres') as e, e.connect() as root_conn: - root_conn.execute('COMMIT') - root_conn.execute('CREATE DATABASE %s' % db_name) + with psycopg2.connect(root + 'postgres') as conn, conn.cursor() as cur: + cur.execute('COMMIT') + cur.execute('CREATE DATABASE %s' % db_name) try: yield uri finally: - resource(uri).dispose() try: - _dropdb(root_conn, db_name) - except sa.exc.OperationalError: - # We couldn't drop the db. The most likely cause is that there - # are active queries. Even more likely is that these are - # rollbacks because there was an exception somewhere inside the - # tests. We will cancel all the running queries and try to drop - # the database again. - pid = _pg_stat_activity.c.pid - root_conn.execute( - sa.select( - (sa.func.pg_terminate_backend(pid),), - ).where( - pid != sa.func.pg_backend_pid(), - ) + cur.execute(""" + select + pg_terminate_backend(pid) + from + pg_stat_activity + where + pid != pg_backend_pid() + """) + _dropdb(cur, db_name) + except: # pragma: no cover # noqa + # The database wasn't cleaned up. Just tell the user to deal + # with this manually. + warnings.warn( + "leaking database '%s', please manually delete this" % + db_name, ) - try: - _dropdb(root_conn, db_name) - except sa.exc.OperationalError: # pragma: no cover - # The database STILL wasn't cleaned up. Just tell the user - # to deal with this manually. - warnings.warn( - "leaking database '%s', please manually delete this" % - db_name, - ) diff --git a/warp_prism/tests/test_warp_prism.py b/warp_prism/tests/test_warp_prism.py index 9ec9532..653c55c 100644 --- a/warp_prism/tests/test_warp_prism.py +++ b/warp_prism/tests/test_warp_prism.py @@ -2,26 +2,33 @@ import struct from uuid import uuid4 -from datashape import var, R, Option, dshape import numpy as np -from odo import resource, odo -import pandas as pd +import psycopg2 import pytest -import sqlalchemy as sa from warp_prism._warp_prism import ( postgres_signature, raw_to_arrays, test_overflow_operations as _test_overflow_operations, ) -from warp_prism import ( - to_arrays, - to_dataframe, - null_values as null_values_for_type, - _typeid_map, -) +from warp_prism import to_arrays, to_dataframe +from warp_prism.query import null_values as null_values_for_type +from warp_prism.types import dtype_to_typeid from warp_prism.tests import tmp_db_uri as tmp_db_uri_ctx +try: + import pandas as pd +except ImportError: + pd = None + +try: + import sqlalchemy as sa + + use_sqlalchemy = pytest.mark.parametrize('use_sqlalchemy', [False, True]) +except ImportError: + sa = None + use_sqlalchemy = pytest.mark.parametrize('use_sqlalchemy', [False]) + @pytest.fixture(scope='module') def tmp_db_uri(): @@ -34,79 +41,158 @@ def tmp_table_uri(tmp_db_uri): return '%s::%s%s' % (tmp_db_uri, 'table_', uuid4().hex) -def check_roundtrip_nonnull(table_uri, data, dtype, sqltype): +def item(a): + """Convert a value to a Python built in type (not a numpy type). + + Parameters + ---------- + a : any + The value to convert. + + Returns + ------- + item : any + The base Python type equivalent value. + """ + try: + return a.item() + except AttributeError: + return a + + +def check_roundtrip_nonnull(table_uri, data, dtype, sqltype, use_sqlalchemy): """Check the data roundtrip through postgres using warp_prism to read the data Parameters ---------- table_uri : str - The uri to a unique table. + The uri for the table. data : np.array The input data. dtype : str The dtype of the data. - sqltype : type - The sqlalchemy type of the data. + sqltype : str + The sql type of the data. + use_sqlalchemy : bool + Use sqlalchemy for the query instead of psycopg2. """ - input_dataframe = pd.DataFrame({'a': data}) - table = odo(input_dataframe, table_uri, dshape=var * R['a': dtype]) - # Ensure that odo created the table correctly. If these fail the other - # tests are not well defined. - assert table.columns.keys() == ['a'] - assert isinstance(table.columns['a'].type, sqltype) - - arrays = to_arrays(table) + db, table = table_uri.split('::') + with psycopg2.connect(db) as conn: + with conn.cursor() as cur: + cur.execute('create table %s (a %s)' % (table, sqltype)) + cur.executemany( + 'insert into {} values (%s)'.format(table), + [(item(v),) for v in data], + ) + cur.execute('commit') + + if use_sqlalchemy: + bind = sa.create_engine(db) + meta = sa.MetaData(bind) + t = sa.Table(table, meta, autoload=True) + query = sa.select(t.c) + else: + bind = conn + query = 'select * from %s' % table + + arrays = to_arrays(query, bind=bind) + + if pd is not None: + output_dataframe = to_dataframe(query, bind=conn) + assert len(arrays) == 1 array, mask = arrays['a'] assert (array == data).all() assert mask.all() - output_dataframe = to_dataframe(table) - pd.util.testing.assert_frame_equal(output_dataframe, input_dataframe) + if pd is not None: + expected_dataframe = pd.DataFrame({'a': data}) + pd.util.testing.assert_frame_equal( + output_dataframe, + expected_dataframe, + ) +@use_sqlalchemy @pytest.mark.parametrize('dtype,sqltype,start,stop,step', ( - ('int16', sa.SmallInteger, 0, 5000, 1), - ('int32', sa.Integer, 0, 5000, 1), - ('int64', sa.BigInteger, 0, 5000, 1), - ('float32', sa.REAL, 0, 2500, 0.5), - ('float64', sa.FLOAT, 0, 2500, 0.5), + ('int16', 'int2', 0, 5000, 1), + ('int32', 'int4', 0, 5000, 1), + ('int64', 'int8', 0, 5000, 1), + ('float32', 'float4', 0, 2500, 0.5), + ('float64', 'float8', 0, 2500, 0.5), )) def test_numeric_type_nonnull(tmp_table_uri, dtype, sqltype, start, stop, - step): + step, + use_sqlalchemy): data = np.arange(start, stop, step, dtype=dtype) - check_roundtrip_nonnull(tmp_table_uri, data, dtype, sqltype) + check_roundtrip_nonnull( + tmp_table_uri, + data, + dtype, + sqltype, + use_sqlalchemy, + ) -def test_bool_type_nonnull(tmp_table_uri): +@use_sqlalchemy +def test_bool_type_nonnull(tmp_table_uri, use_sqlalchemy): data = np.array([True] * 2500 + [False] * 2500, dtype=bool) - check_roundtrip_nonnull(tmp_table_uri, data, 'bool', sa.Boolean) + check_roundtrip_nonnull( + tmp_table_uri, + data, + 'bool', + 'bool', + use_sqlalchemy, + ) -def test_string_type_nonnull(tmp_table_uri): +@use_sqlalchemy +def test_string_type_nonnull(tmp_table_uri, use_sqlalchemy): data = np.array(list(ascii_letters) * 200, dtype='object') - check_roundtrip_nonnull(tmp_table_uri, data, 'string', sa.String) + check_roundtrip_nonnull( + tmp_table_uri, + data, + 'object', + 'text', + use_sqlalchemy, + ) -def test_datetime_type_nonnull(tmp_table_uri): - data = pd.date_range( +@use_sqlalchemy +def test_datetime_type_nonnull(tmp_table_uri, use_sqlalchemy): + data = np.arange( '2000', '2016', - ).values.astype('datetime64[us]') - check_roundtrip_nonnull(tmp_table_uri, data, 'datetime', sa.DateTime) + dtype='M8[D]', + ).astype('datetime64[us]') + check_roundtrip_nonnull( + tmp_table_uri, + data, + 'datetime64[us]', + 'timestamp', + use_sqlalchemy, + ) -def test_date_type_nonnull(tmp_table_uri): - data = pd.date_range( +@use_sqlalchemy +def test_date_type_nonnull(tmp_table_uri, use_sqlalchemy): + data = np.arange( '2000', '2016', - ).values.astype('datetime64[D]') - check_roundtrip_nonnull(tmp_table_uri, data, 'date', sa.Date) + dtype='M8[D]', + ).astype('datetime64[D]') + check_roundtrip_nonnull( + tmp_table_uri, + data, + 'datetime64[D]', + 'date', + use_sqlalchemy, + ) def check_roundtrip_null_values(table_uri, @@ -115,6 +201,7 @@ def check_roundtrip_null_values(table_uri, sqltype, null_values, mask, + use_sqlalchemy, *, astype=False): """Check the data roundtrip through postgres using warp_prism to read the @@ -128,44 +215,69 @@ def check_roundtrip_null_values(table_uri, The input data. dtype : str The dtype of the data. - sqltype : type - The sqlalchemy type of the data. + sqltype : str + The sql type of the data. null_values : dict[str, any] The value to coerce ``NULL`` to. + mask : np.ndarray[bool] + A mask indicating which values are non-null. + use_sqlalchemy : bool + Use sqlalchemy for the query instead of psycopg2. astype : bool, optional Coerce the input data to the given dtype before making assertions about the output data. """ - table = resource(table_uri, dshape=var * R['a': Option(dtype)]) - # Ensure that odo created the table correctly. If these fail the other - # tests are not well defined. - assert table.columns.keys() == ['a'] - assert isinstance(table.columns['a'].type, sqltype) - table.insert().values([{'a': v} for v in data]).execute() - - arrays = to_arrays(table) + db, table = table_uri.split('::') + with psycopg2.connect(db) as conn: + with conn.cursor() as cur: + cur.execute('create table %s (a %s)' % (table, sqltype)) + cur.executemany( + 'insert into {} values (%s)'.format(table), + [(item(v),) for v in data], + ) + cur.execute('commit') + + if use_sqlalchemy: + bind = sa.create_engine(db) + meta = sa.MetaData(bind) + t = sa.Table(table, meta, autoload=True) + query = sa.select(t.c) + else: + bind = conn + query = 'select * from %s' % table + + arrays = to_arrays(query, bind=conn) + if pd is not None: + output_dataframe = to_dataframe( + query, + null_values=null_values, + bind=conn, + ) + assert len(arrays) == 1 array, actual_mask = arrays['a'] assert (actual_mask == mask).all() + assert (array[mask] == data[mask]).all() - output_dataframe = to_dataframe(table, null_values=null_values) - if astype: - data = data.astype(dshape(dtype).measure.to_numpy_dtype()) - expected_dataframe = pd.DataFrame({'a': data}) - expected_dataframe[~mask] = null_values.get( - 'a', - null_values_for_type[ - array.dtype - if array.dtype.kind != 'M' else - np.dtype('datetime64[ns]') - ], - ) - pd.util.testing.assert_frame_equal( - output_dataframe, - expected_dataframe, - check_dtype=False, - ) + if pd is not None: + if astype: + data = data.astype(dtype, copy=False) + + expected_dataframe = pd.DataFrame({'a': data}) + expected_dataframe[~mask] = null_values.get( + 'a', + null_values_for_type[ + array.dtype + if array.dtype.kind != 'M' else + np.dtype('datetime64[ns]') + ], + ) + pd.util.testing.assert_frame_equal( + output_dataframe, + expected_dataframe, + check_dtype=False, + ) def check_roundtrip_null(table_uri, @@ -174,6 +286,7 @@ def check_roundtrip_null(table_uri, sqltype, null, mask, + use_sqlalchemy, *, astype=False): """Check the data roundtrip through postgres using warp_prism to read the @@ -187,10 +300,14 @@ def check_roundtrip_null(table_uri, The input data. dtype : str The dtype of the data. - sqltype : type - The sqlalchemy type of the data. + sqltype : str + The sql type of the data. null : any The value to coerce ``NULL`` to. + mask : np.ndarray[bool] + A mask indicating which values are non-null. + use_sqlalchemy : bool + Use sqlalchemy for the query instead of psycopg2. astype : bool, optional Coerce the input data to the given dtype before making assertions about the output data. @@ -202,16 +319,18 @@ def check_roundtrip_null(table_uri, sqltype, {'a': null}, mask, + use_sqlalchemy, astype=astype, ) +@use_sqlalchemy @pytest.mark.parametrize('dtype,sqltype,start,stop,step,null', ( - ('int16', sa.SmallInteger, 0, 5000, 1, -1), - ('int32', sa.Integer, 0, 5000, 1, -1), - ('int64', sa.BigInteger, 0, 5000, 1, -1), - ('float32', sa.REAL, 0, 2500, 0.5, -1.0), - ('float64', sa.FLOAT, 0, 2500, 0.5, -1.0), + ('int16', 'int2', 0, 5000, 1, -1), + ('int32', 'int4', 0, 5000, 1, -1), + ('int64', 'int8', 0, 5000, 1, -1), + ('float32', 'float4', 0, 2500, 0.5, -1.0), + ('float64', 'float8', 0, 2500, 0.5, -1.0), )) def test_numeric_type_null(tmp_table_uri, dtype, @@ -219,67 +338,101 @@ def test_numeric_type_null(tmp_table_uri, start, stop, step, - null): + null, + use_sqlalchemy): data = np.arange(start, stop, step, dtype=dtype).astype(object) mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None - check_roundtrip_null(tmp_table_uri, data, dtype, sqltype, null, mask) + check_roundtrip_null( + tmp_table_uri, + data, + dtype, + sqltype, + null, + mask, + use_sqlalchemy, + ) +@use_sqlalchemy @pytest.mark.parametrize('dtype,sqltype', ( - ('int16', sa.SmallInteger), - ('int32', sa.Integer), - ('int64', sa.BigInteger), + ('int16', 'int2'), + ('int32', 'int4'), + ('int64', 'int8'), )) -def test_numeric_default_null_promote(tmp_table_uri, dtype, sqltype): +def test_numeric_default_null_promote(tmp_table_uri, + dtype, + sqltype, + use_sqlalchemy): data = np.arange(0, 100, dtype=dtype).astype(object) mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None - check_roundtrip_null_values(tmp_table_uri, data, dtype, sqltype, {}, mask) + check_roundtrip_null_values( + tmp_table_uri, + data, + dtype, + sqltype, + {}, + mask, + use_sqlalchemy, + ) -def test_bool_type_null(tmp_table_uri): +@use_sqlalchemy +def test_bool_type_null(tmp_table_uri, use_sqlalchemy): data = np.array([True] * 2500 + [False] * 2500, dtype=bool).astype(object) mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None - check_roundtrip_null(tmp_table_uri, data, 'bool', sa.Boolean, False, mask) + check_roundtrip_null( + tmp_table_uri, + data, + 'bool', + 'bool', + False, + mask, + use_sqlalchemy, + ) -def test_string_type_null(tmp_table_uri): +@use_sqlalchemy +def test_string_type_null(tmp_table_uri, use_sqlalchemy): data = np.array(list(ascii_letters) * 200, dtype='object') mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None check_roundtrip_null( tmp_table_uri, data, - 'string', - sa.String, + 'object', + 'text', 'ayy lmao', mask, + use_sqlalchemy, ) -def test_datetime_type_null(tmp_table_uri): - data = np.array( - list(pd.date_range( - '2000', - '2016', - )), - dtype=object, - )[:-1] # slice the last element off to have an even number +@use_sqlalchemy +def test_datetime_type_null(tmp_table_uri, use_sqlalchemy): + data = np.arange( + '2000', + '2016', + dtype='M8[D]', + ).astype('M8[us]').astype('O') + mask = np.tile(np.array([True, False]), len(data) // 2) data[~mask] = None check_roundtrip_null( tmp_table_uri, data, - 'datetime', - sa.DateTime, - pd.Timestamp('1995-12-13').to_datetime64(), + 'datetime64[us]', + 'timestamp', + np.datetime64('1995-12-13', 'ns'), mask, + use_sqlalchemy, ) -def test_date_type_null(tmp_table_uri): +@use_sqlalchemy +def test_date_type_null(tmp_table_uri, use_sqlalchemy): data = np.arange( '2000', '2016', @@ -290,10 +443,11 @@ def test_date_type_null(tmp_table_uri): check_roundtrip_null( tmp_table_uri, data, + 'datetime64[D]', 'date', - sa.Date, - pd.Timestamp('1995-12-13').to_datetime64(), + np.datetime64('1995-12-13', 'ns'), mask, + use_sqlalchemy, astype=True, ) @@ -341,7 +495,7 @@ def test_invalid_numeric_size(dtype): ) with pytest.raises(ValueError) as e: - raw_to_arrays(input_data, (_typeid_map[dtype],)) + raw_to_arrays(input_data, (dtype_to_typeid(dtype),)) assert str(e.value) == 'mismatched %s size: %s' % ( dtype.name, @@ -357,13 +511,12 @@ def test_invalid_datetime_size(): input_data = _pack_as_invalid_size_postgres_binary_data( 'q', # int64_t (quadword) 8, - (pd.Timestamp('2014-01-01').to_datetime64().astype('datetime64[us]') + - _epoch_offset).view('int64'), + (np.datetime64('2014-01-01', 'us') + _epoch_offset).view('int64'), ) dtype = np.dtype('datetime64[us]') with pytest.raises(ValueError) as e: - raw_to_arrays(input_data, (_typeid_map[dtype],)) + raw_to_arrays(input_data, (dtype_to_typeid(dtype),)) assert str(e.value) == 'mismatched datetime size: 7' @@ -377,7 +530,7 @@ def test_invalid_date_size(): dtype = np.dtype('datetime64[D]') with pytest.raises(ValueError) as e: - raw_to_arrays(input_data, (_typeid_map[dtype],)) + raw_to_arrays(input_data, (dtype_to_typeid(dtype),)) assert str(e.value) == 'mismatched date size: 3' @@ -405,7 +558,7 @@ def test_invalid_text(): # we put the invalid unicode as the first column to test that we can clean # up the cell in the second column before we have written a string there - str_typeid = _typeid_map[np.dtype(object)] + str_typeid = dtype_to_typeid(np.dtype(object)) with pytest.raises(UnicodeDecodeError): raw_to_arrays(input_data, (str_typeid, str_typeid)) diff --git a/warp_prism/types.py b/warp_prism/types.py new file mode 100644 index 0000000..9f99d6c --- /dev/null +++ b/warp_prism/types.py @@ -0,0 +1,128 @@ +import numpy as np + +from ._warp_prism import typeid_map as _raw_typeid_map + +_typeid_map = {np.dtype(k): v for k, v in _raw_typeid_map.items()} + + +def dtype_to_typeid(dtype): + """Convert a numpy dtype to a warp_prism type id. + + Parameters + ---------- + dtype : np.dtype + The numpy dtype to convert. + + Returns + ------- + typeid : int + The type id for ``dtype``. + """ + try: + return _typeid_map[dtype] + except KeyError: + raise ValueError('no warp_prism type id for dtype %s' % dtype) + + +_oid_map = { + 16: np.dtype('?'), + + # text + 17: np.dtype('O'), + 18: np.dtype('S1'), + 19: np.dtype('O'), + 25: np.dtype('O'), + + # int + 20: np.dtype('i8'), + 21: np.dtype('i2'), + 23: np.dtype('i4'), + 1042: np.dtype('O'), + 1043: np.dtype('O'), + + # float + 700: np.dtype('f4'), + 701: np.dtype('f8'), + + # date(time) + 1082: np.dtype('M8[D]'), + 1114: np.dtype('M8[us]'), + 1184: np.dtype('M8[us]'), +} + + +def oid_to_dtype(oid): + """Get a numpy dtype from postgres oid. + + Parameters + ---------- + oid : int + The oid to convert. + + Returns + ------- + dtype : np.dtype + The corresponding numpy dtype. + """ + try: + return _oid_map[oid] + except KeyError: + raise ValueError('cannot convert oid %s to numpy dtype' % oid) + + +def query_dtypes(cursor, bound_query): + """Get the numpy dtypes for each column returned by a query. + + Parameters + ---------- + cursor : psycopg2.cursor + The psycopg2 cursor to use to get the type information. + bound_query : str + The query to check the types of with all parameters bound. + + Returns + ------- + names : tuple[str] + The column names. + dtypes : tuple[np.dtype] + The column dtypes. + """ + cursor.execute('select * from (%s) a limit 0' % bound_query) + invalid = [] + names = [] + dtypes = [] + for c in cursor.description: + try: + dtypes.append(oid_to_dtype(c.type_code)) + except ValueError: + invalid.append(c) + else: + names.append(c.name) + + if invalid: + raise ValueError( + 'columns cannot be converted to numpy dtype: %s' % invalid + ) + + return tuple(names), tuple(dtypes) + + +def query_typeids(cursor, bound_query): + """Get the warp_prism typeid for each column returned by a query. + + Parameters + ---------- + cursor : psycopg2.cursor + The psycopg2 cursor to use to get the type information. + bound_query : str + The query to check the types of with all parameters bound. + + Returns + ------- + names : tuple[str] + The column names. + typeids : tuple[int] + The warp_prism typeid for each column.. + """ + names, dtypes = query_dtypes(cursor, bound_query) + return names, tuple(dtype_to_typeid(dtype) for dtype in dtypes)