diff --git a/sdgx/utils.py b/sdgx/utils.py index 4e8ba504..b8799ce6 100644 --- a/sdgx/utils.py +++ b/sdgx/utils.py @@ -18,7 +18,27 @@ except ImportError: from functools import lru_cache as cache -__all__ = ["download_demo_data", "get_demo_single_table", "cache", "Singleton", "find_free_port"] +__all__ = [ + "download_demo_data", + "get_demo_single_table", + "cache", + "Singleton", + "find_free_port", + "download_multi_table_demo_data", + "get_demo_single_table", +] + +MULTI_TABLE_DEMO_DATA = { + "rossman": { + "parent_table": "store", + "child_table": "train", + "parent_url": "https://raw.githubusercontent.com/juniorcl/rossman-store-sales/main/databases/store.csv", + "child_url": "https://raw.githubusercontent.com/juniorcl/rossman-store-sales/main/databases/train.csv", + "parent_primary_keys": ["Store"], + "child_primary_keys": ["Store", "Date"], + "foreign_keys": ["Store"], + } +} def find_free_port(): @@ -95,6 +115,72 @@ def __call__(cls, *args, **kwargs): return cls._instances[cls] +def download_multi_table_demo_data( + data_dir: str | Path = "./dataset", dataset_name="rossman" +) -> dict[str, Path]: + """ + Download multi-table demo data "Rossman Store Sales" or "Rossmann Store Sales" if not exist + + Args: + data_dir(str | Path): data directory + + Returns: + dict[str, pathlib.Path]: dict, the key is table name, value is demo data path + """ + demo_data_info = MULTI_TABLE_DEMO_DATA[dataset_name] + data_dir = Path(data_dir).expanduser().resolve() + parent_file_name = dataset_name + "_" + demo_data_info["parent_table"] + ".csv" + child_file_name = dataset_name + "_" + demo_data_info["child_table"] + ".csv" + demo_data_path_parent = data_dir / parent_file_name + demo_data_path_child = data_dir / child_file_name + # For now, I think it's OK to hardcode the URL for each dataset + # In the future we can consider using our own S3 Bucket or providing more data sets through sdg.idslab.io. + if not demo_data_path_parent.exists(): + # make dir + demo_data_path_parent.parent.mkdir(parents=True, exist_ok=True) + # download parent table from github link + logger.info("Downloading parent table from github to {}".format(demo_data_path_parent)) + parent_url = demo_data_info["parent_url"] + urllib.request.urlretrieve(parent_url, demo_data_path_parent) + # then child table + if not demo_data_path_child.exists(): + # make dir + demo_data_path_child.parent.mkdir(parents=True, exist_ok=True) + # download child table from github link + logger.info("Downloading child table from github to {}".format(demo_data_path_child)) + parent_url = demo_data_info["child_url"] + urllib.request.urlretrieve(parent_url, demo_data_path_child) + + return { + demo_data_info["parent_table"]: demo_data_path_parent, + demo_data_info["child_table"]: demo_data_path_child, + } + + +def get_demo_multi_table( + data_dir: str | Path = "./dataset", dataset_name="rossman" +) -> dict[str, pd.DataFrame]: + """ + Get multi-table demo data as DataFrame and relationship + + Args: + data_dir(str | Path): data directory + + Returns: + dict[str, pd.DataFrame]: multi-table data dict, the key is table name, value is DataFrame. + """ + multi_table_dict = {} + # download if not exist + demo_data_dict = download_multi_table_demo_data(data_dir, dataset_name) + # read Data from path + for table_name in demo_data_dict.keys(): + each_path = demo_data_dict[table_name] + pd_obj = pd.read_csv(each_path) + multi_table_dict[table_name] = pd_obj + + return multi_table_dict + + def ignore_warnings(category: Warning): def ignore_warnings_decorator(func: Callable): @functools.wraps(func) diff --git a/tests/conftest.py b/tests/conftest.py index ee1b3a00..8f1635f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from sdgx.data_connectors.csv_connector import CsvConnector from sdgx.data_loader import DataLoader from sdgx.data_models.metadata import Metadata -from sdgx.utils import download_demo_data +from sdgx.utils import download_demo_data, download_multi_table_demo_data _HERE = os.path.dirname(__file__) @@ -132,3 +132,29 @@ def demo_single_table_data_loader(demo_single_table_data_connector, cacher_kwarg @pytest.fixture def demo_single_table_metadata(demo_single_table_data_loader): yield Metadata.from_dataloader(demo_single_table_data_loader) + + +@pytest.fixture +def demo_multi_table_path(): + yield download_multi_table_demo_data(DATA_DIR) + + +@pytest.fixture +def demo_multi_table_data_connector(demo_multi_table_path): + connector_dict = {} + for each_table in demo_multi_table_path.keys(): + each_path = demo_multi_table_path[each_table] + connector_dict[each_table] = CsvConnector(path=each_path) + yield connector_dict + + +@pytest.fixture +def demo_multi_table_data_loader(demo_multi_table_data_connector, cacher_kwargs): + loader_dict = {} + for each_table in demo_multi_table_data_connector.keys(): + each_connector = demo_multi_table_data_connector[each_table] + each_d = DataLoader(each_connector, cacher_kwargs=cacher_kwargs) + loader_dict[each_table] = each_d + yield loader_dict + for each_table in demo_multi_table_data_connector.keys(): + demo_multi_table_data_connector[each_table].finalize()