Skip to content

Commit

Permalink
Chunking SQL export queries
Browse files Browse the repository at this point in the history
  • Loading branch information
dogversioning committed Jun 12, 2024
1 parent 5db1c2a commit 888d5e8
Show file tree
Hide file tree
Showing 24 changed files with 703 additions and 220 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[run]
omit =cumulus_library/schema/*
16 changes: 15 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,23 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install ".[test]"
- name: Create mock AWS credentials
run: |
mkdir ~/.aws && touch ~/.aws/credentials
echo -e "[test]\naws_access_key_id = test\naws_secret_access_key = test" > ~/.aws/credentials
- name: Test with pytest
run: |
python -m pytest
python -m pytest --cov-report xml --cov=cumulus_library tests
- name: Generate coverage report
uses: orgoro/[email protected]
with:
coverageFile: coverage.xml
token: ${{ secrets.GITHUB_TOKEN }}
thresholdAll: .9
thresholdNew: 1
thresholdModified: .95


lint:
runs-on: ubuntu-22.04
steps:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ output.sql
*generated.md
MRCONSO.RRF
*.zip
coverage.xml

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
45 changes: 38 additions & 7 deletions cumulus_library/actions/exporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import pathlib

import pyarrow
from pyarrow import csv, parquet
from rich.progress import track

from cumulus_library import base_utils, databases, study_parser
Expand All @@ -25,6 +26,18 @@ def reset_counts_exports(
file.unlink()


def _write_chunk(writer, chunk, schema):
writer.write(
pyarrow.Table.from_pandas(
chunk.sort_values(
by=list(chunk.columns), ascending=False, na_position="first"
),
preserve_index=False,
schema=schema,
)
)


def export_study(
manifest_parser: study_parser.StudyManifestParser,
db: databases.DatabaseBackend,
Expand Down Expand Up @@ -56,13 +69,31 @@ def export_study(
description=f"Exporting {manifest_parser.get_study_prefix()} data...",
):
query = f"SELECT * FROM {table}"
dataframe = db.execute_as_pandas(query)
# Note: we assume that, for duckdb, you are unlikely to be dealing with large
# exports, so it will ignore the chunksize parameter, as it does not provide
# a pandas enabled cursor.
dataframe_chunks, db_schema = db.execute_as_pandas(query, chunksize=1000000)
first_chunk = next(dataframe_chunks)
path.mkdir(parents=True, exist_ok=True)
dataframe = dataframe.sort_values(
by=list(dataframe.columns), ascending=False, na_position="first"
)
dataframe.to_csv(f"{path}/{table}.csv", index=False, quoting=csv.QUOTE_MINIMAL)
dataframe.to_parquet(f"{path}/{table}.parquet", index=False)
# print(pyarrow.Schema.from_pandas(first_chunk))
# print(db_schema)
schema = pyarrow.schema(db.col_pyarrow_types_from_sql(db_schema))
# print(schema)
with parquet.ParquetWriter(f"{path}/{table}.parquet", schema) as p_writer:
with csv.CSVWriter(
f"{path}/{table}.csv",
schema,
write_options=csv.WriteOptions(
# Note that this quoting style is not exactly csv.QUOTE_MINIMAL
# https://github.com/apache/arrow/issues/42032
quoting_style="needed"
),
) as c_writer:
_write_chunk(p_writer, first_chunk, schema)
_write_chunk(c_writer, first_chunk, schema)
for chunk in dataframe_chunks:
_write_chunk(p_writer, chunk, schema)
_write_chunk(c_writer, chunk, schema)
queries.append(queries)
if archive:
base_utils.zip_dir(path, data_path, manifest_parser.get_study_prefix())
Expand Down
4 changes: 2 additions & 2 deletions cumulus_library/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def run_cli(args: dict):
"set[/italic], primarily dates, on a per patient level.\n\n"
"[bold]By doing this, you are assuming the responsibility for "
"meeting your organization's security requirements for "
"storing this data in a secure manager.[/bold]\n\n"
"storing this data in a secure manner.[/bold]\n\n"
"Type Y to proceed, or any other value to quit.\n"
)
console.print(warning_text)
Expand Down Expand Up @@ -493,7 +493,7 @@ def main(cli_args=None):
("umls_key", "UMLS_API_KEY"),
("url", "CUMULUS_AGGREGATOR_URL"),
("user", "CUMULUS_AGGREGATOR_USER"),
("workgroup", "CUMULUS_LIBRARY_WORKGROUP"),
("work_group", "CUMULUS_LIBRARY_WORKGROUP"),
)
read_env_vars = []
for pair in arg_env_pairs:
Expand Down
1 change: 1 addition & 0 deletions cumulus_library/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def add_aws_config(parser: argparse.ArgumentParser) -> None:
aws.add_argument(
"--workgroup",
default="cumulus",
dest="work_group",
help="Cumulus Athena workgroup (default: cumulus)",
)
aws.add_argument(
Expand Down
114 changes: 90 additions & 24 deletions cumulus_library/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import abc
import collections
import datetime
import json
import os
Expand All @@ -35,16 +36,16 @@ class DatabaseCursor(Protocol):
"""Protocol for a PEP-249 compatible cursor"""

def execute(self, sql: str) -> None:
pass
pass # pragma: no cover

def fetchone(self) -> list | None:
pass
pass # pragma: no cover

def fetchmany(self, size: int | None) -> list[list] | None:
pass
pass # pragma: no cover

def fetchall(self) -> list[list] | None:
pass
pass # pragma: no cover


class DatabaseParser(abc.ABC):
Expand Down Expand Up @@ -151,7 +152,9 @@ def pandas_cursor(self) -> DatabaseCursor:
"""

@abc.abstractmethod
def execute_as_pandas(self, sql: str) -> pandas.DataFrame:
def execute_as_pandas(
self, sql: str, chunksize: int | None = None
) -> (pandas.DataFrame | collections.abc.Iterator[pandas.DataFrame], list[tuple]):
"""Returns a pandas.DataFrame version of the results from the provided SQL"""

@abc.abstractmethod
Expand All @@ -172,9 +175,9 @@ def operational_errors(self) -> tuple[Exception]:
def col_parquet_types_from_pandas(self, field_types: list) -> list:
"""Returns appropriate types for creating tables based from parquet.
By default, returns the input (which assumes that the DB infers directly
By default, returns an empty list (which assumes that the DB infers directly
from parquet data types). Only override if your DB uses an explicit SerDe
format, or otherwise needs a modified typing to inject directly into a query."""
format, or otherwise needs a modfied typing to inject directly into a query."""

# The following example shows the types we're expecting to catch with this
# approach and the rough type to cast them to.
Expand All @@ -196,9 +199,12 @@ def col_parquet_types_from_pandas(self, field_types: list) -> list:
# raise errors.CumulusLibraryError(
# f"Unsupported type {type(field)} found."
# )
# return output
return []

return None

return field_types
def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list:
return columns

def upload_file(
self,
Expand Down Expand Up @@ -257,8 +263,11 @@ def cursor(self) -> AthenaCursor:
def pandas_cursor(self) -> AthenaPandasCursor:
return self.connection.cursor(cursor=AthenaPandasCursor)

def execute_as_pandas(self, sql: str) -> pandas.DataFrame:
return self.pandas_cursor().execute(sql).as_pandas()
def execute_as_pandas(
self, sql: str, chunksize: int | None = None
) -> (pandas.DataFrame | collections.abc.Iterator[pandas.DataFrame], list[tuple]):
query = self.pandas_cursor().execute(sql, chunksize=chunksize)
return query.as_pandas(), query.description

def parser(self) -> DatabaseParser:
return AthenaParser()
Expand All @@ -272,7 +281,10 @@ def col_parquet_types_from_pandas(self, field_types: list) -> list:
match field:
case numpy.dtypes.ObjectDType():
output.append("STRING")
case pandas.core.arrays.integer.Int64Dtype():
case (
pandas.core.arrays.integer.Int64Dtype()
| numpy.dtypes.Int64DType()
):
output.append("INT")
case numpy.dtypes.Float64DType():
output.append("DOUBLE")
Expand All @@ -282,7 +294,31 @@ def col_parquet_types_from_pandas(self, field_types: list) -> list:
output.append("TIMESTAMP")
case _:
raise errors.CumulusLibraryError(
f"Unsupported type {type(field)} found."
f"Unsupported pandas type {type(field)} found."
)
return output

def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list:
output = []
for column in columns:
match column[1]:
case "varchar":
output.append((column[0], pyarrow.string()))
case "bigint":
output.append((column[0], pyarrow.int64()))
case "integer":
output.append((column[0], pyarrow.int64()))
case "double":
output.append((column[0], pyarrow.float64()))
case "boolean":
output.append((column[0], pyarrow.bool_()))
case "date":
output.append((column[0], pyarrow.date64()))
case "timestamp":
output.append((column[0], pyarrow.timestamp("s")))
case _:
raise errors.CumulusLibraryError(
output.append(f"Unsupported SQL type '{column}' found.")
)
return output

Expand All @@ -296,9 +332,8 @@ def upload_file(
force_upload=False,
) -> str | None:
# We'll investigate the connection to get the relevant S3 upload path.
wg_conf = self.connection._client.get_work_group(WorkGroup=self.work_group)[
"WorkGroup"
]["Configuration"]["ResultConfiguration"]
workgroup = self.connection._client.get_work_group(WorkGroup=self.work_group)
wg_conf = workgroup["WorkGroup"]["Configuration"]["ResultConfiguration"]
s3_path = wg_conf["OutputLocation"]
bucket = "/".join(s3_path.split("/")[2:3])
key_prefix = "/".join(s3_path.split("/")[3:])
Expand All @@ -315,7 +350,7 @@ def upload_file(
f"{key_prefix}cumulus_user_uploads/{self.schema_name}/" f"{study}/{topic}"
)
if not remote_filename:
remote_filename = file
remote_filename = file.name

session = boto3.Session(profile_name=self.connection.profile_name)
s3_client = session.client("s3")
Expand Down Expand Up @@ -525,12 +560,41 @@ def pandas_cursor(self) -> duckdb.DuckDBPyConnection:
# Since this is not provided, return the vanilla cursor
return self.connection

def execute_as_pandas(self, sql: str) -> pandas.DataFrame:
def execute_as_pandas(
self, sql: str, chunksize: int | None = None
) -> (pandas.DataFrame | collections.abc.Iterator[pandas.DataFrame], list[tuple]):
# We call convert_dtypes here in case there are integer columns.
# Pandas will normally cast nullable-int as a float type unless
# we call this to convert to its nullable int column type.
# PyAthena seems to do this correctly for us, but not DuckDB.
return self.connection.execute(sql).df().convert_dtypes()
result = self.connection.execute(sql)
if chunksize:
return iter([result.df().convert_dtypes()]), result.description
return result.df().convert_dtypes(), result.description

def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list:
output = []
for column in columns:
match column[1]:
case "STRING":
output.append((column[0], pyarrow.string()))
case "INTEGER":
output.append((column[0], pyarrow.int64()))
case "NUMBER":
output.append((column[0], pyarrow.float64()))
case "DOUBLE":
output.append((column[0], pyarrow.float64()))
case "boolean" | "bool":
output.append((column[0], pyarrow.bool_()))
case "Date":
output.append((column[0], pyarrow.date64()))
case "TIMESTAMP" | "DATETIME":
output.append((column[0], pyarrow.timestamp("s")))
case _:
raise errors.CumulusLibraryError(
f"{column[0],column[1]} does not have a conversion type"
)
return output

def parser(self) -> DatabaseParser:
return DuckDbParser()
Expand Down Expand Up @@ -652,23 +716,25 @@ def read_ndjson_dir(path: str) -> dict[str, pyarrow.Table]:

def create_db_backend(args: dict[str, str]) -> DatabaseBackend:
db_config.db_type = args["db_type"]
database = args["schema_name"]
schema = args["schema_name"]
load_ndjson_dir = args.get("load_ndjson_dir")

if db_config.db_type == "duckdb":
backend = DuckDatabaseBackend(database) # `database` is path name in this case
backend = DuckDatabaseBackend(schema) # `database` is path name in this case
if load_ndjson_dir:
backend.insert_tables(read_ndjson_dir(load_ndjson_dir))
elif db_config.db_type == "athena":
backend = AthenaDatabaseBackend(
args["region"],
args["workgroup"],
args["work_group"],
args["profile"],
database,
schema,
)
if load_ndjson_dir:
sys.exit("Loading an ndjson dir is not supported with --db-type=athena.")
else:
raise ValueError(f"Unexpected --db-type value '{db_config.db_type}'")
raise errors.CumulusLibraryError(
f"'{db_config.db_type}' is not a supported database."
)

return backend
Binary file added main
Binary file not shown.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dev = [
test = [
"freezegun",
"pytest",
"pytest-cov",
"responses"
]

Expand Down
Binary file added test
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
cnt,category_code,recordedDate_month,code_display
"cnt","category_code","recordedDate_month","code_display"
15,,,
15,encounter-diagnosis,,
15,"encounter-diagnosis",,
Loading

0 comments on commit 888d5e8

Please sign in to comment.