Skip to content

Commit

Permalink
Add demo data for multi-table scenario (#98)
Browse files Browse the repository at this point in the history
* add multi-table demo data

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update get_demo_multi_table

* add some multi-table pytest fixture

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add return type hints

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MooooCat and pre-commit-ci[bot] committed Jan 8, 2024
1 parent c10270f commit 9479c1b
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 2 deletions.
88 changes: 87 additions & 1 deletion sdgx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,27 @@
except ImportError:
from functools import lru_cache as cache

__all__ = ["download_demo_data", "get_demo_single_table", "cache", "Singleton", "find_free_port"]
__all__ = [
"download_demo_data",
"get_demo_single_table",
"cache",
"Singleton",
"find_free_port",
"download_multi_table_demo_data",
"get_demo_single_table",
]

MULTI_TABLE_DEMO_DATA = {
"rossman": {
"parent_table": "store",
"child_table": "train",
"parent_url": "https://raw.githubusercontent.com/juniorcl/rossman-store-sales/main/databases/store.csv",
"child_url": "https://raw.githubusercontent.com/juniorcl/rossman-store-sales/main/databases/train.csv",
"parent_primary_keys": ["Store"],
"child_primary_keys": ["Store", "Date"],
"foreign_keys": ["Store"],
}
}


def find_free_port():
Expand Down Expand Up @@ -95,6 +115,72 @@ def __call__(cls, *args, **kwargs):
return cls._instances[cls]


def download_multi_table_demo_data(
data_dir: str | Path = "./dataset", dataset_name="rossman"
) -> dict[str, Path]:
"""
Download multi-table demo data "Rossman Store Sales" or "Rossmann Store Sales" if not exist
Args:
data_dir(str | Path): data directory
Returns:
dict[str, pathlib.Path]: dict, the key is table name, value is demo data path
"""
demo_data_info = MULTI_TABLE_DEMO_DATA[dataset_name]
data_dir = Path(data_dir).expanduser().resolve()
parent_file_name = dataset_name + "_" + demo_data_info["parent_table"] + ".csv"
child_file_name = dataset_name + "_" + demo_data_info["child_table"] + ".csv"
demo_data_path_parent = data_dir / parent_file_name
demo_data_path_child = data_dir / child_file_name
# For now, I think it's OK to hardcode the URL for each dataset
# In the future we can consider using our own S3 Bucket or providing more data sets through sdg.idslab.io.
if not demo_data_path_parent.exists():
# make dir
demo_data_path_parent.parent.mkdir(parents=True, exist_ok=True)
# download parent table from github link
logger.info("Downloading parent table from github to {}".format(demo_data_path_parent))
parent_url = demo_data_info["parent_url"]
urllib.request.urlretrieve(parent_url, demo_data_path_parent)
# then child table
if not demo_data_path_child.exists():
# make dir
demo_data_path_child.parent.mkdir(parents=True, exist_ok=True)
# download child table from github link
logger.info("Downloading child table from github to {}".format(demo_data_path_child))
parent_url = demo_data_info["child_url"]
urllib.request.urlretrieve(parent_url, demo_data_path_child)

return {
demo_data_info["parent_table"]: demo_data_path_parent,
demo_data_info["child_table"]: demo_data_path_child,
}


def get_demo_multi_table(
data_dir: str | Path = "./dataset", dataset_name="rossman"
) -> dict[str, pd.DataFrame]:
"""
Get multi-table demo data as DataFrame and relationship
Args:
data_dir(str | Path): data directory
Returns:
dict[str, pd.DataFrame]: multi-table data dict, the key is table name, value is DataFrame.
"""
multi_table_dict = {}
# download if not exist
demo_data_dict = download_multi_table_demo_data(data_dir, dataset_name)
# read Data from path
for table_name in demo_data_dict.keys():
each_path = demo_data_dict[table_name]
pd_obj = pd.read_csv(each_path)
multi_table_dict[table_name] = pd_obj

return multi_table_dict


def ignore_warnings(category: Warning):
def ignore_warnings_decorator(func: Callable):
@functools.wraps(func)
Expand Down
28 changes: 27 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
from sdgx.utils import download_demo_data
from sdgx.utils import download_demo_data, download_multi_table_demo_data

_HERE = os.path.dirname(__file__)

Expand Down Expand Up @@ -132,3 +132,29 @@ def demo_single_table_data_loader(demo_single_table_data_connector, cacher_kwarg
@pytest.fixture
def demo_single_table_metadata(demo_single_table_data_loader):
yield Metadata.from_dataloader(demo_single_table_data_loader)


@pytest.fixture
def demo_multi_table_path():
yield download_multi_table_demo_data(DATA_DIR)


@pytest.fixture
def demo_multi_table_data_connector(demo_multi_table_path):
connector_dict = {}
for each_table in demo_multi_table_path.keys():
each_path = demo_multi_table_path[each_table]
connector_dict[each_table] = CsvConnector(path=each_path)
yield connector_dict


@pytest.fixture
def demo_multi_table_data_loader(demo_multi_table_data_connector, cacher_kwargs):
loader_dict = {}
for each_table in demo_multi_table_data_connector.keys():
each_connector = demo_multi_table_data_connector[each_table]
each_d = DataLoader(each_connector, cacher_kwargs=cacher_kwargs)
loader_dict[each_table] = each_d
yield loader_dict
for each_table in demo_multi_table_data_connector.keys():
demo_multi_table_data_connector[each_table].finalize()

0 comments on commit 9479c1b

Please sign in to comment.