-
Notifications
You must be signed in to change notification settings - Fork 545
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add base model for multi-table statistic model, change single-table b…
…ase class location (#102) * Create base.py for multi-table statistic models * Update base.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update statistic single-table base class * update multi-table base class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add functions (still draft) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update base.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix dict typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix type hint typo * Update base.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add multi-table test fixture * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify check settings in metadata * update multi-table base class * add test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply reviewer's suggestions. - remove `DataAccessType` - remove `pydantic.BaseModel` - set two mutually exclusive parameters `use_raw_data` and `use_dataloader` in `sdgx.models.base.SynthesizerModel` - some other necessary modifications * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
f18b552
commit d29a2a0
Showing
8 changed files
with
238 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from pathlib import Path | ||
from typing import Any, Dict, List | ||
|
||
import pandas as pd | ||
|
||
from sdgx.data_loader import DataLoader | ||
from sdgx.data_models.combiner import MetadataCombiner | ||
from sdgx.exceptions import SynthesizerInitError | ||
from sdgx.log import logger | ||
from sdgx.models.base import SynthesizerModel | ||
|
||
|
||
class MultiTableSynthesizerModel(SynthesizerModel): | ||
"""MultiTableSynthesizerModel | ||
The base model of multi-table statistic models. | ||
""" | ||
|
||
metadata_combiner: MetadataCombiner = None | ||
""" | ||
metadata_combiner is a sdgx builtin class, it stores all tables' metadata and relationships. | ||
This parameter must be specified when initializing the multi-table class. | ||
""" | ||
|
||
tables_data_frame: Dict[str, Any] = defaultdict() | ||
""" | ||
tables_data_frame is a dict contains every table's csv data frame. | ||
For a small amount of data, this scheme can be used. | ||
""" | ||
|
||
tables_data_loader: Dict[str, Any] = defaultdict() | ||
""" | ||
tables_data_loader is a dict contains every table's data loader. | ||
""" | ||
|
||
_parent_id: List = [] | ||
""" | ||
_parent_id is used to store all parent table's parimary keys in list. | ||
""" | ||
|
||
_table_synthesizers: Dict[str, Any] = {} | ||
""" | ||
_table_synthesizers is a dict to store model for each table. | ||
""" | ||
|
||
parent_map: Dict = defaultdict() | ||
""" | ||
The mapping from all child tables to their parent table. | ||
""" | ||
|
||
child_map: Dict = defaultdict() | ||
""" | ||
The mapping from all parent tabels to their child table. | ||
""" | ||
|
||
def __init__(self, metadata_combiner: MetadataCombiner, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.metadata_combiner = metadata_combiner | ||
self._calculate_parent_and_child_map() | ||
self.check() | ||
|
||
def _calculate_parent_and_child_map(self): | ||
"""Get the mapping from all parent tables to self._parent_map | ||
- key(str) is a child map; | ||
- value(str) is the parent map. | ||
""" | ||
relationships = self.metadata_combiner.relationships | ||
for each_relationship in relationships: | ||
parent_table = each_relationship.parent_table | ||
child_table = each_relationship.child_table | ||
self.parent_map[child_table] = parent_table | ||
self.child_map[parent_table] = child_table | ||
|
||
def _get_foreign_keys(self, parent_table, child_table): | ||
"""Get the foreign key list from a relationship""" | ||
relationships = self.metadata_combiner.relationships | ||
for each_relationship in relationships: | ||
# find the exact relationship and return foreign keys | ||
if ( | ||
each_relationship.parent_table == parent_table | ||
and each_relationship.child_table == child_table | ||
): | ||
return each_relationship.foreign_keys | ||
return [] | ||
|
||
def _get_all_foreign_keys(self, child_table): | ||
"""Given a child table, return ALL foreign keys from metadata.""" | ||
all_foreign_keys = [] | ||
relationships = self.metadata_combiner.relationships | ||
for each_relationship in relationships: | ||
# find the exact relationship and return foreign keys | ||
if each_relationship.child_table == child_table: | ||
all_foreign_keys.append(each_relationship.foreign_keys) | ||
|
||
return all_foreign_keys | ||
|
||
def _finalize(self): | ||
"""Finalize the""" | ||
raise NotImplementedError | ||
|
||
def check(self, check_circular=True): | ||
"""Excute necessary checks | ||
- check access type | ||
- check metadata_combiner | ||
- check relationship | ||
- check each metadata | ||
- validate circular relationships | ||
- validate child map_circular relationship | ||
- validate all tables connect relationship | ||
- validate column relationships foreign keys | ||
""" | ||
self._check_access_type() | ||
|
||
if not isinstance(self.metadata_combiner, MetadataCombiner): | ||
raise SynthesizerInitError("Wrong Metadata Combiner found.") | ||
pass | ||
|
||
def fit( | ||
self, dataloader: Dict[str, DataLoader], raw_data: Dict[str, pd.DataFrame], *args, **kwargs | ||
): | ||
""" | ||
Fit the model using the given metadata and dataloader. | ||
Args: | ||
dataloader (Dict[str, DataLoader]): The dataloader to use to fit the model. | ||
raw_data (Dict[str, pd.DataFrame]): The raw pd.DataFrame to use to fit the model. | ||
""" | ||
raise NotImplementedError | ||
|
||
def sample(self, count: int, *args, **kwargs) -> pd.DataFrame: | ||
""" | ||
Sample data from the model. | ||
Args: | ||
count (int): The number of samples to generate. | ||
Returns: | ||
pd.DataFrame: The generated data. | ||
""" | ||
|
||
raise NotImplementedError | ||
|
||
def save(self, save_dir: str | Path): | ||
pass | ||
|
||
@classmethod | ||
def load(target_path: str | Path): | ||
pass | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from __future__ import annotations | ||
|
||
from collections import defaultdict, namedtuple | ||
|
||
import pytest | ||
|
||
from sdgx.models.statistics.multi_tables.base import MultiTableSynthesizerModel | ||
|
||
|
||
@pytest.fixture | ||
def demo_base_multi_table_synthesizer( | ||
demo_multi_table_data_metadata_combiner, demo_multi_table_data_loader | ||
): | ||
yield MultiTableSynthesizerModel( | ||
use_dataloader=True, | ||
metadata_combiner=demo_multi_table_data_metadata_combiner, | ||
tables_data_loader=demo_multi_table_data_loader, | ||
) | ||
|
||
|
||
def test_base_multi_table_synthesizer(demo_base_multi_table_synthesizer): | ||
KeyTuple = namedtuple("KeyTuple", ["parent", "child"]) | ||
|
||
assert demo_base_multi_table_synthesizer.parent_map == defaultdict(None, {"train": "store"}) | ||
assert demo_base_multi_table_synthesizer.child_map == defaultdict(None, {"store": "train"}) | ||
assert demo_base_multi_table_synthesizer._get_all_foreign_keys("train")[0][0] == KeyTuple( | ||
parent="Store", child="Store" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main(["-vv", "-s", __file__]) |