diff --git a/.gitignore b/.gitignore index 49b785f..b397a0b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ raw_data *.ts *.zip -./__pycache__ \ No newline at end of file +./__pycache__ +upload_to_huggingface.py diff --git a/.linkspector.yml b/.linkspector.yml new file mode 100644 index 0000000..9de8111 --- /dev/null +++ b/.linkspector.yml @@ -0,0 +1,15 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# + +dirs: + - . +useGitIgnore: true +ignorePatterns: + - pattern: "doc:/" + - pattern: "http://localhost" + - pattern: "https://doi.org" diff --git a/README.md b/README.md index 98e6931..35a8c0d 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,6 @@ OpenTSLM models can reason over multiple time series of any length at once, gene

- ## Installation 1. **Clone the Repository** @@ -80,7 +79,83 @@ OpenTSLM has been tested and works with the following models: Other variants may work but have not been extensively tested. -## Multi-stage training (Curriculum) + +## πŸš€ Quickstart with pretrained models + +EmbedHealth provides a factory class called `OpenTSLM` for easily loading pre-trained models from Hugging Face Hub. The `load_pretrained` method automatically detects the model type and returns the appropriate model instance. + + +```python +from src import OpenTSLM, TextPrompt, TextTimeSeriesPrompt, FullPrompt + +# Load model +model = OpenTSLM.load_pretrained("OpenTSLM/gemma-3-270m-pt-har-flamingo") + +# Create prompt with raw time series data (normalization handled automatically) +prompt = FullPrompt( + pre_prompt=TextPrompt("You are an expert in HAR analysis."), + text_time_series_prompt_list=[ + TextTimeSeriesPrompt("X-axis accelerometer", [2.34, 2.34, 7.657, 3.21, -1.2]) + ], + post_prompt=TextPrompt("What activity is this? Reasn step by step providing a full rationale before replying.") +) + +# Generate response +output = model.eval_prompt(prompt, normalize=True) +print(output) +``` + +### πŸ€— HuggingFace Demo Scripts + +We provide ready-to-use demo scripts in the `demo/huggingface/` directory that demonstrate how to load pretrained models from HuggingFace Hub and run inference on the evaluation sets for each task: + +- **`01_test_hf_tsqa.py`** - Test TSQA (Time Series Question Answering) models +- **`02_test_hf_m4.py`** - Test M4 (Time Series Captioning) models +- **`03_test_hf_har_cot.py`** - Test HAR CoT (Human Activity Recognition Chain-of-Thought) models +- **`04_test_hf_sleep_cot.py`** - Test Sleep CoT (Sleep Stage Classification) models +- **`05_test_hf_ecg_qa_cot.py`** - Test ECG QA CoT (ECG Question Answering) models + +Each script: +1. Downloads the model checkpoint from HuggingFace Hub automatically (change repo id as neededs) +2. Loads the corresponding test dataset +3. Runs inference on the evaluation set +4. Prints model outputs with sample information + +**Note:** The scripts above use the OpenTSLM-SP models except for ECG-QA, as they require less VRAM and should run on most hardware. Change the model checkpoints as needed in each file. + +**Usage:** + +```bash +# Run any of the demo scripts +python demo/huggingface/01_test_hf_tsqa.py +python demo/huggingface/02_test_hf_m4.py +python demo/huggingface/03_test_hf_har_cot.py +python demo/huggingface/04_test_hf_sleep_cot.py +python demo/huggingface/05_test_hf_ecg_qa_cot.py +``` + +**Customizing the model:** + +Edit the `REPO_ID` variable at the top of each script to test different model variants. For example: + +```python +# In 01_test_hf_tsqa.py +REPO_ID = "OpenTSLM/llama-3.2-1b-tsqa-sp" # Soft Prompt model +# or +REPO_ID = "OpenTSLM/llama-3.2-1b-tsqa-flamingo" # Flamingo model +``` + +**Available models on HuggingFace:** + +All pretrained models are available under the `OpenTSLM` organization on HuggingFace Hub. Model names follow the pattern: +- `OpenTSLM/{base_model}-{dataset}-{model_type}` + - `base_model`: `llama-3.2-1b`, `llama-3.2-3b`, `gemma-3-1b-pt`, `gemma-3-270m` + - `dataset`: `tsqa`, `m4`, `har`, `sleep`, `ecg` + - `model_type`: `sp` (Soft Prompt) or `flamingo` (Flamingo) + +Example: `OpenTSLM/llama-3.2-1b-ecg-flamingo` + +## Training: Multi-stage training (Curriculum) OpenTSLM uses curriculum learning with progressive training stages: @@ -137,6 +212,11 @@ python curriculum_learning.py --model OpenTSLMFlamingo --eval_only - `--gradient_checkpointing`: Enable gradient checkpointing for memory efficiency - `--verbose`: Enable verbose logging +### Repository Naming Convention + +- Repository IDs ending with `-sp` will load and return `EmbedHealthSP` models +- Repository IDs ending with `-flamingo` will load and return `EmbedHealthFlamingo` models + ## πŸ“ Results Structure diff --git a/create_doctor_eval_dataset.py b/create_doctor_eval_dataset.py new file mode 100644 index 0000000..7c7216c --- /dev/null +++ b/create_doctor_eval_dataset.py @@ -0,0 +1,419 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +""" +Script to create doctor evaluation dataset with correct model predictions. +This script extracts ECG-QA templates with correct model outputs from llama3b_flamingo_predictions.jsonl +and creates organized folders with ECG plots, CSV data, and evaluation materials. +""" + +import json +import os +import sys +import re +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import wfdb +from pathlib import Path +from typing import Dict, List, Set, Tuple +from collections import defaultdict +from tqdm import tqdm +import shutil + +# Add the src directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) +from time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset +from time_series_datasets.ecg_qa.plot_example import draw_ecg, get_ptbxl_ecg_path + +# Configuration +MODEL_PREDICTIONS_FILE = "/Users/planger/Development/EmbedHealth/evaluation/embedhealth/ecg_qa_cot/llama3b_flamingo_predictions.jsonl" +OUTPUT_DIR = "ecg_doctor_eval" +SAMPLES_PER_TEMPLATE = 2 + +def extract_answer_from_generated(generated_text: str) -> str: + """Extract the final answer from generated text after 'Answer: '""" + if "Answer: " not in generated_text: + return generated_text.strip() + + answer = generated_text.split("Answer: ")[-1].strip() + # Remove any end-of-text tokens and trailing punctuation + answer = re.sub(r'<\|.*?\|>|$', '', answer).strip() + answer = re.sub(r'\.$', '', answer).strip() + return answer + +def is_correct_prediction(generated_text: str, correct_answer: str) -> bool: + """Check if the model prediction matches the correct answer""" + predicted_answer = extract_answer_from_generated(generated_text) + return predicted_answer.lower().strip() == correct_answer.lower().strip() + +def load_model_predictions() -> Dict[int, List[Dict]]: + """Load model predictions and group by template_id""" + print(f"Loading model predictions from {MODEL_PREDICTIONS_FILE}") + + template_predictions = defaultdict(list) + + with open(MODEL_PREDICTIONS_FILE, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(tqdm(f, desc="Loading predictions"), 1): + try: + data = json.loads(line.strip()) + + template_id = data.get('template_id') + if template_id is None: + continue + + # Check if prediction is correct + generated_text = data.get('generated', '') + correct_answer = data.get('correct_answer', '') + + if is_correct_prediction(generated_text, correct_answer): + template_predictions[template_id].append({ + 'template_id': template_id, + 'ecg_id': data.get('ecg_id', [None])[0] if data.get('ecg_id') else None, + 'generated': generated_text, + 'correct_answer': correct_answer, + 'pre_prompt': data.get('pre_prompt', ''), + 'line_number': line_num + }) + + except Exception as e: + print(f"Error processing line {line_num}: {e}") + continue + + print(f"Found correct predictions for {len(template_predictions)} templates") + return template_predictions + +def extract_clinical_context(pre_prompt: str) -> str: + """Extract clinical context from pre_prompt""" + if 'Clinical Context:' in pre_prompt: + context_start = pre_prompt.find('Clinical Context:') + context_end = pre_prompt.find('\n\n', context_start) + if context_end == -1: + context_end = pre_prompt.find('\n', context_start) + if context_end != -1: + return pre_prompt[context_start:context_end].strip() + return "Clinical context not available" + +def extract_question_from_prompt(pre_prompt: str) -> str: + """Extract the question from pre_prompt - this is the authoritative question for each sample""" + if 'Question: ' in pre_prompt: + question_start = pre_prompt.find('Question: ') + question_end = pre_prompt.find('\n\n', question_start) + if question_end == -1: + question_end = pre_prompt.find('\n', question_start) + if question_end != -1: + return pre_prompt[question_start:question_end].strip() + return "Question not available" + +def get_answer_options_for_template(template_id: int) -> List[str]: + """Get answer options for a template""" + try: + return ECGQACoTQADataset.get_possible_answers_for_template(template_id) + except Exception as e: + print(f"Warning: Could not get answer options for template {template_id}: {e}") + return [] + +def load_ecg_data(ecg_id: int) -> Tuple[np.ndarray, str]: + """Load ECG data for a given ECG ID""" + try: + ecg_path = get_ptbxl_ecg_path(ecg_id) + + if not os.path.exists(ecg_path + '.dat'): + raise FileNotFoundError(f"ECG file not found: {ecg_path}.dat") + + # Read ECG data using wfdb + ecg_data, meta = wfdb.rdsamp(ecg_path) + + # Get sampling frequency + sampling_freq = meta.get('fs', 500) # Default to 500Hz if not specified + + return ecg_data, sampling_freq + + except Exception as e: + raise RuntimeError(f"Failed to load ECG {ecg_id}: {e}") + +def downsample_to_100hz(ecg_data: np.ndarray, original_freq: int) -> np.ndarray: + """Downsample ECG data to 100Hz""" + if original_freq == 100: + return ecg_data + + # Calculate downsampling factor + downsample_factor = original_freq // 100 + + # Downsample by taking every nth sample + downsampled_data = ecg_data[::downsample_factor] + + return downsampled_data + +def save_ecg_as_csv(ecg_data: np.ndarray, output_dir: str, ecg_id: int): + """Save ECG data as separate CSV files for each lead""" + # Downsample to 100Hz if needed + if ecg_data.shape[0] > 1000: # Likely 500Hz data + ecg_data = downsample_to_100hz(ecg_data, 500) + + lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"] + + for lead_idx, lead_name in enumerate(lead_names): + if lead_idx < ecg_data.shape[1]: # Make sure we don't exceed available leads + lead_data = ecg_data[:, lead_idx] + + # Create DataFrame with time and signal values + time_points = np.arange(len(lead_data)) / 100.0 # 100Hz sampling + df = pd.DataFrame({ + 'time_seconds': time_points, + 'signal_mV': lead_data + }) + + # Save to CSV + csv_filename = f"{output_dir}/lead_{lead_name}.csv" + df.to_csv(csv_filename, index=False) + +def create_ecg_plot(ecg_data: np.ndarray, template_id: int, ecg_id: int, + question: str, answer_options: List[str], + clinical_context: str, model_output: str, + correct_answer: str, output_dir: str): + """Create ECG plot with all information""" + + # Downsample to 100Hz if needed + if ecg_data.shape[0] > 1000: # Likely 500Hz data + ecg_data = downsample_to_100hz(ecg_data, 500) + + # Create the plot with all 12 leads + fig, axes = plt.subplots(12, 1, figsize=(14, 24)) + fig.suptitle(f"Template {template_id}: ECG Analysis\nECG ID: {ecg_id}", + fontsize=16, fontweight='bold') + + # Create time array for 100Hz sampling (10 seconds) + time_points = np.arange(0, 10, 0.01) # 100Hz for 10 seconds + + # Plot all 12 leads + lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"] + for i, (ax, lead_name) in enumerate(zip(axes, lead_names)): + if i < ecg_data.shape[1]: # Make sure we don't exceed available leads + # Plot the ECG signal for this lead + ax.plot(time_points, ecg_data[:, i], linewidth=2, color="k", alpha=1.0) + + # Add grid lines (millimeter paper style) + # Major grid lines (every 0.2s and 0.5mV) + ax.vlines(np.arange(0, 10, 0.2), -2.5, 2.5, colors="r", alpha=0.3, linewidth=0.5) + ax.hlines(np.arange(-2.5, 2.5, 0.5), 0, 10, colors="r", alpha=0.3, linewidth=0.5) + + # Minor grid lines (every 0.04s and 0.1mV) + ax.vlines(np.arange(0, 10, 0.04), -2.5, 2.5, colors="r", alpha=0.1, linewidth=0.3) + ax.hlines(np.arange(-2.5, 2.5, 0.1), 0, 10, colors="r", alpha=0.1, linewidth=0.3) + + ax.set_xticks(np.arange(0, 11, 1.0)) + ax.set_ylabel(f'Lead {lead_name} (mV)', fontweight='bold') + ax.margins(0.0) + ax.set_ylim(-2.5, 2.5) + ax.set_title(f'Lead {lead_name}', fontweight='bold', pad=10) + else: + ax.set_title(f'Lead {lead_name} (not available)', fontweight='bold', pad=10) + ax.text(0.5, 0.5, 'Lead not available', ha='center', va='center', transform=ax.transAxes) + + # Add information text box + info_text = f"""Question: {question} + +Answer Options: {' | '.join(answer_options[:5])}{'...' if len(answer_options) > 5 else ''} + +Clinical Context: {clinical_context[:200]}{'...' if len(clinical_context) > 200 else ''} + +Model Output: {model_output[:300]}{'...' if len(model_output) > 300 else ''} + +Expected Answer: {correct_answer}""" + + fig.text(0.02, 0.02, info_text, fontsize=9, transform=fig.transFigure, + bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)) + + # Save plot + plot_filename = f"{output_dir}/ecg_plot.png" + fig.savefig(plot_filename, dpi=300, bbox_inches='tight', facecolor='white') + plt.close(fig) + + return plot_filename + +def create_evaluation_text_file(output_dir: str, template_id: int, ecg_id: int, + question: str, answer_options: List[str], + clinical_context: str, model_output: str, + correct_answer: str): + """Create a text file with all evaluation information""" + + txt_filename = f"{output_dir}/evaluation_info.txt" + with open(txt_filename, 'w') as f: + f.write(f"ECG-QA Doctor Evaluation\n") + f.write(f"=" * 50 + "\n\n") + + f.write(f"Template ID: {template_id}\n") + f.write(f"ECG ID: {ecg_id}\n\n") + + f.write(f"Question:\n{question}\n\n") + + f.write(f"Answer Options:\n") + for i, option in enumerate(answer_options, 1): + f.write(f"{i}. {option}\n") + f.write(f"\n") + + f.write(f"Clinical Context:\n{clinical_context}\n\n") + + f.write(f"Model Output (Llama3B-Flamingo):\n{model_output}\n\n") + + f.write(f"Expected Answer: {correct_answer}\n") + +def create_doctor_evaluation_dataset(): + """Main function to create the doctor evaluation dataset""" + + print("Creating doctor evaluation dataset...") + print("Note: Each template_id can have multiple different questions - using the specific question from each sample") + + # Create output directory + os.makedirs(OUTPUT_DIR, exist_ok=True) + + # Load model predictions + template_predictions = load_model_predictions() + + if not template_predictions: + print("No correct predictions found!") + return + + # Process each template + processed_templates = 0 + + for template_id in sorted(template_predictions.keys()): + predictions = template_predictions[template_id] + + if len(predictions) < SAMPLES_PER_TEMPLATE: + print(f"Template {template_id}: Only {len(predictions)} correct predictions, skipping") + continue + + print(f"\nProcessing template {template_id} with {len(predictions)} correct predictions") + + # Get answer options for this template + answer_options = get_answer_options_for_template(template_id) + if not answer_options: + print(f"Template {template_id}: No answer options found, skipping") + continue + + # Process up to SAMPLES_PER_TEMPLATE samples + samples_to_process = predictions[:SAMPLES_PER_TEMPLATE] + + for sample_idx, prediction in enumerate(samples_to_process, 1): + try: + ecg_id = prediction['ecg_id'] + if ecg_id is None: + print(f"Template {template_id}, Sample {sample_idx}: No ECG ID, skipping") + continue + + print(f" Processing sample {sample_idx}: ECG {ecg_id}") + + # Create sample directory + sample_dir = f"{OUTPUT_DIR}/template_{template_id:02d}/sample{sample_idx}" + os.makedirs(sample_dir, exist_ok=True) + + # Extract information + clinical_context = extract_clinical_context(prediction['pre_prompt']) + question = extract_question_from_prompt(prediction['pre_prompt']) # Use the specific question from this sample + model_output = prediction['generated'] + correct_answer = prediction['correct_answer'] + + # Load ECG data + try: + ecg_data, sampling_freq = load_ecg_data(ecg_id) + print(f" Loaded ECG data: {ecg_data.shape}, {sampling_freq}Hz") + except Exception as e: + print(f" Error loading ECG {ecg_id}: {e}") + continue + + # Save ECG as CSV files + try: + save_ecg_as_csv(ecg_data, sample_dir, ecg_id) + print(f" Saved ECG CSV files") + except Exception as e: + print(f" Error saving ECG CSV: {e}") + + # Create ECG plot + try: + plot_filename = create_ecg_plot( + ecg_data, template_id, ecg_id, question, answer_options, + clinical_context, model_output, correct_answer, sample_dir + ) + print(f" Created ECG plot: {plot_filename}") + except Exception as e: + print(f" Error creating ECG plot: {e}") + + # Create evaluation text file + try: + create_evaluation_text_file( + sample_dir, template_id, ecg_id, question, answer_options, + clinical_context, model_output, correct_answer + ) + print(f" Created evaluation text file") + except Exception as e: + print(f" Error creating text file: {e}") + + except Exception as e: + print(f" Error processing sample {sample_idx}: {e}") + continue + + processed_templates += 1 + print(f"Completed template {template_id}") + + print(f"\nDoctor evaluation dataset creation completed!") + print(f"Processed {processed_templates} templates") + print(f"Output directory: {OUTPUT_DIR}") + + # Create summary file + create_summary_file(template_predictions) + +def create_summary_file(template_predictions: Dict[int, List[Dict]]): + """Create a summary file with statistics""" + summary_file = f"{OUTPUT_DIR}/dataset_summary.txt" + + with open(summary_file, 'w') as f: + f.write("ECG-QA Doctor Evaluation Dataset Summary\n") + f.write("=" * 50 + "\n\n") + + f.write(f"Total templates with correct predictions: {len(template_predictions)}\n") + f.write(f"Samples per template: {SAMPLES_PER_TEMPLATE}\n") + f.write(f"Total samples created: {len(template_predictions) * SAMPLES_PER_TEMPLATE}\n\n") + + f.write("Template Statistics:\n") + f.write("-" * 30 + "\n") + + for template_id in sorted(template_predictions.keys()): + predictions = template_predictions[template_id] + f.write(f"Template {template_id:2d}: {len(predictions):3d} correct predictions\n") + + f.write(f"\nDataset Structure:\n") + f.write(f"ecg_doctor_eval/\n") + f.write(f"β”œβ”€β”€ template_01/\n") + f.write(f"β”‚ β”œβ”€β”€ sample1/\n") + f.write(f"β”‚ β”‚ β”œβ”€β”€ ecg_plot.png\n") + f.write(f"β”‚ β”‚ β”œβ”€β”€ evaluation_info.txt\n") + f.write(f"β”‚ β”‚ β”œβ”€β”€ lead_I.csv\n") + f.write(f"β”‚ β”‚ β”œβ”€β”€ lead_II.csv\n") + f.write(f"β”‚ β”‚ └── ... (all 12 leads)\n") + f.write(f"β”‚ └── sample2/\n") + f.write(f"β”‚ └── ... (same structure)\n") + f.write(f"β”œβ”€β”€ template_02/\n") + f.write(f"β”‚ └── ...\n") + f.write(f"└── dataset_summary.txt\n") + + f.write(f"\nNotes:\n") + f.write(f"- All predictions are CORRECT (model answer matches expected answer)\n") + f.write(f"- ECG data is downsampled to 100Hz for consistency\n") + f.write(f"- Each sample includes clinical context, question, answer options, and model reasoning\n") + f.write(f"- CSV files contain time series data for each ECG lead\n") + + print(f"Summary file created: {summary_file}") + +if __name__ == "__main__": + try: + create_doctor_evaluation_dataset() + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() diff --git a/curriculum_learning.py b/curriculum_learning.py index a9cf893..1269414 100644 --- a/curriculum_learning.py +++ b/curriculum_learning.py @@ -781,7 +781,6 @@ def _evaluate_stage( result["ecg_id"] = sample["ecg_id"] if "correct_answer" in sample: result["correct_answer"] = sample["correct_answer"] - results.append(result) # Stream write each result immediately to per-rank file results_fp.write(json.dumps(result, ensure_ascii=False) + "\n") @@ -822,11 +821,7 @@ def _evaluate_stage( print(f"Merged per-rank predictions into: {final_results_file}") finally: pass - - # Report test loss as NaN since we skip explicit loss computation during evaluation - # Before, we were computing the loss explicitly, but this required to run the model twice, once for loss and once for predictions. avg_test_loss = float("nan") - # Calculate stage-specific metrics metrics = {"test_loss": avg_test_loss} if epoch is not None: @@ -848,7 +843,6 @@ def _evaluate_stage( continue additional_metrics = metric_func(predictions, gold_answers) metrics.update(additional_metrics) - # Save results only on rank 0 (or when not distributed) if (not dist.is_initialized()) or (self.rank == 0): # Save metrics @@ -1366,58 +1360,6 @@ def stage5_ecg_cot( sampler=sampler, ) - def stage4_sleep_cot( - self, batch_size: int = None, eval_only: bool = False - ) -> Dict[str, Any]: - """Stage 4: Chain-of-Thought Reasoning (SleepEDF). - - Configuration: - - Epochs: 60 - - OpenTSLMSP: encoder_lr=2e-4, projector_lr=1e-4 - - OpenTSLMFlamingo: base_lr=2e-4 - - Metric: Test loss only (chain-of-thought reasoning) - """ - sampler = None - - return self._train_stage( - stage_name="stage4_sleep_cot", - dataset_class=SleepEDFCoTQADataset, - num_epochs=60, - lr_encoder=2e-4, - lr_projector=1e-4, - lr_base=2e-4, - metric_func=None, # Only test loss for chain-of-thought reasoning - batch_size=batch_size, - eval_only=eval_only, - sampler=sampler, - ) - - def stage5_ecg_cot( - self, batch_size: int = None, eval_only: bool = False - ) -> Dict[str, Any]: - """Stage 5: Chain-of-Thought Reasoning (ECG QA CoT). - - Configuration: - - Epochs: 60 - - OpenTSLMSP: encoder_lr=2e-4, projector_lr=1e-4 - - OpenTSLMFlamingo: base_lr=2e-4 - - Metric: Test loss only (chain-of-thought reasoning) - """ - sampler = None - - return self._train_stage( - stage_name="stage5_ecg_cot", - dataset_class=ECGQACoTQADataset, - num_epochs=60, - lr_encoder=2e-4, - lr_projector=1e-4, - lr_base=2e-4, - metric_func=None, # Only test loss for chain-of-thought reasoning - batch_size=batch_size, - eval_only=eval_only, - sampler=sampler, - ) - def run_curriculum( self, stages: List[str] = None, batch_size: int = None, eval_only: bool = False ): diff --git a/demo/huggingface/.gitignore b/demo/huggingface/.gitignore new file mode 100644 index 0000000..6320cd2 --- /dev/null +++ b/demo/huggingface/.gitignore @@ -0,0 +1 @@ +data \ No newline at end of file diff --git a/demo/huggingface/01_test_hf_tsqa.py b/demo/huggingface/01_test_hf_tsqa.py new file mode 100755 index 0000000..0280ba7 --- /dev/null +++ b/demo/huggingface/01_test_hf_tsqa.py @@ -0,0 +1,81 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +""" +Demo script for testing TSQA (Time Series Question Answering) model from HuggingFace. + +This script: +1. Loads a pretrained model from HuggingFace Hub +2. Loads the TSQA test dataset +3. Generates predictions on the evaluation set +4. Prints model outputs +""" + +import sys +import os + +# Add src to path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src"))) + +from model.llm.OpenTSLM import OpenTSLM +from time_series_datasets.TSQADataset import TSQADataset +from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate +from torch.utils.data import DataLoader +from model_config import PATCH_SIZE + +# Model repository ID - change this to test different models +REPO_ID = "OpenTSLM/llama-3.2-1b-tsqa-sp" + +def main(): + print("=" * 60) + print("TSQA Model Demo") + print("=" * 60) + + # Load model from HuggingFace + print(f"\nπŸ“₯ Loading model from {REPO_ID}...") + model = OpenTSLM.load_pretrained(REPO_ID) + + # Create dataset + print("\nπŸ“Š Loading TSQA test dataset...") + test_dataset = TSQADataset("test", EOS_TOKEN=model.get_eos_token()) + + # Create data loader + test_loader = DataLoader( + test_dataset, + shuffle=False, + batch_size=1, + collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate( + batch, patch_size=PATCH_SIZE + ), + ) + + print(f"\nπŸ” Running inference on {len(test_dataset)} test samples...") + print("=" * 60) + + # Iterate over evaluation set + for i, batch in enumerate(test_loader): + # Generate predictions + predictions = model.generate(batch, max_new_tokens=200) + + # Print results + for sample, pred in zip(batch, predictions): + print(f"\nπŸ“ Sample {i + 1}:") + print(f" Question: {sample.get('pre_prompt', 'N/A')}") + if 'time_series_text' in sample: + print(f" Time Series Info: {sample['time_series_text'][:100]}...") + print(f" Gold Answer: {sample.get('answer', 'N/A')}") + print(f" Model Output: {pred}") + print("-" * 60) + + # Limit to first 5 samples for demo + if i >= 4: + print("\nβœ… Demo complete! (Showing first 5 samples)") + break + +if __name__ == "__main__": + main() + diff --git a/demo/huggingface/02_test_hf_m4.py b/demo/huggingface/02_test_hf_m4.py new file mode 100755 index 0000000..dc33e7d --- /dev/null +++ b/demo/huggingface/02_test_hf_m4.py @@ -0,0 +1,82 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +""" +Demo script for testing M4 (Time Series Captioning) model from HuggingFace. + +This script: +1. Loads a pretrained model from HuggingFace Hub +2. Loads the M4 test dataset +3. Generates predictions on the evaluation set +4. Prints model outputs +""" + +import sys +import os + +# Add src to path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src"))) + +from model.llm.OpenTSLM import OpenTSLM +from time_series_datasets.m4.M4QADataset import M4QADataset +from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate +from torch.utils.data import DataLoader +from model_config import PATCH_SIZE + +# Model repository ID - change this to test different models +REPO_ID = "OpenTSLM/llama-3.2-1b-m4-sp" + +def main(): + print("=" * 60) + print("M4 Captioning Model Demo") + print("=" * 60) + + # Load model from HuggingFace + print(f"\nπŸ“₯ Loading model from {REPO_ID}...") + model = OpenTSLM.load_pretrained(REPO_ID) + + # Create dataset + print("\nπŸ“Š Loading M4 test dataset...") + test_dataset = M4QADataset("test", EOS_TOKEN=model.get_eos_token()) + + # Create data loader + test_loader = DataLoader( + test_dataset, + shuffle=False, + batch_size=1, + collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate( + batch, patch_size=PATCH_SIZE + ), + ) + + print(f"\nπŸ” Running inference on {len(test_dataset)} test samples...") + print("=" * 60) + + # Iterate over evaluation set + for i, batch in enumerate(test_loader): + # Generate predictions + predictions = model.generate(batch, max_new_tokens=200) + + # Print results + for sample, pred in zip(batch, predictions): + print(f"\nπŸ“ Sample {i + 1}:") + if 'id' in sample: + print(f" Time Series ID: {sample['id']}") + if 'pre_prompt' in sample: + print(f" Prompt: {sample['pre_prompt']}") + print(f" Gold Caption: {sample.get('answer', 'N/A')}") + print(f" Model Output: {pred}") + print("-" * 60) + + # Limit to first 5 samples for demo + if i >= 9: + print("\nβœ… Demo complete! (Showing first 10 samples)") + break + +if __name__ == "__main__": + main() + diff --git a/demo/huggingface/03_test_hf_har_cot.py b/demo/huggingface/03_test_hf_har_cot.py new file mode 100755 index 0000000..3c1f484 --- /dev/null +++ b/demo/huggingface/03_test_hf_har_cot.py @@ -0,0 +1,85 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +""" +Demo script for testing HAR CoT (Human Activity Recognition Chain-of-Thought) model from HuggingFace. + +This script: +1. Loads a pretrained model from HuggingFace Hub +2. Loads the HAR CoT test dataset +3. Generates predictions on the evaluation set +4. Prints model outputs +""" + +import sys +import os + +# Add src to path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src"))) + +from model.llm.OpenTSLM import OpenTSLM +from time_series_datasets.har_cot.HARCoTQADataset import HARCoTQADataset +from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate +from torch.utils.data import DataLoader +from model_config import PATCH_SIZE + +# Model repository ID - change this to test different models +REPO_ID = "OpenTSLM/llama-3.2-1b-har-sp" + +def main(): + print("=" * 60) + print("HAR CoT Model Demo") + print("=" * 60) + + # Load model from HuggingFace + print(f"\nπŸ“₯ Loading model from {REPO_ID}...") + enable_lora = False + if "-sp" in REPO_ID: + enable_lora = True + model = OpenTSLM.load_pretrained(REPO_ID, enable_lora=enable_lora) + + + + # Create dataset + print("\nπŸ“Š Loading HAR CoT test dataset...") + test_dataset = HARCoTQADataset("test", EOS_TOKEN=model.get_eos_token()) + + # Create data loader + test_loader = DataLoader( + test_dataset, + shuffle=False, + batch_size=1, + collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate( + batch, patch_size=PATCH_SIZE + ), + ) + + print(f"\nπŸ” Running inference on {len(test_dataset)} test samples...") + print("=" * 60) + + # Iterate over evaluation set + for i, batch in enumerate(test_loader): + # Generate predictions + predictions = model.generate(batch, max_new_tokens=500) + + # Print results + for sample, pred in zip(batch, predictions): + print(f"\nπŸ“ Sample {i + 1}:") + if 'pre_prompt' in sample: + print(f" Question: {sample['pre_prompt']}") + print(f" Gold Answer: {sample.get('answer', 'N/A')}") + print(f" Model Output: {pred}") + print("-" * 60) + + # Limit to first 5 samples for demo + if i >= 4: + print("\nβœ… Demo complete! (Showing first 5 samples)") + break + +if __name__ == "__main__": + main() + diff --git a/demo/huggingface/04_test_hf_sleep_cot.py b/demo/huggingface/04_test_hf_sleep_cot.py new file mode 100755 index 0000000..6258be5 --- /dev/null +++ b/demo/huggingface/04_test_hf_sleep_cot.py @@ -0,0 +1,83 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +""" +Demo script for testing Sleep CoT (Sleep Stage Classification Chain-of-Thought) model from HuggingFace. + +This script: +1. Loads a pretrained model from HuggingFace Hub +2. Loads the Sleep CoT test dataset +3. Generates predictions on the evaluation set +4. Prints model outputs +""" + +import sys +import os + +# Add src to path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src"))) + +from model.llm.OpenTSLM import OpenTSLM +from time_series_datasets.sleep.SleepEDFCoTQADataset import SleepEDFCoTQADataset +from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate +from torch.utils.data import DataLoader +from model_config import PATCH_SIZE + +# Model repository ID - change this to test different models +REPO_ID = "OpenTSLM/llama-3.2-1b-sleep-sp" + +def main(): + print("=" * 60) + print("Sleep CoT Model Demo") + print("=" * 60) + + # Load model from HuggingFace + print(f"\nπŸ“₯ Loading model from {REPO_ID}...") + enable_lora = False + if "-sp" in REPO_ID: + enable_lora = True + model = OpenTSLM.load_pretrained(REPO_ID, enable_lora=enable_lora) + + # Create dataset + print("\nπŸ“Š Loading Sleep CoT test dataset...") + test_dataset = SleepEDFCoTQADataset("test", EOS_TOKEN=model.get_eos_token()) + + # Create data loader + test_loader = DataLoader( + test_dataset, + shuffle=False, + batch_size=1, + collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate( + batch, patch_size=PATCH_SIZE + ), + ) + + print(f"\nπŸ” Running inference on {len(test_dataset)} test samples...") + print("=" * 60) + + # Iterate over evaluation set + for i, batch in enumerate(test_loader): + # Generate predictions + predictions = model.generate(batch, max_new_tokens=500) + + # Print results + for sample, pred in zip(batch, predictions): + print(f"\nπŸ“ Sample {i + 1}:") + if 'pre_prompt' in sample: + print(f" Question: {sample['pre_prompt']}") + print(f" Gold Answer: {sample.get('answer', 'N/A')}") + print(f" Model Output: {pred}") + print("-" * 60) + + # Limit to first 5 samples for demo + if i >= 4: + print("\nβœ… Demo complete! (Showing first 5 samples)") + break + +if __name__ == "__main__": + main() + diff --git a/demo/huggingface/05_test_hf_ecg_qa_cot.py b/demo/huggingface/05_test_hf_ecg_qa_cot.py new file mode 100755 index 0000000..7c7ac98 --- /dev/null +++ b/demo/huggingface/05_test_hf_ecg_qa_cot.py @@ -0,0 +1,87 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +""" +Demo script for testing ECG QA CoT (ECG Question Answering Chain-of-Thought) model from HuggingFace. + +This script: +1. Loads a pretrained model from HuggingFace Hub +2. Loads the ECG QA CoT test dataset +3. Generates predictions on the evaluation set +4. Prints model outputs +""" + +import sys +import os + +# Add src to path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src"))) + +from model.llm.OpenTSLM import OpenTSLM +from time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset +from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate +from torch.utils.data import DataLoader +from model_config import PATCH_SIZE + +# Model repository ID - change this to test different models +REPO_ID = "OpenTSLM/llama-3.2-1b-ecg-sp" + +def main(): + print("=" * 60) + print("ECG QA CoT Model Demo") + print("=" * 60) + + # Load model from HuggingFace + print(f"\nπŸ“₯ Loading model from {REPO_ID}...") + enable_lora = False + if "-sp" in REPO_ID: + enable_lora = True + model = OpenTSLM.load_pretrained(REPO_ID, enable_lora=enable_lora) + + # Create dataset + print("\nπŸ“Š Loading ECG QA CoT test dataset...") + test_dataset = ECGQACoTQADataset("test", EOS_TOKEN=model.get_eos_token()) + + # Create data loader + test_loader = DataLoader( + test_dataset, + shuffle=False, + batch_size=1, + collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate( + batch, patch_size=PATCH_SIZE + ), + ) + + print(f"\nπŸ” Running inference on {len(test_dataset)} test samples...") + print("=" * 60) + + # Iterate over evaluation set + for i, batch in enumerate(test_loader): + # Generate predictions + predictions = model.generate(batch, max_new_tokens=500) + + # Print results + for sample, pred in zip(batch, predictions): + print(f"\nπŸ“ Sample {i + 1}:") + if 'pre_prompt' in sample: + print(f" Question: {sample['pre_prompt']}") + if 'template_id' in sample: + print(f" Template ID: {sample['template_id']}") + if 'ecg_id' in sample: + print(f" ECG ID: {sample['ecg_id']}") + print(f" Gold Answer: {sample.get('answer', 'N/A')}") + print(f" Model Output: {pred}") + print("-" * 60) + + # Limit to first 5 samples for demo + if i >= 4: + print("\nβœ… Demo complete! (Showing first 5 samples)") + break + +if __name__ == "__main__": + main() + diff --git a/evaluation/baseline/common_evaluator.py b/evaluation/baseline/common_evaluator.py index b6e5cd1..9edbec4 100644 --- a/evaluation/baseline/common_evaluator.py +++ b/evaluation/baseline/common_evaluator.py @@ -498,9 +498,9 @@ def _consolidate_jsonl_results(self, model_name: str, dataset_name: str) -> str: "dataset_name": dataset_name, "total_samples": total_samples, "successful_inferences": successful_inferences, - "success_rate": successful_inferences / total_samples - if total_samples > 0 - else 0.0, + "success_rate": ( + successful_inferences / total_samples if total_samples > 0 else 0.0 + ), "metrics": aggregate_metrics, "detailed_results": individual_results, } diff --git a/evaluation/baseline/parse_predictions_sleep_baseline.py b/evaluation/baseline/parse_predictions_sleep_baseline.py index e8b9fd0..6352c8f 100644 --- a/evaluation/baseline/parse_predictions_sleep_baseline.py +++ b/evaluation/baseline/parse_predictions_sleep_baseline.py @@ -58,6 +58,7 @@ # --- Inline minimal utilities (avoid importing modules that require extra packages) --- import re + def extract_answer(text: str) -> str: """Extract the final answer from text by taking content after 'Answer:' and trimming trailing special tokens. @@ -65,21 +66,21 @@ def extract_answer(text: str) -> str: if "Answer: " not in text: return text answer = text.split("Answer: ")[-1].strip() - answer = re.sub(r'<\|.*?\|>$', '', answer).strip() + answer = re.sub(r"<\|.*?\|>$", "", answer).strip() return answer def calculate_f1_score(prediction: str, ground_truth: str): """Binary exact-match F1 on normalized strings (lower/strip/punct).""" - pred_normalized = prediction.lower().strip().rstrip('.,!?;:') - truth_normalized = ground_truth.lower().strip().rstrip('.,!?;:') + pred_normalized = prediction.lower().strip().rstrip(".,!?;:") + truth_normalized = ground_truth.lower().strip().rstrip(".,!?;:") f1 = 1.0 if pred_normalized == truth_normalized else 0.0 return { - 'f1_score': f1, - 'precision': f1, - 'recall': f1, - 'prediction_normalized': pred_normalized, - 'ground_truth_normalized': truth_normalized, + "f1_score": f1, + "precision": f1, + "recall": f1, + "prediction_normalized": pred_normalized, + "ground_truth_normalized": truth_normalized, } @@ -122,7 +123,11 @@ def calculate_f1_stats(data_points: List[Dict], allowed_labels=None): tp, fp, fn = counts["tp"], counts["fp"], counts["fn"] precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 - f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 + f1 = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) class_f1_scores[class_name] = { "f1": f1, "precision": precision, @@ -186,8 +191,8 @@ def canonicalize_sleep_label(s: str) -> str: # If it's an option-like artifact (e.g., "(a) wake"), keep only the label part # but since SleepEDF doesn't use options, just strip leading option markers if present - if len(t) > 3 and t[0] == '(' and ')' in t[:4]: - t = t.split(')', 1)[-1].strip() + if len(t) > 3 and t[0] == "(" and ")" in t[:4]: + t = t.split(")", 1)[-1].strip() # Short aliases if t in {"w"}: @@ -290,18 +295,18 @@ def extract_structured_data(obj: Dict) -> List[Dict]: # Binary exact-match accuracy on normalized labels handled in calculate_f1_score, but keep explicit flag f1_result = calculate_f1_score(model_prediction, ground_truth) - accuracy = f1_result['f1_score'] == 1.0 + accuracy = f1_result["f1_score"] == 1.0 data_point = { "generated": generated, "model_prediction": model_prediction, "ground_truth": ground_truth, "accuracy": accuracy, - "f1_score": f1_result['f1_score'], - "precision": f1_result['precision'], - "recall": f1_result['recall'], - "prediction_normalized": f1_result['prediction_normalized'], - "ground_truth_normalized": f1_result['ground_truth_normalized'], + "f1_score": f1_result["f1_score"], + "precision": f1_result["precision"], + "recall": f1_result["recall"], + "prediction_normalized": f1_result["prediction_normalized"], + "ground_truth_normalized": f1_result["ground_truth_normalized"], } data_points.append(data_point) @@ -316,16 +321,16 @@ def main(): "--detailed-json", type=Path, required=True, - help="Path to a single results JSON file containing 'detailed_results'" + help="Path to a single results JSON file containing 'detailed_results'", ) ap.add_argument( "--clean-out", type=Path, - help="Optional path to write clean JSONL of parsed per-sample points" + help="Optional path to write clean JSONL of parsed per-sample points", ) args = ap.parse_args() - with args.detailed_json.open('r', encoding='utf-8') as f: + with args.detailed_json.open("r", encoding="utf-8") as f: obj = json.load(f) # Extract per-sample points @@ -374,9 +379,9 @@ def main(): print(f"Macro-F1 Score: {f1_stats.get('macro_f1', 0.0):.4f}") print(f"Total Classes: {f1_stats.get('total_classes', 0)}") - if f1_stats.get('class_f1_scores'): + if f1_stats.get("class_f1_scores"): print(f"\nPer-Class F1 Scores:") - for class_name, scores in f1_stats['class_f1_scores'].items(): + for class_name, scores in f1_stats["class_f1_scores"].items(): print( f" {class_name}: F1={scores['f1']:.4f}, " f"P={scores['precision']:.4f}, R={scores['recall']:.4f}" @@ -384,7 +389,7 @@ def main(): # Optional clean JSONL output if args.clean_out: - with args.clean_out.open('w', encoding='utf-8') as f: + with args.clean_out.open("w", encoding="utf-8") as f: for item in data_points: f.write(json.dumps(item, indent=2) + "\n") print(f"\nData saved to {args.clean_out}") diff --git a/evaluation/opentslm/ecg_qa_cot/parse_ecg_qa_cot_data.py b/evaluation/opentslm/ecg_qa_cot/parse_ecg_qa_cot_data.py index fe416a5..c72223d 100644 --- a/evaluation/opentslm/ecg_qa_cot/parse_ecg_qa_cot_data.py +++ b/evaluation/opentslm/ecg_qa_cot/parse_ecg_qa_cot_data.py @@ -18,17 +18,18 @@ from tqdm import tqdm # Add the src directory to the path to import from the dataset class -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "src")) from time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset + def calculate_f1_score(prediction, ground_truth, possible_answers): """Calculate F1 score for single-label classification with template-specific answers. - + Args: prediction: Model's predicted answer ground_truth: Ground truth answer possible_answers: List of valid answers for this template - + Returns: Dict with F1 metrics and metadata """ @@ -38,77 +39,78 @@ def calculate_f1_score(prediction, ground_truth, possible_answers): raise ValueError("Ground truth cannot be None") if not possible_answers: raise ValueError("Possible answers list cannot be empty") - + # Normalize predictions and ground truth - pred_normalized = prediction.lower().strip().rstrip('.,!?;:') - truth_normalized = ground_truth.lower().strip().rstrip('.,!?;:') - + pred_normalized = prediction.lower().strip().rstrip(".,!?;:") + truth_normalized = ground_truth.lower().strip().rstrip(".,!?;:") + # Check if prediction is in supported answers possible_answers_lower = [ans.lower().strip() for ans in possible_answers] pred_supported = pred_normalized in possible_answers_lower truth_supported = truth_normalized in possible_answers_lower - + # Calculate F1 (exact match after normalization) f1 = 1.0 if pred_normalized == truth_normalized else 0.0 - + return { - 'f1_score': f1, - 'precision': f1, - 'recall': f1, - 'prediction_normalized': pred_normalized, - 'ground_truth_normalized': truth_normalized, - 'prediction_supported': pred_supported, - 'ground_truth_supported': truth_supported, - 'possible_answers': possible_answers, + "f1_score": f1, + "precision": f1, + "recall": f1, + "prediction_normalized": pred_normalized, + "ground_truth_normalized": truth_normalized, + "prediction_supported": pred_supported, + "ground_truth_supported": truth_supported, + "possible_answers": possible_answers, } + def calculate_template_f1_stats(data_points): """Calculate F1 statistics per template. - + Args: data_points: List of data points with template_id and metrics - + Returns: Dict with per-template and overall statistics """ if not data_points: return {} - + # Group by template_id template_groups = defaultdict(list) for point in data_points: - template_id = point.get('template_id') + template_id = point.get("template_id") if template_id is None: raise ValueError(f"Missing template_id in data point: {point}") template_groups[template_id].append(point) - + # Calculate per-template metrics template_stats = {} total_samples = 0 total_correct = 0 total_f1_sum = 0 total_macro_f1_weighted_sum = 0 - + for template_id, points in template_groups.items(): if not points: continue - + # Get possible answers for this template - required for evaluation - possible_answers = points[0].get('possible_answers', []) + possible_answers = points[0].get("possible_answers", []) if not possible_answers: raise ValueError(f"No possible answers found for template {template_id}") - + # Calculate per-template F1 stats class_predictions = {} for answer in possible_answers: class_predictions[answer.lower()] = {"tp": 0, "fp": 0, "fn": 0} - + # Count TP, FP, FN for each class in this template for point in points: gt_class = point.get("ground_truth_normalized", "") pred_class = point.get("prediction_normalized", "") pred_supported = point.get("prediction_supported", False) - + # Only count if ground truth is in supported answers if gt_class in class_predictions: if pred_class == gt_class: @@ -119,19 +121,23 @@ def calculate_template_f1_stats(data_points): # Count FP only if prediction is supported if pred_supported and pred_class in class_predictions: class_predictions[pred_class]["fp"] += 1 - + # Calculate per-class F1 for this template class_f1_scores = {} template_f1_sum = 0 valid_classes = 0 - + for class_name, counts in class_predictions.items(): tp, fp, fn = counts["tp"], counts["fp"], counts["fn"] - + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0 - f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 - + f1 = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + class_f1_scores[class_name] = { "f1": f1, "precision": precision, @@ -140,20 +146,24 @@ def calculate_template_f1_stats(data_points): "fp": fp, "fn": fn, } - + template_f1_sum += f1 valid_classes += 1 - + # Calculate macro-F1 for this template macro_f1 = template_f1_sum / valid_classes if valid_classes > 0 else 0 - + # Calculate accuracy for this template template_correct = sum(1 for point in points if point.get("accuracy", False)) template_accuracy = template_correct / len(points) if points else 0 - + # Calculate average F1 for this template - template_avg_f1 = sum(point.get("f1_score", 0) for point in points) / len(points) if points else 0 - + template_avg_f1 = ( + sum(point.get("f1_score", 0) for point in points) / len(points) + if points + else 0 + ) + template_stats[template_id] = { "num_samples": len(points), "accuracy": template_accuracy, @@ -163,19 +173,21 @@ def calculate_template_f1_stats(data_points): "num_classes": valid_classes, "correct_predictions": template_correct, } - + total_samples += len(points) total_correct += template_correct total_f1_sum += template_avg_f1 * len(points) # Weighted by number of samples total_macro_f1_weighted_sum += macro_f1 * len(points) - + # Calculate overall statistics overall_accuracy = total_correct / total_samples if total_samples > 0 else 0 overall_avg_f1 = total_f1_sum / total_samples if total_samples > 0 else 0 - + # Calculate macro-F1 across templates template_macro_f1s = [stats["macro_f1"] for stats in template_stats.values()] - overall_macro_f1 = sum(template_macro_f1s) / len(template_macro_f1s) if template_macro_f1s else 0 + overall_macro_f1 = ( + sum(template_macro_f1s) / len(template_macro_f1s) if template_macro_f1s else 0 + ) overall_macro_f1_weighted = ( total_macro_f1_weighted_sum / total_samples if total_samples > 0 else 0 ) @@ -183,9 +195,11 @@ def calculate_template_f1_stats(data_points): # Calculate unweighted average of per-template accuracies template_accuracies = [stats["accuracy"] for stats in template_stats.values()] overall_template_accuracy_avg = ( - sum(template_accuracies) / len(template_accuracies) if template_accuracies else 0 + sum(template_accuracies) / len(template_accuracies) + if template_accuracies + else 0 ) - + return { "overall": { "total_samples": total_samples, @@ -199,36 +213,38 @@ def calculate_template_f1_stats(data_points): "per_template": template_stats, } + def calculate_accuracy_stats(data_points): """Calculate overall accuracy statistics from data points""" if not data_points: return {} - + total = len(data_points) correct = sum(1 for point in data_points if point.get("accuracy", False)) accuracy_percentage = (correct / total) * 100 if total > 0 else 0 - + return { "total_samples": total, "correct_predictions": correct, "incorrect_predictions": total - correct, - "accuracy_percentage": accuracy_percentage + "accuracy_percentage": accuracy_percentage, } + def parse_ecg_qa_cot_jsonl(input_file, output_file=None): """Parse ECG-QA CoT JSONL file and extract JSON objects with per-template metrics.""" if output_file is None: input_path = Path(input_file) output_file = str(input_path.parent / f"{input_path.stem}.clean.jsonl") - + print(f"Parsing {input_file}") print(f"Output will be saved to {output_file}") - + extracted_data = extract_structured_data(input_file) - + if extracted_data: print(f"Extracted {len(extracted_data)} data points") - + # Calculate and display overall accuracy statistics accuracy_stats = calculate_accuracy_stats(extracted_data) print(f"\nOverall Accuracy Statistics:") @@ -236,23 +252,31 @@ def parse_ecg_qa_cot_jsonl(input_file, output_file=None): print(f"Correct predictions: {accuracy_stats['correct_predictions']}") print(f"Incorrect predictions: {accuracy_stats['incorrect_predictions']}") print(f"Accuracy: {accuracy_stats['accuracy_percentage']:.2f}%") - + # Calculate and display per-template F1 statistics f1_stats = calculate_template_f1_stats(extracted_data) print(f"\nOverall F1 Statistics:") overall = f1_stats.get("overall", {}) print(f"Total templates: {overall.get('total_templates', 0)}") print(f"Average F1 Score (sample-weighted): {overall.get('average_f1', 0):.4f}") - print(f"Macro-F1 Score (unweighted over templates): {overall.get('macro_f1', 0):.4f}") - print(f"Macro-F1 Score (sample-weighted over templates): {overall.get('macro_f1_weighted', 0):.4f}") - print(f"Template Accuracy Avg (unweighted): {overall.get('template_accuracy_avg', 0):.4f}") + print( + f"Macro-F1 Score (unweighted over templates): {overall.get('macro_f1', 0):.4f}" + ) + print( + f"Macro-F1 Score (sample-weighted over templates): {overall.get('macro_f1_weighted', 0):.4f}" + ) + print( + f"Template Accuracy Avg (unweighted): {overall.get('template_accuracy_avg', 0):.4f}" + ) # Final single-value summary (aggregated across all templates) print("\nFinal Results (aggregated across all templates):") print(f"Final Accuracy (micro over samples): {overall.get('accuracy', 0):.4f}") print(f"Final F1 (micro over samples): {overall.get('accuracy', 0):.4f}") - print(f"Final Macro-F1 (weighted by template size): {overall.get('macro_f1_weighted', 0):.4f}") - + print( + f"Final Macro-F1 (weighted by template size): {overall.get('macro_f1_weighted', 0):.4f}" + ) + # Display per-template statistics per_template = f1_stats.get("per_template", {}) if per_template: @@ -263,18 +287,22 @@ def parse_ecg_qa_cot_jsonl(input_file, output_file=None): print(f" Accuracy: {stats['accuracy']:.4f}") print(f" Average F1: {stats['average_f1']:.4f}") print(f" Macro-F1: {stats['macro_f1']:.4f}") - + # Show per-class F1 scores for this template - if stats['class_f1_scores']: + if stats["class_f1_scores"]: print(f" Per-class F1:") - for class_name, scores in stats['class_f1_scores'].items(): - if scores['tp'] + scores['fp'] + scores['fn'] > 0: # Only show classes with samples - print(f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}") - - with open(output_file, 'w', encoding='utf-8') as f: + for class_name, scores in stats["class_f1_scores"].items(): + if ( + scores["tp"] + scores["fp"] + scores["fn"] > 0 + ): # Only show classes with samples + print( + f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}" + ) + + with open(output_file, "w", encoding="utf-8") as f: for item in extracted_data: f.write(json.dumps(item, indent=2) + "\n") - + print(f"\nData saved to {output_file}") # Print concise overall stats at the end as a final summary @@ -284,59 +312,73 @@ def parse_ecg_qa_cot_jsonl(input_file, output_file=None): print(f"F1 (micro): {overall.get('accuracy', 0):.4f}") print(f"Macro-F1 (unweighted): {overall.get('macro_f1', 0):.4f}") print(f"Macro-F1 (weighted): {overall.get('macro_f1_weighted', 0):.4f}") - print(f"Template Accuracy Avg (unweighted): {overall.get('template_accuracy_avg', 0):.4f}") + print( + f"Template Accuracy Avg (unweighted): {overall.get('template_accuracy_avg', 0):.4f}" + ) return extracted_data else: print("No data could be extracted from the file.") return [] + def extract_structured_data(input_file): """Extract structured data from JSONL file""" data_points = [] - - with open(input_file, 'r', encoding='utf-8') as f: + + with open(input_file, "r", encoding="utf-8") as f: for line_num, line in tqdm(enumerate(f, 1), desc="Processing JSONL"): try: # Parse JSON line data = json.loads(line.strip()) - + # Extract generated and gold fields - these are required - generated_text = data.get('generated_answer') - gold_text = data.get('target_answer') - + generated_text = data.get("generated_answer") + gold_text = data.get("target_answer") + if generated_text is None: raise ValueError(f"Missing 'generated' field in line {line_num}") if gold_text is None: raise ValueError(f"Missing 'gold' field in line {line_num}") - + # Extract template_id and ecg_id - these are required for ECG-QA evaluation - template_id = data.get('template_id') - ecg_id = data.get('ecg_id', "None") - + template_id = data.get("template_id") + ecg_id = data.get("ecg_id", "None") + if template_id is None: raise ValueError(f"Missing template_id in line {line_num}") if ecg_id is None: raise ValueError(f"Missing ecg_id in line {line_num}") - + # Extract answers from both fields model_prediction_raw = extract_answer(generated_text) ground_truth_raw = extract_answer(gold_text) - + # Get possible answers for this template - required for evaluation try: - possible_answers = ECGQACoTQADataset.get_possible_answers_for_template(template_id) + possible_answers = ( + ECGQACoTQADataset.get_possible_answers_for_template(template_id) + ) except Exception as e: - raise ValueError(f"Could not get possible answers for template {template_id}: {e}") - + raise ValueError( + f"Could not get possible answers for template {template_id}: {e}" + ) + if not possible_answers: - raise ValueError(f"No possible answers found for template {template_id}") - + raise ValueError( + f"No possible answers found for template {template_id}" + ) + # Calculate F1 score with template-specific answers - f1_result = calculate_f1_score(model_prediction_raw, ground_truth_raw, possible_answers) - + f1_result = calculate_f1_score( + model_prediction_raw, ground_truth_raw, possible_answers + ) + # Calculate accuracy (exact match) - accuracy = (f1_result['prediction_normalized'] == f1_result['ground_truth_normalized']) and f1_result['ground_truth_supported'] - + accuracy = ( + f1_result["prediction_normalized"] + == f1_result["ground_truth_normalized"] + ) and f1_result["ground_truth_supported"] + data_point = { "generated": generated_text, "model_prediction": model_prediction_raw, @@ -344,15 +386,15 @@ def extract_structured_data(input_file): "template_id": template_id, "ecg_id": ecg_id, "accuracy": accuracy, - "f1_score": f1_result['f1_score'], - "precision": f1_result['precision'], - "recall": f1_result['recall'], - "prediction_normalized": f1_result['prediction_normalized'], - "ground_truth_normalized": f1_result['ground_truth_normalized'], - "prediction_supported": f1_result['prediction_supported'], - "ground_truth_supported": f1_result['ground_truth_supported'], + "f1_score": f1_result["f1_score"], + "precision": f1_result["precision"], + "recall": f1_result["recall"], + "prediction_normalized": f1_result["prediction_normalized"], + "ground_truth_normalized": f1_result["ground_truth_normalized"], + "prediction_supported": f1_result["prediction_supported"], + "ground_truth_supported": f1_result["ground_truth_supported"], "possible_answers": possible_answers, - "line_number": line_num + "line_number": line_num, } data_points.append(data_point) except json.JSONDecodeError as e: @@ -361,24 +403,30 @@ def extract_structured_data(input_file): except Exception as e: print(f"Unexpected error on line {line_num}: {e}") continue - + return data_points + def extract_answer(text): """Extract the final answer from text""" if "Answer: " not in text: return text - + answer = text.split("Answer: ")[-1].strip() # Remove any end-of-text tokens (including and <|...|>) - answer = re.sub(r'<\|.*?\|>|$', '', answer).strip() + answer = re.sub(r"<\|.*?\|>|$", "", answer).strip() # Remove trailing periods and normalize - answer = re.sub(r'\.$', '', answer).strip() + answer = re.sub(r"\.$", "", answer).strip() return answer + if __name__ == "__main__": current_dir = Path(__file__).parent - input_file = current_dir / "evaluation_results_openai-gpt-4o_ecgqacotqadataset.jsonl" - clean_output = current_dir / "evaluation_results_openai-gpt-4o_ecgqacotqadataset.clean.jsonl" - + input_file = ( + current_dir / "evaluation_results_openai-gpt-4o_ecgqacotqadataset.jsonl" + ) + clean_output = ( + current_dir / "evaluation_results_openai-gpt-4o_ecgqacotqadataset.clean.jsonl" + ) + parse_ecg_qa_cot_jsonl(input_file, clean_output) diff --git a/evaluation/opentslm/parse_predictions.py b/evaluation/opentslm/parse_predictions.py index a686fe1..8f1c67a 100644 --- a/evaluation/opentslm/parse_predictions.py +++ b/evaluation/opentslm/parse_predictions.py @@ -17,8 +17,8 @@ from collections import Counter # Add the src directory to the path to import from the dataset class -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) -sys.path.append(os.path.join(project_root, 'src')) +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.append(os.path.join(project_root, "src")) # Import the dataset class to get labels from time_series_datasets.har_cot.HARCoTQADataset import HARCoTQADataset @@ -26,23 +26,25 @@ # Get the supported labels from the dataset class SUPPORTED_LABELS = HARCoTQADataset.get_labels() + def calculate_f1_score(prediction, ground_truth): """Calculate F1 score for classification labels""" # Normalize labels for comparison (lowercase, strip whitespace and trailing punctuation) - pred_normalized = prediction.lower().strip().rstrip('.,!?;:') - truth_normalized = ground_truth.lower().strip().rstrip('.,!?;:') - + pred_normalized = prediction.lower().strip().rstrip(".,!?;:") + truth_normalized = ground_truth.lower().strip().rstrip(".,!?;:") + # For single prediction vs single ground truth, F1 is binary f1 = 1.0 if pred_normalized == truth_normalized else 0.0 - + return { - 'f1_score': f1, - 'precision': f1, # For single-label classification, precision = recall = f1 - 'recall': f1, - 'prediction_normalized': pred_normalized, - 'ground_truth_normalized': truth_normalized + "f1_score": f1, + "precision": f1, # For single-label classification, precision = recall = f1 + "recall": f1, + "prediction_normalized": pred_normalized, + "ground_truth_normalized": truth_normalized, } + def calculate_f1_stats(data_points, allowed_labels=None): """Calculate both macro-F1 and average F1 (micro-F1) statistics. @@ -53,11 +55,11 @@ def calculate_f1_stats(data_points, allowed_labels=None): """ if not data_points: return {} - + # Calculate average F1 (micro-F1) - simple average across all predictions f1_scores = [point.get("f1_score", 0) for point in data_points] average_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0 - + # Group by ground truth class for macro-F1 class_predictions = {} if allowed_labels: @@ -66,10 +68,10 @@ def calculate_f1_stats(data_points, allowed_labels=None): for point in data_points: gt_class = point.get("ground_truth_normalized", "") pred_class = point.get("prediction_normalized", "") - + if gt_class not in class_predictions: class_predictions[gt_class] = {"tp": 0, "fp": 0, "fn": 0} - + # True positive: prediction matches ground truth if pred_class == gt_class: class_predictions[gt_class]["tp"] += 1 @@ -82,73 +84,79 @@ def calculate_f1_stats(data_points, allowed_labels=None): class_predictions[pred_class]["fp"] += 1 else: class_predictions[pred_class] = {"tp": 0, "fp": 1, "fn": 0} - + # Calculate F1 per class class_f1_scores = {} total_f1 = 0 valid_classes = 0 - + for class_name, counts in class_predictions.items(): tp, fp, fn = counts["tp"], counts["fp"], counts["fn"] - + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0 - f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 - + f1 = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + class_f1_scores[class_name] = { "f1": f1, "precision": precision, "recall": recall, "tp": tp, "fp": fp, - "fn": fn + "fn": fn, } - + total_f1 += f1 valid_classes += 1 - + # Calculate macro-F1 (average across all classes) macro_f1 = total_f1 / valid_classes if valid_classes > 0 else 0 - + return { "average_f1": average_f1, "macro_f1": macro_f1, "class_f1_scores": class_f1_scores, - "total_classes": valid_classes + "total_classes": valid_classes, } + def calculate_accuracy_stats(data_points): """Calculate accuracy statistics from data points""" if not data_points: return {} - + total = len(data_points) correct = sum(1 for point in data_points if point.get("accuracy", False)) accuracy_percentage = (correct / total) * 100 if total > 0 else 0 - + return { "total_samples": total, "correct_predictions": correct, "incorrect_predictions": total - correct, - "accuracy_percentage": accuracy_percentage + "accuracy_percentage": accuracy_percentage, } - def parse_rtf_jsonl(input_file, output_file=None): """Parse RTF-formatted JSONL file and extract JSON objects.""" if output_file is None: input_path = Path(input_file) - output_file = str(input_path.parent / f"{input_path.stem.split('.')[0]}.clean.jsonl") - + output_file = str( + input_path.parent / f"{input_path.stem.split('.')[0]}.clean.jsonl" + ) + print(f"Parsing {input_file}") print(f"Output will be saved to {output_file}") - - with open(input_file, 'rb') as f: - rtf_content = f.read().decode('utf-8', errors='ignore') - + + with open(input_file, "rb") as f: + rtf_content = f.read().decode("utf-8", errors="ignore") + extracted_data = extract_structured_data(rtf_content) - + # Use the predefined supported labels from the dataset class # This ensures consistency and prevents OOV predictions from creating new classes allowed_labels = set(SUPPORTED_LABELS) @@ -159,12 +167,14 @@ def parse_rtf_jsonl(input_file, output_file=None): point["excluded"] = not is_valid_prediction if not is_valid_prediction: excluded_count += 1 - + if extracted_data: print(f"Extracted {len(extracted_data)} data points") if excluded_count > 0: - print(f"Excluded {excluded_count} predictions not in the label set from metrics") - + print( + f"Excluded {excluded_count} predictions not in the label set from metrics" + ) + # Calculate and display accuracy statistics (include all samples) accuracy_stats = calculate_accuracy_stats(extracted_data) print(f"\nAccuracy Statistics:") @@ -172,80 +182,85 @@ def parse_rtf_jsonl(input_file, output_file=None): print(f"Correct predictions: {accuracy_stats['correct_predictions']}") print(f"Incorrect predictions: {accuracy_stats['incorrect_predictions']}") print(f"Accuracy: {accuracy_stats['accuracy_percentage']:.2f}%") - + # Calculate and display F1 statistics (prevent OOV predictions from creating new classes) f1_stats = calculate_f1_stats(extracted_data, allowed_labels=allowed_labels) print(f"\nF1 Score Statistics:") print(f"Average F1 Score: {f1_stats['average_f1']:.4f}") print(f"Macro-F1 Score: {f1_stats['macro_f1']:.4f}") print(f"Total Classes: {f1_stats['total_classes']}") - + # Display per-class F1 scores - if f1_stats['class_f1_scores']: + if f1_stats["class_f1_scores"]: print(f"\nPer-Class F1 Scores:") - for class_name, scores in f1_stats['class_f1_scores'].items(): - print(f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}") - - with open(output_file, 'w', encoding='utf-8') as f: + for class_name, scores in f1_stats["class_f1_scores"].items(): + print( + f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}" + ) + + with open(output_file, "w", encoding="utf-8") as f: for item in extracted_data: f.write(json.dumps(item, indent=2) + "\n") - + print(f"\nData saved to {output_file}") return extracted_data else: print("No data could be extracted from the file.") return [] + def extract_structured_data(rtf_content): """Extract structured data from RTF content""" data_points = [] - + # Find key components generated_pattern = r'generated":\s*"(.*?)"' generated_matches = re.findall(generated_pattern, rtf_content) - + gold_pattern = r'gold":\s*"(.*?)"' gold_matches = re.findall(gold_pattern, rtf_content) - + min_length = min(len(generated_matches), len(gold_matches)) - + for i in range(min_length): model_prediction = extract_answer(generated_matches[i]).replace("", "") ground_truth = extract_answer(gold_matches[i]).replace("", "") - + # Calculate accuracy (exact match) accuracy = model_prediction == ground_truth - + # Calculate F1 score f1_result = calculate_f1_score(model_prediction, ground_truth) - + data_point = { "generated": generated_matches[i], "model_prediction": model_prediction, "ground_truth": ground_truth, "accuracy": accuracy, - "f1_score": f1_result['f1_score'], - "precision": f1_result['precision'], - "recall": f1_result['recall'], - "prediction_normalized": f1_result['prediction_normalized'], - "ground_truth_normalized": f1_result['ground_truth_normalized'] + "f1_score": f1_result["f1_score"], + "precision": f1_result["precision"], + "recall": f1_result["recall"], + "prediction_normalized": f1_result["prediction_normalized"], + "ground_truth_normalized": f1_result["ground_truth_normalized"], } data_points.append(data_point) - + return data_points + def extract_answer(text): """Extract the final answer from text""" if "Answer: " not in text: return text - + answer = text.split("Answer: ")[-1].strip() - answer = re.sub(r'<\|.*?\|>$', '', answer).strip() + answer = re.sub(r"<\|.*?\|>$", "", answer).strip() return answer + if __name__ == "__main__": current_dir = Path(__file__).parent input_file = current_dir / "gemma3_270m_sp_har.jsonl" clean_output = current_dir / "gemma3_270m_sp_har.clean.jsonl" - + parse_rtf_jsonl(input_file, clean_output) diff --git a/evaluation/opentslm/sleep/parse_sleep_cot_data.py b/evaluation/opentslm/sleep/parse_sleep_cot_data.py index bed0b88..1fb2086 100644 --- a/evaluation/opentslm/sleep/parse_sleep_cot_data.py +++ b/evaluation/opentslm/sleep/parse_sleep_cot_data.py @@ -18,7 +18,7 @@ from tqdm import tqdm # Add the src directory to the path to import from the dataset class -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "src")) from time_series_datasets.sleep.SleepEDFCoTQADataset import SleepEDFCoTQADataset # We'll determine supported labels dynamically from the actual ground truth data @@ -26,6 +26,7 @@ FALLBACK_LABELS = SleepEDFCoTQADataset.get_labels() SUPPORTED_LABELS = [] # Will be populated dynamically + def _canonicalize_label(text): """Return canonical label with stage 4 merged into stage 3. @@ -39,8 +40,8 @@ def _canonicalize_label(text): cleaned = str(text).strip() # Remove any end-of-text tokens and trailing period - cleaned = re.sub(r'<\|.*?\|>|$', '', cleaned).strip() - cleaned = re.sub(r'\.$', '', cleaned).strip() + cleaned = re.sub(r"<\|.*?\|>|$", "", cleaned).strip() + cleaned = re.sub(r"\.$", "", cleaned).strip() lowered = cleaned.lower() @@ -76,6 +77,8 @@ def _canonicalize_label(text): label_set = SUPPORTED_LABELS if SUPPORTED_LABELS else FALLBACK_LABELS is_supported = canonical in label_set return canonical if canonical else cleaned, is_supported + + def calculate_f1_score(prediction, ground_truth): """Calculate F1 score for single-label classification with supported labels. @@ -89,15 +92,16 @@ def calculate_f1_score(prediction, ground_truth): f1 = 1.0 if pred_canon == truth_canon else 0.0 return { - 'f1_score': f1, - 'precision': f1, - 'recall': f1, - 'prediction_normalized': pred_canon.lower().strip(), - 'ground_truth_normalized': truth_canon.lower().strip(), - 'prediction_supported': pred_supported, - 'ground_truth_supported': truth_supported, + "f1_score": f1, + "precision": f1, + "recall": f1, + "prediction_normalized": pred_canon.lower().strip(), + "ground_truth_normalized": truth_canon.lower().strip(), + "prediction_supported": pred_supported, + "ground_truth_supported": truth_supported, } + def calculate_f1_stats(data_points): """Calculate both macro-F1 and average F1 (micro-F1) statistics. @@ -107,16 +111,18 @@ def calculate_f1_stats(data_points): """ if not data_points: return {} - + # Calculate average F1 (micro-F1) - simple average across all predictions f1_scores = [point.get("f1_score", 0) for point in data_points] average_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0 - + # Initialize class buckets for only supported classes (lowercased for consistency) # Use discovered labels if available, otherwise fall back to dataset labels labels_to_use = SUPPORTED_LABELS if SUPPORTED_LABELS else FALLBACK_LABELS supported_lower = {label.lower(): label for label in labels_to_use} - class_predictions = {lab.lower(): {"tp": 0, "fp": 0, "fn": 0} for lab in labels_to_use} + class_predictions = { + lab.lower(): {"tp": 0, "fp": 0, "fn": 0} for lab in labels_to_use + } for point in data_points: gt_class = point.get("ground_truth_normalized", "") @@ -147,7 +153,11 @@ def calculate_f1_stats(data_points): precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0 - f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + f1 = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0 + ) # Use canonical casing in output keys pretty_name = supported_lower.get(class_name, class_name) @@ -172,45 +182,47 @@ def calculate_f1_stats(data_points): "total_classes": valid_classes, } + def calculate_accuracy_stats(data_points): """Calculate accuracy statistics from data points""" if not data_points: return {} - + total = len(data_points) correct = sum(1 for point in data_points if point.get("accuracy", False)) accuracy_percentage = (correct / total) * 100 if total > 0 else 0 - + return { "total_samples": total, "correct_predictions": correct, "incorrect_predictions": total - correct, - "accuracy_percentage": accuracy_percentage + "accuracy_percentage": accuracy_percentage, } + def parse_sleep_cot_jsonl(input_file, output_file=None): """Parse sleep COT JSONL file and extract JSON objects.""" if output_file is None: input_path = Path(input_file) output_file = str(input_path.parent / f"{input_path.stem}.clean.jsonl") - + print(f"Parsing {input_file}") print(f"Output will be saved to {output_file}") - + # First, discover the actual labels from the ground truth data global SUPPORTED_LABELS discovered_labels = discover_ground_truth_labels(input_file) SUPPORTED_LABELS = discovered_labels - + print(f"Discovered {len(discovered_labels)} labels from ground truth data:") for label in sorted(discovered_labels): print(f" - {label}") - + extracted_data = extract_structured_data(input_file) - + if extracted_data: print(f"Extracted {len(extracted_data)} data points") - + # Calculate and display accuracy statistics accuracy_stats = calculate_accuracy_stats(extracted_data) print(f"\nAccuracy Statistics:") @@ -218,89 +230,93 @@ def parse_sleep_cot_jsonl(input_file, output_file=None): print(f"Correct predictions: {accuracy_stats['correct_predictions']}") print(f"Incorrect predictions: {accuracy_stats['incorrect_predictions']}") print(f"Accuracy: {accuracy_stats['accuracy_percentage']:.2f}%") - + # Calculate and display F1 statistics f1_stats = calculate_f1_stats(extracted_data) print(f"\nF1 Score Statistics:") print(f"Average F1 Score: {f1_stats['average_f1']:.4f}") print(f"Macro-F1 Score: {f1_stats['macro_f1']:.4f}") print(f"Total Classes: {f1_stats['total_classes']}") - + # Display per-class F1 scores - if f1_stats['class_f1_scores']: + if f1_stats["class_f1_scores"]: print(f"\nPer-Class F1 Scores:") - for class_name, scores in f1_stats['class_f1_scores'].items(): - print(f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}") + for class_name, scores in f1_stats["class_f1_scores"].items(): + print( + f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}" + ) pass - - with open(output_file, 'w', encoding='utf-8') as f: + + with open(output_file, "w", encoding="utf-8") as f: for item in extracted_data: f.write(json.dumps(item, indent=2) + "\n") - + print(f"\nData saved to {output_file}") return extracted_data else: print("No data could be extracted from the file.") return [] + def discover_ground_truth_labels(input_file): """Discover actual labels from ground truth data in the JSONL file""" discovered_labels = set() - - with open(input_file, 'r', encoding='utf-8') as f: + + with open(input_file, "r", encoding="utf-8") as f: for line in f: try: data = json.loads(line.strip()) - gold_text = data.get('gold', '') + gold_text = data.get("gold", "") ground_truth_raw = extract_answer(gold_text) gt_canon, _ = _canonicalize_label(ground_truth_raw) if gt_canon: discovered_labels.add(gt_canon) except (json.JSONDecodeError, Exception): continue - + return list(discovered_labels) + def extract_structured_data(input_file): """Extract structured data from JSONL file""" data_points = [] - - with open(input_file, 'r', encoding='utf-8') as f: + + with open(input_file, "r", encoding="utf-8") as f: for line_num, line in tqdm(enumerate(f, 1)): try: # Parse JSON line data = json.loads(line.strip()) - + # Extract generated and gold fields - generated_text = data.get('generated', '') - gold_text = data.get('gold', '') - + generated_text = data.get("generated", "") + gold_text = data.get("gold", "") + # Extract answers from both fields model_prediction_raw = extract_answer(generated_text) ground_truth_raw = extract_answer(gold_text) # Canonicalize labels and merge stage 4 -> stage 3 pred_canon, pred_supported = _canonicalize_label(model_prediction_raw) gt_canon, gt_supported = _canonicalize_label(ground_truth_raw) - + # Calculate accuracy (exact match) accuracy = (pred_canon == gt_canon) and gt_supported - + # Calculate F1 score f1_result = calculate_f1_score(model_prediction_raw, ground_truth_raw) - + data_point = { "generated": generated_text, "model_prediction": model_prediction_raw, "ground_truth": ground_truth_raw, "accuracy": accuracy, - "f1_score": f1_result['f1_score'], - "precision": f1_result['precision'], - "recall": f1_result['recall'], - "prediction_normalized": f1_result['prediction_normalized'], - "ground_truth_normalized": f1_result['ground_truth_normalized'], - "prediction_supported": f1_result['prediction_supported'], - "ground_truth_supported": f1_result['ground_truth_supported'], - "line_number": line_num + "f1_score": f1_result["f1_score"], + "precision": f1_result["precision"], + "recall": f1_result["recall"], + "prediction_normalized": f1_result["prediction_normalized"], + "ground_truth_normalized": f1_result["ground_truth_normalized"], + "prediction_supported": f1_result["prediction_supported"], + "ground_truth_supported": f1_result["ground_truth_supported"], + "line_number": line_num, } data_points.append(data_point) except json.JSONDecodeError as e: @@ -309,24 +325,26 @@ def extract_structured_data(input_file): except Exception as e: print(f"Unexpected error on line {line_num}: {e}") continue - + return data_points + def extract_answer(text): """Extract the final answer from text""" if "Answer: " not in text: return text - + answer = text.split("Answer: ")[-1].strip() # Remove any end-of-text tokens (including and <|...|>) - answer = re.sub(r'<\|.*?\|>|$', '', answer).strip() + answer = re.sub(r"<\|.*?\|>|$", "", answer).strip() # Remove trailing periods and normalize - answer = re.sub(r'\.$', '', answer).strip() + answer = re.sub(r"\.$", "", answer).strip() return answer + if __name__ == "__main__": current_dir = Path(__file__).parent input_file = current_dir / "llama_1b_flamingo_predictions.jsonl" clean_output = current_dir / "llama_1b_flamingo_predictions.clean.jsonl" - + parse_sleep_cot_jsonl(input_file, clean_output) diff --git a/evaluation/opentslm/sleep/plot_sleep_predictions.py b/evaluation/opentslm/sleep/plot_sleep_predictions.py index 17cb618..c613859 100644 --- a/evaluation/opentslm/sleep/plot_sleep_predictions.py +++ b/evaluation/opentslm/sleep/plot_sleep_predictions.py @@ -22,30 +22,31 @@ OUTPUT_DIR = "sleep_cot_plots" # Publication style -plt.style.use('seaborn-v0_8') +plt.style.use("seaborn-v0_8") sns.set_palette("colorblind") # Create output directory os.makedirs(OUTPUT_DIR, exist_ok=True) display_label_map = { - 'W': 'Wake', - 'N1': 'Non-REM stage 1', - 'N2': 'Non-REM stage 2', - 'N3': 'Non-REM stage 3', - 'N4': 'Non-REM stage 4', - 'REM': 'REM sleep', - 'M': 'Movement', - 'Unknown': 'Unknown' + "W": "Wake", + "N1": "Non-REM stage 1", + "N2": "Non-REM stage 2", + "N3": "Non-REM stage 3", + "N4": "Non-REM stage 4", + "REM": "REM sleep", + "M": "Movement", + "Unknown": "Unknown", } + def plot_sample(row, idx): - eeg_data = np.array(json.loads(row['eeg_data'])) - full_pred = row['full_prediction'] - gt_label = row['ground_truth_label'] - pred_label = row['predicted_label'] - sample_idx = row['sample_index'] - series_length = row['series_length'] + eeg_data = np.array(json.loads(row["eeg_data"])) + full_pred = row["full_prediction"] + gt_label = row["ground_truth_label"] + pred_label = row["predicted_label"] + sample_idx = row["sample_index"] + series_length = row["series_length"] # Map labels to pretty names pretty_gt = display_label_map.get(gt_label, gt_label) @@ -59,7 +60,7 @@ def plot_sample(row, idx): elif len(full_pred) > text_length: # Truncate if longer full_pred = full_pred[:text_length] - + # Add extra newlines to ensure consistent text box height full_pred = full_pred + "\n" @@ -71,27 +72,42 @@ def plot_sample(row, idx): fig, ax1 = plt.subplots(figsize=(12, 7)) t = np.arange(len(eeg_plot)) # Use the same blue color as PAMAP2 plots (first color from colorblind palette) - ax1.plot(t, eeg_plot, linewidth=2.5, color='#0173B2', alpha=0.8, label='EEG') - ax1.set_xlabel('Time Step', fontsize=26) - ax1.set_ylabel('Normalized EEG Amplitude', fontsize=26) - ax1.set_title(f"Sample {sample_idx} | GT: {pretty_gt} | Pred: {pretty_pred}", fontsize=22, fontweight='bold') - ax1.legend(fontsize=13, loc='upper right') + ax1.plot(t, eeg_plot, linewidth=2.5, color="#0173B2", alpha=0.8, label="EEG") + ax1.set_xlabel("Time Step", fontsize=26) + ax1.set_ylabel("Normalized EEG Amplitude", fontsize=26) + ax1.set_title( + f"Sample {sample_idx} | GT: {pretty_gt} | Pred: {pretty_pred}", + fontsize=22, + fontweight="bold", + ) + ax1.legend(fontsize=13, loc="upper right") ax1.grid(True, alpha=0.3) - ax1.tick_params(axis='both', which='major', labelsize=26) + ax1.tick_params(axis="both", which="major", labelsize=26) ax1.set_ylim(-3, 3) ax1.set_yticks(np.linspace(-3, 3, 7)) # Add full_prediction as a text box below the plot (same as PAMAP2) - plt.gcf().text(0.01, -0.02, f"Prediction:\n{full_pred}", fontsize=30, ha='left', va='top', wrap=True, - bbox=dict(boxstyle='round', facecolor='whitesmoke', alpha=0.9, edgecolor='gray')) + plt.gcf().text( + 0.01, + -0.02, + f"Prediction:\n{full_pred}", + fontsize=30, + ha="left", + va="top", + wrap=True, + bbox=dict( + boxstyle="round", facecolor="whitesmoke", alpha=0.9, edgecolor="gray" + ), + ) plt.tight_layout(rect=[0, 0.05, 1, 1]) fname = f"sample_{idx+1:03d}_gt_{pretty_gt.lower().replace(' ', '_').replace('-', '_')}.png" - plt.savefig(os.path.join(OUTPUT_DIR, fname), dpi=300, bbox_inches='tight') + plt.savefig(os.path.join(OUTPUT_DIR, fname), dpi=300, bbox_inches="tight") plt.close() print(f"Saved {fname}") + def main(): df = pd.read_csv(CSV_PATH) print(f"Loaded {len(df)} samples from {CSV_PATH}") @@ -99,5 +115,6 @@ def main(): plot_sample(row, idx) print(f"All plots saved to {OUTPUT_DIR}/") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/get_memory_use.py b/get_memory_use.py index a13d2cf..f39f4e1 100644 --- a/get_memory_use.py +++ b/get_memory_use.py @@ -420,17 +420,23 @@ def main(): res["dataset"], res["loss"], res["peak_cuda_bytes"], - f"{peak_gb:.4f}" - if isinstance(peak_gb, float) and peak_gb >= 0 - else peak_gb, + ( + f"{peak_gb:.4f}" + if isinstance(peak_gb, float) and peak_gb >= 0 + else peak_gb + ), res.get("peak_cuda_reserved_bytes", -1), - f"{peak_reserved_gb:.4f}" - if isinstance(peak_reserved_gb, float) and peak_reserved_gb >= 0 - else peak_reserved_gb, + ( + f"{peak_reserved_gb:.4f}" + if isinstance(peak_reserved_gb, float) and peak_reserved_gb >= 0 + else peak_reserved_gb + ), res.get("nvml_peak_bytes", -1), - f"{nvml_peak_gb:.4f}" - if isinstance(nvml_peak_gb, float) and nvml_peak_gb >= 0 - else nvml_peak_gb, + ( + f"{nvml_peak_gb:.4f}" + if isinstance(nvml_peak_gb, float) and nvml_peak_gb >= 0 + else nvml_peak_gb + ), res["status"], res["error"], ], diff --git a/hf_test.py b/hf_test.py new file mode 100644 index 0000000..dba29e5 --- /dev/null +++ b/hf_test.py @@ -0,0 +1,24 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +from src import OpenTSLM, TextPrompt, TextTimeSeriesPrompt, FullPrompt + +# Load model +model = OpenTSLM.load_pretrained("OpenTSLM/gemma-3-270m-pt-har-flamingo") + +# Create prompt with raw time series data (normalization handled automatically) +prompt = FullPrompt( + pre_prompt=TextPrompt("You are an expert in HAR analysis."), + text_time_series_prompt_list=[ + TextTimeSeriesPrompt("X-axis accelerometer", [2.34, 2.34, 7.657, 3.21, -1.2]) + ], + post_prompt=TextPrompt("What activity is this? Reasn step by step providing a full rationale before replying.") +) + +# Generate response +output = model.eval_prompt(prompt, normalize=True) +print(output) diff --git a/plot_memory_simulation.py b/plot_memory_simulation.py new file mode 100644 index 0000000..250a537 --- /dev/null +++ b/plot_memory_simulation.py @@ -0,0 +1,220 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +""" +Plot memory usage on simulation datasets from memory_simulation.csv. + +- Only uses datasets starting with 'Simulation-'. +- Extracts time series length (L) and number of series (N). +- Computes total_length = N * L. +- Plots memory vs total_length per base model, comparing SoftPrompt vs Flamingo. +- OOM runs (> 180GB) are shown with a dashed line, red X, and "OOM" label. +- Always shows panels in order: gemma-270m, gemma-1b, llama-1b, llama-3b. +""" + +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +import matplotlib +import re + +OOM_THRESHOLD = 180 # GB + + +def parse_model_name(llm_id, model_type): + """Return base_model, config (SoftPrompt or Flamingo).""" + if llm_id.startswith("meta-llama/"): + base_name = llm_id.replace("meta-llama/", "") + elif llm_id.startswith("google/"): + base_name = llm_id.replace("google/", "") + else: + base_name = llm_id + + # Normalize base model names to match expected order + if "Llama-3.2-1B" in base_name: + base_name = "Llama-3.2-1B" + elif "Llama-3.2-3B" in base_name: + base_name = "Llama-3.2-3B" + elif "gemma-3-270m" in base_name: + base_name = "Gemma-3-270M" + elif "gemma-3-1b-pt" in base_name: + base_name = "Gemma-3-1B-pt" + + if model_type == "EmbedHealthSP": + type_name = "SoftPrompt" + elif model_type == "EmbedHealthFlamingo": + type_name = "Flamingo" + else: + type_name = model_type + + return base_name, type_name + + +def parse_simulation_dataset(name): + """Parse Simulation dataset name like 'Simulation-L10-N5' β†’ (L=10, N=5).""" + match = re.match(r"Simulation-L(\d+)-N(\d+)", name) + if match: + return int(match.group(1)), int(match.group(2)) + return None, None + + +def plot_memory_usage_sim(csv_file="memory_simulation.csv"): + # --- Paper-style settings --- + plt.style.use("seaborn-v0_8-white") + matplotlib.rcParams.update({ + "font.family": "serif", + "font.serif": ["Palatino", "Times New Roman", "DejaVu Serif"], + "font.size": 16, + "axes.labelsize": 18, + "axes.titlesize": 18, + "legend.fontsize": 15, + "xtick.labelsize": 15, + "ytick.labelsize": 15, + "axes.linewidth": 0.6, + "axes.edgecolor": "0.15", + }) + + df = pd.read_csv(csv_file) + + # Replace -1 with NaN (ignore failed runs) + df["peak_cuda_reserved_gb"] = df["peak_cuda_reserved_gb"].replace(-1, pd.NA) + + # Keep only simulation datasets + df = df[df["dataset"].str.startswith("Simulation-")] + + # Parse model name and dataset details + df[["base_model", "config"]] = df.apply( + lambda row: pd.Series(parse_model_name(row["llm_id"], row["model"])), axis=1 + ) + df[["L", "N"]] = df["dataset"].apply( + lambda s: pd.Series(parse_simulation_dataset(s)) + ) + df = df.dropna(subset=["L", "N"]) + df["L"] = df["L"].astype(int) + df["N"] = df["N"].astype(int) + + # Compute total sequence length + df["total_length"] = df["L"] * df["N"] + + # Sort + df = df.sort_values(by=["base_model", "config", "total_length"]) + + # Fixed base_model order + base_model_order = ["Gemma-3-270M", "Gemma-3-1B-pt", "Llama-3.2-1B", "Llama-3.2-3B"] + + # One subplot per model (always 4) + n_models = len(base_model_order) + fig, axes = plt.subplots(1, n_models, figsize=(3.2 * n_models, 4.5), sharey=True) + + if n_models == 1: + axes = [axes] + + # Muted palette for configs - order matters for legend + palette = {"SoftPrompt": "#4477AA", "Flamingo": "#CC6677"} + config_order = ["SoftPrompt", "Flamingo"] + + for ax, base_model in zip(axes, base_model_order): + subdf = df[df["base_model"] == base_model] + + if subdf.empty: + ax.set_title(base_model, fontsize=13, fontweight="bold") + ax.set_facecolor("#F8F9FA") + ax.text( + 0.5, 0.5, "No data", + ha="center", va="center", + fontsize=10, color="gray" + ) + ax.set_xticks([]) + ax.set_yticks([]) + continue + + for cfg in config_order: + cfg_df = subdf[subdf["config"] == cfg] + if cfg_df.empty: + continue + cfg_df = cfg_df.sort_values("total_length") + color = palette[cfg] + + # Successful runs (≀ threshold) + ok_df = cfg_df[cfg_df["peak_cuda_reserved_gb"] <= OOM_THRESHOLD] + ax.plot( + ok_df["total_length"], ok_df["peak_cuda_reserved_gb"], + label=cfg, + color=color, + linewidth=3.0, + alpha=0.9, + ) + + # First OOM run (if any) + oom_df = cfg_df[cfg_df["peak_cuda_reserved_gb"] > OOM_THRESHOLD] + if not oom_df.empty and not ok_df.empty: + first_oom = oom_df.iloc[0] + last_ok = ok_df.iloc[-1] + + # dashed line up to OOM + ax.plot( + [last_ok["total_length"], first_oom["total_length"]], + [last_ok["peak_cuda_reserved_gb"], OOM_THRESHOLD * 1.05], + color=color, + linestyle="--", + linewidth=1.5, + alpha=0.8, + ) + + # red X marker + ax.scatter( + first_oom["total_length"], OOM_THRESHOLD * 1.05, + color="red", + marker="x", + s=70, + linewidth=2, + zorder=5, + ) + ax.text( + first_oom["total_length"], OOM_THRESHOLD * 1.05, + "OOM", color="red", fontsize=9, + fontweight="bold", ha="center", va="bottom" + ) + + # Titles & labels + ax.set_title(base_model, fontsize=17, fontweight="bold") + + # Only show axis labels on specific subplots + if ax == axes[0]: # Leftmost subplot + ax.set_ylabel("Peak CUDA Reserved (GB)", fontsize=16, fontweight="bold") + ax.set_xlabel("Total Sequence Length (N Γ— L)", fontsize=16, fontweight="bold") + else: + ax.set_ylabel("") + ax.set_xlabel("") + ax.set_facecolor("#F8F9FA") + ax.grid(True, which="major", linestyle="-", linewidth=0.4, alpha=0.5) + ax.grid(True, which="minor", linestyle=":", linewidth=0.3, alpha=0.3) + ax.minorticks_on() + ax.tick_params(axis="both", labelsize=15) + + # Legend only in first subplot + if ax == axes[0]: + leg = ax.legend(title=None, fontsize=15, loc="best", frameon=True, + framealpha=0.95, edgecolor="0.3") + for text in leg.get_texts(): + text.set_fontweight("bold") + + plt.tight_layout(pad=0.5) + for fmt in ["png", "pdf"]: + plt.savefig( + f"memory_usage_simulation.{fmt}", + dpi=300 if fmt == "png" else None, + bbox_inches="tight", + pad_inches=0, + facecolor="white", + format=fmt, + ) + plt.show() + + +if __name__ == "__main__": + plot_memory_usage_sim() diff --git a/plot_memory_simulation_per_length.py b/plot_memory_simulation_per_length.py new file mode 100644 index 0000000..3431a14 --- /dev/null +++ b/plot_memory_simulation_per_length.py @@ -0,0 +1,286 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +""" +Paper-style plots: memory usage scaling with N for different lengths (L). + +- Rows = config (SoftPrompt, Flamingo) +- Cols = sequence lengths (L) [excluding L=1] +- Hue = base model +- Y-axis sharing logic: + * Flamingo: all panels share y-axis + * SoftPrompt: all panels have independent y-axes +- OOM cases (peak_cuda_reserved_gb > OOM_THRESHOLD) are shown by extending the + line upward and marking with a red X + "OOM". +""" + +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +import matplotlib +import re +from matplotlib.lines import Line2D + +OOM_THRESHOLD = 180 # GB + + +def parse_model_name(llm_id, model_type): + """Return base_model, config (SoftPrompt or Flamingo).""" + if llm_id.startswith("meta-llama/"): + base_name = llm_id.replace("meta-llama/", "") + elif llm_id.startswith("google/"): + base_name = llm_id.replace("google/", "") + else: + base_name = llm_id + + # Normalize base model names to match expected order + if "Llama-3.2-1B" in base_name: + base_name = "Llama-3.2-1B" + elif "Llama-3.2-3B" in base_name: + base_name = "Llama-3.2-3B" + elif "gemma-3-270m" in base_name: + base_name = "Gemma-3-270M" + elif "gemma-3-1b-pt" in base_name: + base_name = "Gemma-3-1B-pt" + + if model_type == "EmbedHealthSP": + type_name = "SoftPrompt" + elif model_type == "EmbedHealthFlamingo": + type_name = "Flamingo" + else: + type_name = model_type + + return base_name, type_name + + +def parse_simulation_dataset(name): + """Parse Simulation dataset name like 'Simulation-L10-N5' β†’ (L=10, N=5).""" + match = re.match(r"Simulation-L(\d+)-N(\d+)", name) + if match: + return int(match.group(1)), int(match.group(2)) + return None, None + + +def plot_memory_usage_paper(csv_file="memory_simulation.csv"): + # Publication style + plt.style.use("seaborn-v0_8-white") + matplotlib.rcParams.update({ + "font.family": "serif", + "font.serif": ["Palatino", "Times New Roman", "DejaVu Serif"], + "font.size": 18, + "axes.labelsize": 20, + "axes.titlesize": 20, + "legend.fontsize": 17, + "xtick.labelsize": 17, + "ytick.labelsize": 17, + "axes.linewidth": 0.6, + "axes.edgecolor": "0.15", + }) + + # Load & preprocess + df = pd.read_csv(csv_file) + df["peak_cuda_reserved_gb"] = df["peak_cuda_reserved_gb"].replace(-1, pd.NA) + df = df[df["dataset"].str.startswith("Simulation-")] + df[["base_model", "config"]] = df.apply( + lambda row: pd.Series(parse_model_name(row["llm_id"], row["model"])), axis=1 + ) + df[["L", "N"]] = df["dataset"].apply( + lambda s: pd.Series(parse_simulation_dataset(s)) + ) + df = df.dropna(subset=["L", "N"]) + df["L"] = df["L"].astype(int) + df["N"] = df["N"].astype(int) + df = df[df["L"] != 1] + df = df.sort_values(by=["base_model", "config", "L", "N"]) + + # Palette + markers - use consistent colors with the other script + # Define consistent order and colors for each model + model_order = ["Gemma-3-270M", "Gemma-3-1B-pt", "Llama-3.2-1B", "Llama-3.2-3B"] + color_map = { + "Gemma-3-270M": "#4477AA", + "Gemma-3-1B-pt": "#66CCEE", + "Llama-3.2-1B": "#228833", + "Llama-3.2-3B": "#CC6677" + } + + # Get base models in the specified order + base_models = [bm for bm in model_order if bm in df["base_model"].unique()] + custom_palette = [color_map.get(bm, "#888888") for bm in base_models] + markers_dict = dict(zip( + base_models, + ["o", "s", "^", "D", "p", "X", "*"] + )) + + # Unique sequence lengths + unique_L = sorted(df["L"].unique()) + + # Create subplot grid manually: 2 rows (SoftPrompt, Flamingo) + fig, axes = plt.subplots( + 2, len(unique_L), + figsize=(3.2 * len(unique_L), 6), + sharex="col", + ) + + # Row mapping + row_map = {"SoftPrompt": 0, "Flamingo": 1} + + # Precompute Flamingo y-lims + flamingo_df = df[df["config"] == "Flamingo"] + flamingo_ymin, flamingo_ymax = None, None + if not flamingo_df.empty: + flamingo_ymin = flamingo_df["peak_cuda_reserved_gb"].min(skipna=True) + flamingo_ymax = flamingo_df["peak_cuda_reserved_gb"].max(skipna=True) + + flamingo_ymin = 0 + flamingo_ymax = max(flamingo_ymax if flamingo_ymax else 0, OOM_THRESHOLD * 1.1) + + # Iterate configs + for cfg in ["SoftPrompt", "Flamingo"]: + cfg_df = df[df["config"] == cfg] + + for j, L in enumerate(unique_L): + ax = axes[row_map[cfg], j] + subdf = cfg_df[cfg_df["L"] == L] + + for bm, sdf in subdf.groupby("base_model"): + sdf = sdf.sort_values("N") + + # Split successful vs OOM runs + ok_df = sdf[sdf["peak_cuda_reserved_gb"] <= OOM_THRESHOLD] + oom_df = sdf[sdf["peak_cuda_reserved_gb"] > OOM_THRESHOLD] + + # Normal line + if not ok_df.empty: + ax.plot( + ok_df["N"], ok_df["peak_cuda_reserved_gb"], + label=bm, + color=custom_palette[base_models.index(bm)], + marker=markers_dict[bm], + linewidth=2.2, + markersize=5, + alpha=0.9, + ) + + # OOM handling + if not oom_df.empty: + first_oom = oom_df.iloc[0] + last_ok_y = ok_df["peak_cuda_reserved_gb"].iloc[-1] if not ok_df.empty else OOM_THRESHOLD * 0.9 + + # extend line upward + ax.plot( + [ok_df["N"].iloc[-1] if not ok_df.empty else first_oom["N"], first_oom["N"]], + [last_ok_y, OOM_THRESHOLD * 1.05], + color=custom_palette[base_models.index(bm)], + linestyle="--", + linewidth=1.5, + alpha=0.8, + ) + + # red X marker at OOM + ax.scatter( + first_oom["N"], OOM_THRESHOLD * 1.05, + color="red", + marker="x", + s=70, + linewidth=2, + zorder=5, + ) + ax.text( + first_oom["N"], OOM_THRESHOLD * 1.05, + "OOM", color="red", fontsize=9, + fontweight="bold", ha="center", va="bottom" + ) + + # Titles + if row_map[cfg] == 0: + ax.set_title(f"L = {L}", fontsize=19, fontweight="bold") + + # Y labels only leftmost col + if j == 0: + ax.set_ylabel( + f"{cfg}", + fontsize=18, + fontweight="bold" + ) + else: + ax.set_ylabel("") + + # Styling + ax.set_facecolor("#F8F9FA") + ax.grid(True, which="major", linestyle="-", linewidth=0.4, alpha=0.5) + ax.grid(True, which="minor", linestyle=":", linewidth=0.3, alpha=0.3) + ax.minorticks_on() + # Remove individual x-axis labels - will add global one later + ax.set_xlabel("") + ax.set_xticks([1, 2, 3, 4, 5]) + ax.tick_params(axis="x", which="both", labelbottom=True) + + # Y-axis rules + if cfg == "Flamingo": + ax.set_ylim(0, 65) + elif j <= 2: + ax.set_ylim(0, 25) + + # Legend (global, right side) - create in specified order + # Get handles and labels from first subplot + handles_dict = {} + labels_dict = {} + for handle, label in zip(*axes[0, 0].get_legend_handles_labels()): + handles_dict[label] = handle + labels_dict[label] = label + + # Create ordered handles and labels + ordered_handles = [] + ordered_labels = [] + for bm in base_models: + if bm in handles_dict: + ordered_handles.append(handles_dict[bm]) + ordered_labels.append(labels_dict[bm]) + + # Add OOM handle + oom_handle = Line2D([0], [0], color="red", marker="x", linestyle="--", + markersize=8, label="Out of Memory (OOM)") + ordered_handles.append(oom_handle) + ordered_labels.append("Out of Memory (OOM)") + + # Legend in top left plot only + top_left_ax = axes[0, 0] + top_left_ax.legend( + ordered_handles, ordered_labels, + title=None, + loc="upper right", + frameon=True, + framealpha=0.95, + edgecolor="0.3", + fontsize=12, + ) + + # Add main vertical title + fig.text(0.02, 0.5, "Peak Memory (GB)", rotation=90, fontsize=20, fontweight="bold", ha="center", va="center") + + # Add global x-axis label spanning the bottom row + fig.text(0.5, 0.01, "Number of Time Series (N)", fontsize=16, fontweight="bold", ha="center", va="center") + + # Layout - no longer need space for bottom legend + plt.tight_layout(pad=0.5) + plt.subplots_adjust(left=0.08) + + # Save + for fmt in ["png", "pdf"]: + plt.savefig( + f"memory_usage_paper.{fmt}", + dpi=300 if fmt == "png" else None, + bbox_inches="tight", + pad_inches=0, + facecolor="white", + format=fmt, + ) + plt.show() + + +if __name__ == "__main__": + plot_memory_usage_paper() diff --git a/src/model/__init__.py b/src/model/__init__.py index 7f7c8a3..7566f7d 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -4,4 +4,4 @@ # SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) # # SPDX-License-Identifier: MIT -# \ No newline at end of file +# diff --git a/src/model/encoder/CNNTokenizer.py b/src/model/encoder/CNNTokenizer.py index 74338ec..bee3780 100644 --- a/src/model/encoder/CNNTokenizer.py +++ b/src/model/encoder/CNNTokenizer.py @@ -21,7 +21,7 @@ def __init__( dropout: float = 0.0, transformer_input_dim: int = TRANSFORMER_INPUT_DIM, patch_size: int = PATCH_SIZE, - max_patches: int = 2600, + max_patches: int = 1024, ): """ Args: diff --git a/src/model/encoder/TransformerCNNEncoder.py b/src/model/encoder/TransformerCNNEncoder.py index d5e6a28..5c96844 100644 --- a/src/model/encoder/TransformerCNNEncoder.py +++ b/src/model/encoder/TransformerCNNEncoder.py @@ -24,7 +24,7 @@ def __init__( num_layers: int = 6, patch_size: int = PATCH_SIZE, ff_dim: int = 1024, - max_patches: int = 2600, + max_patches: int = 1024, ): """ Args: diff --git a/src/model/encoder/TransformerMLPEncoder.py b/src/model/encoder/TransformerMLPEncoder.py index efe4328..6aa59e5 100644 --- a/src/model/encoder/TransformerMLPEncoder.py +++ b/src/model/encoder/TransformerMLPEncoder.py @@ -24,7 +24,7 @@ def __init__( num_layers: int = 6, patch_size: int = PATCH_SIZE, ff_dim: int = 2048, - max_patches: int = 2600, + max_patches: int = 1024, ): """ Args: diff --git a/src/model/encoder/__init__.py b/src/model/encoder/__init__.py index 7f7c8a3..7566f7d 100644 --- a/src/model/encoder/__init__.py +++ b/src/model/encoder/__init__.py @@ -4,4 +4,4 @@ # SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) # # SPDX-License-Identifier: MIT -# \ No newline at end of file +# diff --git a/src/model/llm/OpenTSLM.py b/src/model/llm/OpenTSLM.py new file mode 100644 index 0000000..d705e9b --- /dev/null +++ b/src/model/llm/OpenTSLM.py @@ -0,0 +1,175 @@ +# +# This source file is part of the OpenTSLM open-source project +# +# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# +# SPDX-License-Identifier: MIT +# +import torch +from typing import Optional, Union +from enum import Enum +from huggingface_hub import hf_hub_download + +from .OpenTSLMSP import OpenTSLMSP +from .OpenTSLMFlamingo import OpenTSLMFlamingo + + +class ModelType(Enum): + """Enumeration of supported model types.""" + + SP = "sp" + FLAMINGO = "flamingo" + + +class OpenTSLM: + """ + Factory class for loading EmbedHealth models from Hugging Face Hub. + + Automatically detects model type based on repository ID suffix and returns + the appropriate model instance (EmbedHealthSP or EmbedHealthFlamingo) with + optimal parameters from curriculum learning training. + + - Repository IDs ending with "-sp" load EmbedHealthSP models + - Repository IDs ending with "-flamingo" load EmbedHealthFlamingo models + + The factory automatically applies the exact same parameters used in curriculum learning: + - EmbedHealthSP: Uses default constructor parameters + - EmbedHealthFlamingo: cross_attn_every_n_layers=1, gradient_checkpointing=False + + These parameters are fixed and cannot be overridden since they were determined during training. + + Example: + >>> model = OpenTSLM.load_pretrained("OpenTSLM/gemma-3-270m-pt-sleep-flamingo") + >>> + >>> from prompt.full_prompt import FullPrompt + >>> prompt = FullPrompt(...) + >>> response = model.eval_prompt(prompt) + """ + + @classmethod + def load_pretrained( + cls, + repo_id: str, + device: Optional[str] = None, + cache_dir: Optional[str] = None, + enable_lora: Optional[bool] = False, + ) -> Union[OpenTSLMSP, OpenTSLMFlamingo]: + """ + Load a pretrained model from Hugging Face Hub. + + Args: + repo_id: Hugging Face repository ID (e.g., "OpenTSLM/gemma-3-270m-pt-sleep-flamingo") + device: Device to load the model on (default: auto-detect) + cache_dir: Directory to cache downloaded models (optional) + enable_lora: Whether to enable LoRA (default: False) + + Returns: + Union[OpenTSLMSP, OpenTSLMFlamingo]: The loaded model instance + + Example: + >>> model = OpenTSLM.load_pretrained("OpenTSLM/gemma-3-270m-pt-sleep-flamingo") + >>> prompt = FullPrompt(...) + >>> response = model.eval_prompt(prompt) + """ + device = cls._get_device(device) + model_type = cls._detect_model_type(repo_id) + checkpoint_path = cls._download_model_files(repo_id, cache_dir) + base_llm_id = cls._get_base_llm_id(repo_id) + + print(f"πŸš€ Loading {model_type.value.upper()} model...") + print(f" Repository: {repo_id}") + print(f" Base LLM: {base_llm_id}") + print(f" Device: {device}") + + # Instantiate model with fixed training parameters + if model_type == ModelType.SP: + # OpenTSLMSP uses default parameters from curriculum learning + model = OpenTSLMSP(llm_id=base_llm_id, device=device) + if enable_lora: + model.enable_lora() + elif model_type == ModelType.FLAMINGO: + # OpenTSLMFlamingo with fixed parameters from curriculum learning + model = OpenTSLMFlamingo( + device=device, + llm_id=base_llm_id, + cross_attn_every_n_layers=1, + gradient_checkpointing=False, + ) + else: + raise ValueError(f"Unknown model type: {model_type}") + + # Load the checkpoint + model.load_from_file(checkpoint_path) + model.eval() + + print(f"βœ… {model_type.value.upper()} model loaded successfully!") + return model + + @staticmethod + def _get_device(device: Optional[str]) -> str: + """Auto-detect device if not specified.""" + if device is not None: + return device + + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + + @staticmethod + def _detect_model_type(repo_id: str) -> ModelType: + """Detect model type from repository ID suffix.""" + if repo_id.endswith("-sp"): + return ModelType.SP + elif repo_id.endswith("-flamingo"): + return ModelType.FLAMINGO + else: + raise ValueError( + f"Repository ID '{repo_id}' must end with either '-sp' or '-flamingo' " + f"to indicate the model type." + ) + + @staticmethod + def _download_model_files(repo_id: str, cache_dir: Optional[str] = None) -> str: + """Download model checkpoint from Hugging Face Hub.""" + try: + # Download the main model checkpoint file + checkpoint_path = hf_hub_download( + repo_id=repo_id, + filename="model_checkpoint.pt", + cache_dir=cache_dir, + local_files_only=False, + ) + print(f"βœ… Downloaded model checkpoint from {repo_id}") + return checkpoint_path + + except Exception as e: + raise RuntimeError( + f"Failed to download model from {repo_id}. " + f"Tried 'model_checkpoint.pt'. " + f"Original error: {e}" + ) + + @staticmethod + def _get_base_llm_id(repo_id: str) -> str: + """Get the base LLM ID from static mapping based on repository ID pattern.""" + repo_name = repo_id.split("/")[-1] if "/" in repo_id else repo_id + + # Extract base model from repository name pattern + if repo_name.startswith("llama-3.2-3b"): + return "meta-llama/Llama-3.2-3B" + elif repo_name.startswith("llama-3.2-1b"): + return "meta-llama/Llama-3.2-1B" + elif repo_name.startswith("gemma-3-1b"): + return "google/gemma-3-1b" + elif repo_name.startswith("gemma-3-270m"): + return "google/gemma-3-270m" + else: + # Raise exception if pattern doesn't match + raise ValueError( + f"Unable to determine base LLM ID from repository name '{repo_name}'. " + f"Repository name must start with one of: 'llama-3.2-3b', 'llama-3.2-1b', " + f"'gemma-3-1b', or 'gemma-3-270m'." + ) diff --git a/src/model/llm/OpenTSLMFlamingo.py b/src/model/llm/OpenTSLMFlamingo.py index 6f43181..58517c7 100644 --- a/src/model/llm/OpenTSLMFlamingo.py +++ b/src/model/llm/OpenTSLMFlamingo.py @@ -349,19 +349,24 @@ def load_from_file(self, path: str = "best_model.pt"): print(f" ... and {len(unexpected_keys) - 10} more keys") self.to(self.device) - def eval_prompt(self, prompt: FullPrompt, max_new_tokens: int = 30000) -> str: + def eval_prompt( + self, prompt: FullPrompt, max_new_tokens: int = 1000, normalize: bool = False + ) -> str: """ Evaluate a prompt and return the generated text. """ # Temporarily disable compilation to avoid data-dependent operation issues original_disable = torch._dynamo.config.disable torch._dynamo.config.disable = True - try: batch = [prompt.to_dict()] self.eval() - batch = extend_time_series_to_match_patch_size_and_aggregate(batch) + batch = extend_time_series_to_match_patch_size_and_aggregate( + batch, normalize=normalize + ) + print("Generating") output = self.generate(batch, max_new_tokens=max_new_tokens) + print(f"Generated output: {output[0]}") return output[0] finally: # Restore original compilation setting diff --git a/src/model/llm/OpenTSLMSP.py b/src/model/llm/OpenTSLMSP.py index 5bc4302..d594433 100644 --- a/src/model/llm/OpenTSLMSP.py +++ b/src/model/llm/OpenTSLMSP.py @@ -21,9 +21,9 @@ print("Warning: peft not available. LoRA fine-tuning will be disabled.") from model_config import ENCODER_OUTPUT_DIM -from model.llm.TimeSeriesLLM import TimeSeriesLLM -from model.encoder.TransformerCNNEncoder import TransformerCNNEncoder -from model.projector.MLPProjector import MLPProjector +from .TimeSeriesLLM import TimeSeriesLLM +from ..encoder.TransformerCNNEncoder import TransformerCNNEncoder +from ..projector.MLPProjector import MLPProjector from prompt.full_prompt import FullPrompt from time_series_datasets.util import ( extend_time_series_to_match_patch_size_and_aggregate, @@ -493,13 +493,17 @@ def save_lora_state_to_checkpoint(self, checkpoint: dict): return 0 - def eval_prompt(self, prompt: FullPrompt, max_new_tokens: int = 30000) -> str: + def eval_prompt( + self, prompt: FullPrompt, max_new_tokens: int = 30000, normalize: bool = False + ) -> str: """ Evaluate a prompt and return the generated text. """ batch = [prompt.to_dict()] self.eval() - batch = extend_time_series_to_match_patch_size_and_aggregate(batch) + batch = extend_time_series_to_match_patch_size_and_aggregate( + batch, normalize=normalize + ) output = self.generate(batch, max_new_tokens=max_new_tokens) return output[0] diff --git a/src/model/llm/__init__.py b/src/model/llm/__init__.py index 7f7c8a3..7566f7d 100644 --- a/src/model/llm/__init__.py +++ b/src/model/llm/__init__.py @@ -4,4 +4,4 @@ # SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) # # SPDX-License-Identifier: MIT -# \ No newline at end of file +# diff --git a/src/model/projector/__init__.py b/src/model/projector/__init__.py index 7f7c8a3..7566f7d 100644 --- a/src/model/projector/__init__.py +++ b/src/model/projector/__init__.py @@ -4,4 +4,4 @@ # SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) # # SPDX-License-Identifier: MIT -# \ No newline at end of file +# diff --git a/src/time_series_datasets/har_cot/HARAccQADataset.py b/src/time_series_datasets/har_cot/HARAccQADataset.py index d1ee850..8250b03 100644 --- a/src/time_series_datasets/har_cot/HARAccQADataset.py +++ b/src/time_series_datasets/har_cot/HARAccQADataset.py @@ -28,8 +28,16 @@ class HARAccQADataset(QADataset): - def __init__(self, split: Literal["train", "test", "validation"], EOS_TOKEN: str, format_sample_str: bool = False, time_series_format_function=None): - super().__init__(split, EOS_TOKEN, format_sample_str, time_series_format_function) + def __init__( + self, + split: Literal["train", "test", "validation"], + EOS_TOKEN: str, + format_sample_str: bool = False, + time_series_format_function=None, + ): + super().__init__( + split, EOS_TOKEN, format_sample_str, time_series_format_function + ) def _load_splits(self) -> Tuple[Dataset, Dataset, Dataset]: """ @@ -107,7 +115,9 @@ def _format_sample(self, row): dataset_val = HARAccQADataset(split="validation", EOS_TOKEN="") dataset_test = HARAccQADataset(split="test", EOS_TOKEN="") - print(f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}") + print( + f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}" + ) dataloader = DataLoader( dataset_test, diff --git a/src/time_series_datasets/m4/m4_loader.py b/src/time_series_datasets/m4/m4_loader.py index 14fc5ea..34c8974 100644 --- a/src/time_series_datasets/m4/m4_loader.py +++ b/src/time_series_datasets/m4/m4_loader.py @@ -28,13 +28,16 @@ from typing import Dict, List, Literal, Optional, Tuple from datasets import Dataset from sklearn.model_selection import train_test_split +from time_series_datasets.constants import RAW_DATA # --------------------------- # Constants # --------------------------- RELEASE_URL = "https://polybox.ethz.ch/index.php/s/MT3y9WdEebT8wfj/download/M4TimeSeriesCaptionDatasetV02.zip" -DATA_DIR = "data/M4TimeSeriesCaptionDataset" + + +DATA_DIR = os.path.join(RAW_DATA, "M4TimeSeriesCaptionDataset") GENERATED_DATA_DIR = os.path.join(DATA_DIR, "M4TimeSeriesCaptionDataset") AVAILABLE_FREQUENCIES = ["Daily", "Hourly", "Monthly", "Quarterly", "Weekly", "Yearly"] diff --git a/src/time_series_datasets/sleep/SleepEDFQADataset.py b/src/time_series_datasets/sleep/SleepEDFQADataset.py index 6bba122..24c7ead 100644 --- a/src/time_series_datasets/sleep/SleepEDFQADataset.py +++ b/src/time_series_datasets/sleep/SleepEDFQADataset.py @@ -10,15 +10,27 @@ from typing import List, Tuple, Literal import sys import os -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) + +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) from prompt.text_time_series_prompt import TextTimeSeriesPrompt from time_series_datasets.QADataset import QADataset from time_series_datasets.sleep.sleepedf_cot_loader import load_sleepedf_cot_splits import numpy as np + class SleepEDFCoTQADataset(QADataset): - def __init__(self, split: Literal["train", "test", "validation"], EOS_TOKEN: str, format_sample_str: bool = False, time_series_format_function=None): - super().__init__(split, EOS_TOKEN, format_sample_str, time_series_format_function) + def __init__( + self, + split: Literal["train", "test", "validation"], + EOS_TOKEN: str, + format_sample_str: bool = False, + time_series_format_function=None, + ): + super().__init__( + split, EOS_TOKEN, format_sample_str, time_series_format_function + ) def _load_splits(self) -> Tuple[Dataset, Dataset, Dataset]: return load_sleepedf_cot_splits() @@ -71,13 +83,20 @@ def _get_text_time_series_prompt_list(self, row) -> List[TextTimeSeriesPrompt]: std = max(std, min_std) series_norm = (series - mean) / std text_prompt = f"The following is the EEG time series, it has mean {mean:.4f} and std {std:.4f}:" - + return [TextTimeSeriesPrompt(text_prompt, series_norm.tolist())] @staticmethod def get_labels() -> List[str]: # This could be made dynamic, but for now, use the standard sleep stages - return ["Wake", "Non-REM stage 1", "Non-REM stage 2", "Non-REM stage 3", "REM sleep", "Movement"] + return [ + "Wake", + "Non-REM stage 1", + "Non-REM stage 2", + "Non-REM stage 3", + "REM sleep", + "Movement", + ] def _format_sample(self, row): sample = super()._format_sample(row) @@ -85,15 +104,21 @@ def _format_sample(self, row): sample["original_data"] = row["time_series"] return sample + if __name__ == "__main__": dataset = SleepEDFCoTQADataset(split="train", EOS_TOKEN="") dataset_val = SleepEDFCoTQADataset(split="validation", EOS_TOKEN="") dataset_test = SleepEDFCoTQADataset(split="test", EOS_TOKEN="") - print(f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}") + print( + f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}" + ) if len(dataset) > 0: sample = dataset[0] print("Sample keys:", sample.keys()) print("Sample answer:", sample["answer"]) - print("Sample time series text:", sample["time_series_text"] if "time_series_text" in sample else "N/A") + print( + "Sample time series text:", + sample["time_series_text"] if "time_series_text" in sample else "N/A", + ) print("Sample pre prompt:", sample["pre_prompt"]) - print("Sample post prompt:", sample["post_prompt"]) \ No newline at end of file + print("Sample post prompt:", sample["post_prompt"]) diff --git a/src/time_series_datasets/util.py b/src/time_series_datasets/util.py index 6850a4d..3649245 100644 --- a/src/time_series_datasets/util.py +++ b/src/time_series_datasets/util.py @@ -19,9 +19,12 @@ def extend_time_series_to_match_patch_size_and_aggregate( - batch, *, patch_size: int = PATCH_SIZE + batch, *, patch_size: int = PATCH_SIZE, normalize: bool = False ): - """Pad variable-length series so each sample length is a multiple of *patch_size*.""" + """ + Pad variable-length series so each sample length is a multiple of *patch_size*. + Optionally normalize each time series to have zero mean and unit variance. + """ for element in batch: # 1) pull out the list of (1D) time‑series @@ -30,13 +33,26 @@ def extend_time_series_to_match_patch_size_and_aggregate( # 2) convert each to a torch.Tensor (float) ts_tensors = [torch.as_tensor(ts, dtype=torch.float32) for ts in ts_list] - # 3) find the longest series length + # 3) normalize each time series if requested + if normalize: + normalized_tensors = [] + for ts in ts_tensors: + mean = ts.mean() + std = ts.std() + if std > 1e-8: # Avoid division by zero + ts_normalized = (ts - mean) / std + else: + ts_normalized = ts - mean + normalized_tensors.append(ts_normalized) + ts_tensors = normalized_tensors + + # 4) find the longest series length max_len = max([ts.size(0) for ts in ts_tensors]) - # 4) round up to nearest multiple of patch_size + # 5) round up to nearest multiple of patch_size padded_len = ((max_len + patch_size - 1) // patch_size) * patch_size - # 5) pad (or trim) each series to padded_len + # 6) pad (or trim) each series to padded_len padded = [] for ts in ts_tensors: L = ts.size(0) @@ -47,7 +63,7 @@ def extend_time_series_to_match_patch_size_and_aggregate( ts = ts[:padded_len] padded.append(ts) - # 6) stack into a single 2D tensor: (num_series, padded_len) + # 7) stack into a single 2D tensor: (num_series, padded_len) element["time_series"] = torch.stack(padded, dim=0) return batch