Skip to content

Commit

Permalink
Add base model for multi-table statistic model, change single-table b…
Browse files Browse the repository at this point in the history
…ase class location (#102)

* Create base.py for multi-table statistic models

* Update base.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update statistic single-table base class

* update multi-table base class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add functions (still draft)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update base.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix dict typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix type hint typo

* Update base.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add multi-table test fixture

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify check settings in metadata

* update multi-table base class

* add test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply reviewer's suggestions.

- remove `DataAccessType`
- remove `pydantic.BaseModel`
- set two mutually exclusive parameters `use_raw_data` and `use_dataloader` in `sdgx.models.base.SynthesizerModel`
- some other necessary modifications

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MooooCat and pre-commit-ci[bot] authored Jan 16, 2024
1 parent f18b552 commit d29a2a0
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 6 deletions.
1 change: 0 additions & 1 deletion sdgx/data_models/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class MetadataCombiner(BaseModel):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check()

def check(self):
"""Do necessary checks:
Expand Down
8 changes: 6 additions & 2 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def from_dataloader(
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataLoader and Inspectors
Expand Down Expand Up @@ -257,7 +258,8 @@ def from_dataloader(
if not primary_keys:
metadata.update_primary_key(metadata.id_columns)

metadata.check()
if check:
metadata.check()
return metadata

@classmethod
Expand All @@ -267,6 +269,7 @@ def from_dataframe(
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataFrame and Inspectors
Expand Down Expand Up @@ -294,7 +297,8 @@ def from_dataframe(
metadata = Metadata(primary_keys=[df.columns[0]], column_list=set(df.columns))
for inspector in inspectors:
metadata.update(inspector.inspect())
metadata.check()
if check:
metadata.check()
return metadata

def _dump_json(self):
Expand Down
19 changes: 19 additions & 0 deletions sdgx/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,28 @@

from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
from sdgx.exceptions import SynthesizerInitError


class SynthesizerModel:
use_dataloader: bool = False
use_raw_data: bool = False

def __init__(self, *args, **kwargs) -> None:
# specify data access type
if "use_dataloader" in kwargs.keys():
self.use_dataloader = kwargs["use_dataloader"]
if "use_raw_data" in kwargs.keys():
self.use_raw_data = kwargs["use_raw_data"]

def _check_access_type(self):
if self.use_dataloader == self.use_raw_data == False:
raise SynthesizerInitError(
"Data access type not specified, please use `use_raw_data: bool` or `use_dataloader: bool` to specify data access type."
)
elif self.use_dataloader == self.use_raw_data == True:
raise SynthesizerInitError("Duplicate data access type found.")

def fit(self, metadata: Metadata, dataloader: DataLoader, *args, **kwargs):
"""
Fit the model using the given metadata and dataloader.
Expand Down
156 changes: 156 additions & 0 deletions sdgx/models/statistics/multi_tables/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from __future__ import annotations

from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List

import pandas as pd

from sdgx.data_loader import DataLoader
from sdgx.data_models.combiner import MetadataCombiner
from sdgx.exceptions import SynthesizerInitError
from sdgx.log import logger
from sdgx.models.base import SynthesizerModel


class MultiTableSynthesizerModel(SynthesizerModel):
"""MultiTableSynthesizerModel
The base model of multi-table statistic models.
"""

metadata_combiner: MetadataCombiner = None
"""
metadata_combiner is a sdgx builtin class, it stores all tables' metadata and relationships.
This parameter must be specified when initializing the multi-table class.
"""

tables_data_frame: Dict[str, Any] = defaultdict()
"""
tables_data_frame is a dict contains every table's csv data frame.
For a small amount of data, this scheme can be used.
"""

tables_data_loader: Dict[str, Any] = defaultdict()
"""
tables_data_loader is a dict contains every table's data loader.
"""

_parent_id: List = []
"""
_parent_id is used to store all parent table's parimary keys in list.
"""

_table_synthesizers: Dict[str, Any] = {}
"""
_table_synthesizers is a dict to store model for each table.
"""

parent_map: Dict = defaultdict()
"""
The mapping from all child tables to their parent table.
"""

child_map: Dict = defaultdict()
"""
The mapping from all parent tabels to their child table.
"""

def __init__(self, metadata_combiner: MetadataCombiner, *args, **kwargs):
super().__init__(*args, **kwargs)

self.metadata_combiner = metadata_combiner
self._calculate_parent_and_child_map()
self.check()

def _calculate_parent_and_child_map(self):
"""Get the mapping from all parent tables to self._parent_map
- key(str) is a child map;
- value(str) is the parent map.
"""
relationships = self.metadata_combiner.relationships
for each_relationship in relationships:
parent_table = each_relationship.parent_table
child_table = each_relationship.child_table
self.parent_map[child_table] = parent_table
self.child_map[parent_table] = child_table

def _get_foreign_keys(self, parent_table, child_table):
"""Get the foreign key list from a relationship"""
relationships = self.metadata_combiner.relationships
for each_relationship in relationships:
# find the exact relationship and return foreign keys
if (
each_relationship.parent_table == parent_table
and each_relationship.child_table == child_table
):
return each_relationship.foreign_keys
return []

def _get_all_foreign_keys(self, child_table):
"""Given a child table, return ALL foreign keys from metadata."""
all_foreign_keys = []
relationships = self.metadata_combiner.relationships
for each_relationship in relationships:
# find the exact relationship and return foreign keys
if each_relationship.child_table == child_table:
all_foreign_keys.append(each_relationship.foreign_keys)

return all_foreign_keys

def _finalize(self):
"""Finalize the"""
raise NotImplementedError

def check(self, check_circular=True):
"""Excute necessary checks
- check access type
- check metadata_combiner
- check relationship
- check each metadata
- validate circular relationships
- validate child map_circular relationship
- validate all tables connect relationship
- validate column relationships foreign keys
"""
self._check_access_type()

if not isinstance(self.metadata_combiner, MetadataCombiner):
raise SynthesizerInitError("Wrong Metadata Combiner found.")
pass

def fit(
self, dataloader: Dict[str, DataLoader], raw_data: Dict[str, pd.DataFrame], *args, **kwargs
):
"""
Fit the model using the given metadata and dataloader.
Args:
dataloader (Dict[str, DataLoader]): The dataloader to use to fit the model.
raw_data (Dict[str, pd.DataFrame]): The raw pd.DataFrame to use to fit the model.
"""
raise NotImplementedError

def sample(self, count: int, *args, **kwargs) -> pd.DataFrame:
"""
Sample data from the model.
Args:
count (int): The number of samples to generate.
Returns:
pd.DataFrame: The generated data.
"""

raise NotImplementedError

def save(self, save_dir: str | Path):
pass

@classmethod
def load(target_path: str | Path):
pass

pass
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ class SynthesizerModel:
random_states = None

def __init__(self, transformer=None, sampler=None) -> None:
# 以下几个变量都需要在初始化 model 时进行更改
self.model = None # 存放模型
self.model = None
self.status = "UNFINED"
self.model_type = "MODEL_TYPE_UNDEFINED"
# self.epochs = epochs
Expand Down
2 changes: 1 addition & 1 deletion sdgx/models/statistics/single_table/copula.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
unflatten_dict,
validate_numerical_distributions,
)
from sdgx.models.statistics.base import SynthesizerModel
from sdgx.models.statistics.single_table.base import SynthesizerModel

