Skip to content

Commit

Permalink
Added bootstrap estimate tests and documentation
Browse files Browse the repository at this point in the history
Signed-off-by: Alexan <[email protected]>
  • Loading branch information
Alcray committed Sep 27, 2024
1 parent f5788ee commit ed09054
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 101 deletions.
29 changes: 27 additions & 2 deletions docs/gen_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
IGNORE_CONFIGS = []


def gen_docs():
def gen_dataset_config_docs():
config_dir = str(Path(__file__).absolute().parents[1] / 'dataset_configs')
config_docs_dir = str(Path(__file__).parents[0] / 'src' / 'sdp' / 'config-docs')

Expand All @@ -53,6 +53,31 @@ def gen_docs():
with open(destination_path, "wt", encoding="utf-8") as fout:
fout.write(docs + link)

def gen_metric_config_docs():
config_dir = str(Path(__file__).absolute().parents[1] / 'metrics_configs')
config_docs_dir = str(Path(__file__).parents[0] / 'src' / 'sdp' / 'config-docs')

for root, dirs, files in os.walk(config_dir):
# Create corresponding directories in the destination directory
for directory in dirs:
source_path = os.path.join(root, directory)
destination_path = source_path.replace(config_dir, config_docs_dir)
os.makedirs(destination_path, exist_ok=True)

# Copy files and change the file extensions
for file in files:
if file.endswith('.yaml'):
source_path = os.path.join(root, file)
config_path = source_path.replace(config_dir, '')[1:] # removing leading /
if config_path in IGNORE_CONFIGS:
continue
destination_path = source_path.replace(config_dir, config_docs_dir).replace('.yaml', '.rst')
with open(source_path, "rt", encoding="utf-8") as fin:
docs = yaml.safe_load(fin).get('documentation', "Documentation is not yet available.") + "\n\n"
link = f"Config link: `dataset_configs/{config_path} <{ROOT_LINK}/{config_path}>`_"
with open(destination_path, "wt", encoding="utf-8") as fout:
fout.write(docs + link)

if __name__ == '__main__':
gen_docs()
gen_dataset_config_docs()
gen_metric_config_docs()
5 changes: 5 additions & 0 deletions docs/src/sdp/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ ASR-based processors
``text_key`` (defaults to "text") and ``pred_text_key`` (defaults to "text_pred")
to control which fields contain transcription and ASR model predictions.

Metric calculation
''''''''''''''''''
.. autodata:: sdp.utils.BootstrapProcessor
:annotation:

Data modifications
''''''''''''''''''

Expand Down
15 changes: 14 additions & 1 deletion docs/src/sdp/existing_configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,17 @@ Kazakh Speech Corpus 2 (KSC2)
.. toctree::
:hidden:

config-docs/kazakh/ksc2/config
config-docs/kazakh/ksc2/config


Bootstrap Estimation
~~~~~~~~~~~~~~~~~~~~


`config <https://github.com/NVIDIA/NeMo-speech-data-processor/blob/main/metrics_configs/bootstrap/config.yaml>`__ |
:doc:`documentation <config-docs/bootstrap/config>`

.. toctree::
:hidden:

config-docs/bootstrap/config
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,46 @@ documentation: |
**Required arguments**:
* **workspace_dir**: Specify the workspace folder where all the data and results will be stored.
* **manifest_files**: List of file paths to the manifest files in JSONL format.
* **workspace_dir**: Specify the workspace folder where the results will be stored.
* **raw_data_dir**: Specify the data folder where all the datawill be stored.
* **bootstrap_manifest_files**: List of file paths to the manifest files in JSONL format.
* **metric_type**: The metric to compute. Supported options include 'wer', 'cer', 'wmr', 'charrate', 'wordrate'.
* **dataset_size**: Proportion of dataset size for each bootstrap sample.
* **num_bootstraps**: The number of bootstrap iterations for metric computation.
* **ci_lower**: Lower bound percentile for confidence intervals (default: 2.5).
* **ci_upper**: Upper bound percentile for confidence intervals (default: 97.5).
* **calculate_pairwise**: Whether to calculate pairwise metric difference and POI
* **text_key**: The key in the manifest that contains the ground truth text
* **pred_text_key**: The key in the manifest that contains the predicted text
* **random_state**: Random state of the program
**Output format**:
The config generates the following outputs:
* **output_file**: A JSON file containing the results of the metric computation.
* **output_manifest_file**: A JSON file containing the results of the metric computation.
* The output file will contain fields for the mean metric value, confidence intervals (ci_lower, ci_upper), and optional pairwise comparisons between models.
processors_to_run: all
workspace_dir: ???
manifest_files: ["${workspace_dir}/manifest1.json", "${workspace_dir}/manifest2.json"]
output_file: ${workspace_dir}/results.json
bootstrap_manifest_files: ["manifest1.json", "manifest2.json"]
final_manifest: ${workspace_dir}/manifest.json
metric_type: "wer"
num_bootstraps: 10
num_bootstraps: 1000
ci_lower: 2.5
ci_upper: 97.5

