Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added constant for float relative tolerance #2071

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

COND_IDX = str(uuid.uuid4())
FIXED_RNG_SEED = 73251
FLOAT_RTOL = 0.01


class BaseSynthesizer:
Expand Down Expand Up @@ -576,7 +577,7 @@ def _filter_conditions(sampled, conditions, float_rtol):
return sampled

def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None,
float_rtol=0.1, previous_rows=None, keep_extra_columns=False):
float_rtol=FLOAT_RTOL, previous_rows=None, keep_extra_columns=False):
"""Sample rows with the given conditions.

Input conditions is taken both in the raw input format, which will be used
Expand Down Expand Up @@ -654,7 +655,7 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None,
return sampled, num_rows

def _sample_batch(self, batch_size, max_tries=100,
conditions=None, transformed_conditions=None, float_rtol=0.01,
conditions=None, transformed_conditions=None, float_rtol=FLOAT_RTOL,
progress_bar=None, output_file_path=None, keep_extra_columns=False):
"""Sample a batch of rows with the given conditions.

Expand Down Expand Up @@ -774,7 +775,7 @@ def _make_condition_dfs(conditions):
]

def _sample_in_batches(self, num_rows, batch_size, max_tries_per_batch, conditions=None,
transformed_conditions=None, float_rtol=0.01, progress_bar=None,
transformed_conditions=None, float_rtol=FLOAT_RTOL, progress_bar=None,
output_file_path=None):
sampled = []
batch_size = batch_size if num_rows > batch_size else num_rows
Expand All @@ -794,7 +795,7 @@ def _sample_in_batches(self, num_rows, batch_size, max_tries_per_batch, conditio
return sampled.head(num_rows)

def _conditionally_sample_rows(self, dataframe, condition, transformed_condition,
max_tries_per_batch=None, batch_size=None, float_rtol=0.01,
max_tries_per_batch=None, batch_size=None, float_rtol=FLOAT_RTOL,
graceful_reject_sampling=True, progress_bar=None,
output_file_path=None):
batch_size = batch_size or len(dataframe)
Expand Down