LOGGER = logging.getLogger(__name__)

Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.data_loader import DataLoader
from sdgx.data_models.combiner import MetadataCombiner
from sdgx.data_models.metadata import Metadata
from sdgx.data_models.relationship import Relationship
from sdgx.utils import download_demo_data, download_multi_table_demo_data

_HERE = os.path.dirname(__file__)
Expand Down Expand Up @@ -158,3 +160,24 @@ def demo_multi_table_data_loader(demo_multi_table_data_connector, cacher_kwargs)
yield loader_dict
for each_table in demo_multi_table_data_connector.keys():
demo_multi_table_data_connector[each_table].finalize()


@pytest.fixture
def demo_multi_data_relationship():
yield Relationship.build(parent_table="store", child_table="train", foreign_keys=["Store"])


@pytest.fixture
def demo_multi_table_data_metadata_combiner(
demo_multi_table_data_loader, demo_multi_data_relationship
):
# 1. get metadata
metadata_dict = {}
for each_table_name in demo_multi_table_data_loader:
each_metadata = Metadata.from_dataloader(demo_multi_table_data_loader[each_table_name])
metadata_dict[each_table_name] = each_metadata
# 2. define relationship - already defined
# 3. define combiner
m = MetadataCombiner(named_metadata=metadata_dict, relationships=[demo_multi_data_relationship])

yield m
32 changes: 32 additions & 0 deletions tests/models/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from collections import defaultdict, namedtuple

import pytest

from sdgx.models.statistics.multi_tables.base import MultiTableSynthesizerModel


@pytest.fixture
def demo_base_multi_table_synthesizer(
demo_multi_table_data_metadata_combiner, demo_multi_table_data_loader
):
yield MultiTableSynthesizerModel(
use_dataloader=True,
metadata_combiner=demo_multi_table_data_metadata_combiner,
tables_data_loader=demo_multi_table_data_loader,
)


def test_base_multi_table_synthesizer(demo_base_multi_table_synthesizer):
KeyTuple = namedtuple("KeyTuple", ["parent", "child"])

assert demo_base_multi_table_synthesizer.parent_map == defaultdict(None, {"train": "store"})
assert demo_base_multi_table_synthesizer.child_map == defaultdict(None, {"store": "train"})
assert demo_base_multi_table_synthesizer._get_all_foreign_keys("train")[0][0] == KeyTuple(
parent="Store", child="Store"
)


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])

0 comments on commit d29a2a0

Please sign in to comment.