processors:
- _target_: sdp.utils.BootstrapProcessor
manifest_files: ${manifest_files}
bootstrap_manifest_files: ${bootstrap_manifest_files}
raw_data_dir: ${workspace_dir}/data
num_bootstraps: ${num_bootstraps}
dataset_size: 1.0
output_file: ${output_file}
output_manifest_file: ${final_manifest}
calculate_pairwise: true
metric_type: ${metric_type}
text_key: "text" # Can be customized if manifest uses a different key for ground truth text
pred_text_key: "pred_text" # Can be customized if manifest uses a different key for predicted text
text_key: "text"
pred_text_key: "pred_text"
ci_lower: ${ci_lower}
ci_upper: ${ci_upper}
random_state: 42
96 changes: 40 additions & 56 deletions sdp/utils/bootstrap_estimates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,80 +27,66 @@ class BootstrapProcessor(BaseProcessor):
is set to True.
Args:
manifest_files (List[str]): List of file paths to manifest (JSONL) files for metric calculation
bootstrap_manifest_files (List[str]): List of file paths to manifest (JSONL) files for metric calculation
raw_data_dir (str): Directory of data files
num_bootstraps (int): Number of bootstrap iterations for metric computation
dataset_size (float): Proportion of dataset size for each bootstrap sample
output_file (str): Path to the output JSON file to save results
output_manifest_file (str): Path to the output JSON file to save results
calculate_pairwise (bool): Whether to calculate pairwise metric difference and POI (default: True)
metric_type (str): The type of metric to calculate, options include 'wer', 'cer', 'wmr', 'charrate', 'wordrate' (default: 'wer')
text_key (str): The key in the manifest that contains the ground truth text (default: 'text')
pred_text_key (str): The key in the manifest that contains the predicted text (default: 'pred_text')
ci_lower (float): Lower bound percentile for confidence intervals (default: 2.5)
ci_upper (float): Upper bound percentile for confidence intervals (default: 97.5)
random_state (int): Random state of the program
"""

def __init__(
self,
manifest_files: List[str],
bootstrap_manifest_files: List[str],
raw_data_dir: str,
num_bootstraps: int = 1000,
dataset_size: float = 1.0,
output_file: Optional[str] = None,
calculate_pairwise: bool = True, # Option for pairwise comparison
metric_type: str = 'wer', # Default metric is WER
text_key: str = 'text', # Default key for ground truth text in the manifest
pred_text_key: str = 'pred_text', # Default key for predicted text in the manifest
ci_lower: float = 2.5, # Lower percentile for confidence intervals
ci_upper: float = 97.5, # Upper percentile for confidence intervals
calculate_pairwise: bool = True,
metric_type: str = 'wer',
text_key: str = 'text',
pred_text_key: str = 'pred_text',
ci_lower: float = 2.5,
ci_upper: float = 97.5,
random_state: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)

self.manifest_files = manifest_files
self.bootstrap_manifest_files = bootstrap_manifest_files
self.raw_data_dir = raw_data_dir
self.num_bootstraps = num_bootstraps
self.dataset_size = dataset_size
self.output_file = output_file
self.calculate_pairwise = calculate_pairwise # Store the option to calculate pairwise metrics
self.metric_type = metric_type.lower() # Store metric type and convert to lowercase for consistency
self.text_key = text_key # Store the key for ground truth text
self.pred_text_key = pred_text_key # Store the key for predicted text
self.ci_lower = ci_lower # Store the lower percentile for confidence intervals
self.ci_upper = ci_upper # Store the upper percentile for confidence intervals

# Validate metric type
self.calculate_pairwise = calculate_pairwise
self.metric_type = metric_type.lower()
self.text_key = text_key
self.pred_text_key = pred_text_key
self.ci_lower = ci_lower
self.ci_upper = ci_upper
self.random_state = random_state


if self.random_state is not None:
np.random.seed(self.random_state)

if self.metric_type not in ['wer', 'cer', 'wmr', 'charrate', 'wordrate']:
raise ValueError(f"Invalid metric_type '{self.metric_type}'! Must be one of ['wer', 'cer', 'wmr', 'charrate', 'wordrate']")

def read_manifest(self, manifest_path: Path) -> List[Dict[str, Union[str, float]]]:
"""
Read a manifest file in JSONL format and return a list of dictionaries.
Args:
manifest_path (Path): Path to the manifest file (.jsonl)
Returns:
List[Dict[str, Union[str, float]]]: A list of dictionaries where each dictionary corresponds to an entry in the manifest.
"""
manifest_data = []
with manifest_path.open('r', encoding='utf-8') as f:
for line in f:
# Each line in the manifest file is a JSON object
data = json.loads(line.strip()) # Parse the JSON object
data = json.loads(line.strip())
manifest_data.append(data)

return manifest_data

def calculate_metric(self, text: str, pred_text: str, duration: Optional[float] = None) -> float:
"""
Calculate the specified metric between ground truth and predicted text.
Args:
text (str): Ground truth text
pred_text (str): Predicted text
duration (Optional[float]): Duration of the audio (used for charrate and wordrate)
Returns:
float: Computed metric value
"""
if self.metric_type == 'wer':
return metrics.get_wer(text, pred_text)
elif self.metric_type == 'cer':
Expand Down Expand Up @@ -195,13 +181,13 @@ def process(self):
results = {}

# Load ground truth and predictions
manifest_files = [Path(f) for f in self.manifest_files]
bootstrap_manifest_files = [Path(f) for f in self.bootstrap_manifest_files]
ground_truth = []
predicted_texts = []
durations = [] # Optional durations for charrate and wordrate
durations = []

for manifest_file in manifest_files:
manifest_data = self.read_manifest(manifest_file)
for manifest_file in bootstrap_manifest_files:
manifest_data = self.read_manifest(Path(self.raw_data_dir) / manifest_file)
# Use text_key and pred_text_key to extract ground truth and predictions
gt_texts = [entry[self.text_key] for entry in manifest_data]
pred_texts = [entry[self.pred_text_key] for entry in manifest_data]
Expand All @@ -225,7 +211,7 @@ def process(self):
ci_upper_value = np.percentile(metric_conf_intervals, self.ci_upper)
mean_metric = np.mean(metric_conf_intervals)

results["individual_results"][manifest_files[idx].name] = {
results["individual_results"][bootstrap_manifest_files[idx].name] = {
f"mean_{self.metric_type}": mean_metric,
"ci_lower": ci_lower_value,
"ci_upper": ci_upper_value
Expand All @@ -247,20 +233,18 @@ def process(self):
ci_upper_value = np.percentile(delta_metric_bootstrap, self.ci_upper)

results["pairwise_comparisons"].append({
"file_1": manifest_files[i].name,
"file_2": manifest_files[j].name,
"file_1": bootstrap_manifest_files[i].name,
"file_2": bootstrap_manifest_files[j].name,
f"delta_{self.metric_type}_mean": mean_delta_metric,
"ci_lower": ci_lower_value,
"ci_upper": ci_upper_value,
"poi": poi
})

# Save results to output file
if self.output_file:
output_path = Path(self.output_file)
output_path.parent.mkdir(exist_ok=True, parents=True)
with output_path.open('w') as out_file:
json.dump(results, out_file, indent=4)
output_path = Path(self.output_manifest_file)
output_path.parent.mkdir(exist_ok=True, parents=True)
with output_path.open('w') as out_file:
json.dump(results, out_file, indent=4)

print(f"Results saved to {self.output_file}")
print(f"Results saved to {self.output_manifest_file}")

Loading

0 comments on commit ed09054

Please sign in to comment.