diff --git a/.gitignore b/.gitignore
index 49b785f..b397a0b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,4 +8,5 @@ raw_data
*.ts
*.zip
-./__pycache__
\ No newline at end of file
+./__pycache__
+upload_to_huggingface.py
diff --git a/.linkspector.yml b/.linkspector.yml
new file mode 100644
index 0000000..9de8111
--- /dev/null
+++ b/.linkspector.yml
@@ -0,0 +1,15 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+
+dirs:
+ - .
+useGitIgnore: true
+ignorePatterns:
+ - pattern: "doc:/"
+ - pattern: "http://localhost"
+ - pattern: "https://doi.org"
diff --git a/README.md b/README.md
index 98e6931..35a8c0d 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,6 @@ OpenTSLM models can reason over multiple time series of any length at once, gene
-
## Installation
1. **Clone the Repository**
@@ -80,7 +79,83 @@ OpenTSLM has been tested and works with the following models:
Other variants may work but have not been extensively tested.
-## Multi-stage training (Curriculum)
+
+## π Quickstart with pretrained models
+
+EmbedHealth provides a factory class called `OpenTSLM` for easily loading pre-trained models from Hugging Face Hub. The `load_pretrained` method automatically detects the model type and returns the appropriate model instance.
+
+
+```python
+from src import OpenTSLM, TextPrompt, TextTimeSeriesPrompt, FullPrompt
+
+# Load model
+model = OpenTSLM.load_pretrained("OpenTSLM/gemma-3-270m-pt-har-flamingo")
+
+# Create prompt with raw time series data (normalization handled automatically)
+prompt = FullPrompt(
+ pre_prompt=TextPrompt("You are an expert in HAR analysis."),
+ text_time_series_prompt_list=[
+ TextTimeSeriesPrompt("X-axis accelerometer", [2.34, 2.34, 7.657, 3.21, -1.2])
+ ],
+ post_prompt=TextPrompt("What activity is this? Reasn step by step providing a full rationale before replying.")
+)
+
+# Generate response
+output = model.eval_prompt(prompt, normalize=True)
+print(output)
+```
+
+### π€ HuggingFace Demo Scripts
+
+We provide ready-to-use demo scripts in the `demo/huggingface/` directory that demonstrate how to load pretrained models from HuggingFace Hub and run inference on the evaluation sets for each task:
+
+- **`01_test_hf_tsqa.py`** - Test TSQA (Time Series Question Answering) models
+- **`02_test_hf_m4.py`** - Test M4 (Time Series Captioning) models
+- **`03_test_hf_har_cot.py`** - Test HAR CoT (Human Activity Recognition Chain-of-Thought) models
+- **`04_test_hf_sleep_cot.py`** - Test Sleep CoT (Sleep Stage Classification) models
+- **`05_test_hf_ecg_qa_cot.py`** - Test ECG QA CoT (ECG Question Answering) models
+
+Each script:
+1. Downloads the model checkpoint from HuggingFace Hub automatically (change repo id as neededs)
+2. Loads the corresponding test dataset
+3. Runs inference on the evaluation set
+4. Prints model outputs with sample information
+
+**Note:** The scripts above use the OpenTSLM-SP models except for ECG-QA, as they require less VRAM and should run on most hardware. Change the model checkpoints as needed in each file.
+
+**Usage:**
+
+```bash
+# Run any of the demo scripts
+python demo/huggingface/01_test_hf_tsqa.py
+python demo/huggingface/02_test_hf_m4.py
+python demo/huggingface/03_test_hf_har_cot.py
+python demo/huggingface/04_test_hf_sleep_cot.py
+python demo/huggingface/05_test_hf_ecg_qa_cot.py
+```
+
+**Customizing the model:**
+
+Edit the `REPO_ID` variable at the top of each script to test different model variants. For example:
+
+```python
+# In 01_test_hf_tsqa.py
+REPO_ID = "OpenTSLM/llama-3.2-1b-tsqa-sp" # Soft Prompt model
+# or
+REPO_ID = "OpenTSLM/llama-3.2-1b-tsqa-flamingo" # Flamingo model
+```
+
+**Available models on HuggingFace:**
+
+All pretrained models are available under the `OpenTSLM` organization on HuggingFace Hub. Model names follow the pattern:
+- `OpenTSLM/{base_model}-{dataset}-{model_type}`
+ - `base_model`: `llama-3.2-1b`, `llama-3.2-3b`, `gemma-3-1b-pt`, `gemma-3-270m`
+ - `dataset`: `tsqa`, `m4`, `har`, `sleep`, `ecg`
+ - `model_type`: `sp` (Soft Prompt) or `flamingo` (Flamingo)
+
+Example: `OpenTSLM/llama-3.2-1b-ecg-flamingo`
+
+## Training: Multi-stage training (Curriculum)
OpenTSLM uses curriculum learning with progressive training stages:
@@ -137,6 +212,11 @@ python curriculum_learning.py --model OpenTSLMFlamingo --eval_only
- `--gradient_checkpointing`: Enable gradient checkpointing for memory efficiency
- `--verbose`: Enable verbose logging
+### Repository Naming Convention
+
+- Repository IDs ending with `-sp` will load and return `EmbedHealthSP` models
+- Repository IDs ending with `-flamingo` will load and return `EmbedHealthFlamingo` models
+
## π Results Structure
diff --git a/create_doctor_eval_dataset.py b/create_doctor_eval_dataset.py
new file mode 100644
index 0000000..7c7216c
--- /dev/null
+++ b/create_doctor_eval_dataset.py
@@ -0,0 +1,419 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+"""
+Script to create doctor evaluation dataset with correct model predictions.
+This script extracts ECG-QA templates with correct model outputs from llama3b_flamingo_predictions.jsonl
+and creates organized folders with ECG plots, CSV data, and evaluation materials.
+"""
+
+import json
+import os
+import sys
+import re
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import wfdb
+from pathlib import Path
+from typing import Dict, List, Set, Tuple
+from collections import defaultdict
+from tqdm import tqdm
+import shutil
+
+# Add the src directory to the path
+sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
+from time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset
+from time_series_datasets.ecg_qa.plot_example import draw_ecg, get_ptbxl_ecg_path
+
+# Configuration
+MODEL_PREDICTIONS_FILE = "/Users/planger/Development/EmbedHealth/evaluation/embedhealth/ecg_qa_cot/llama3b_flamingo_predictions.jsonl"
+OUTPUT_DIR = "ecg_doctor_eval"
+SAMPLES_PER_TEMPLATE = 2
+
+def extract_answer_from_generated(generated_text: str) -> str:
+ """Extract the final answer from generated text after 'Answer: '"""
+ if "Answer: " not in generated_text:
+ return generated_text.strip()
+
+ answer = generated_text.split("Answer: ")[-1].strip()
+ # Remove any end-of-text tokens and trailing punctuation
+ answer = re.sub(r'<\|.*?\|>|$', '', answer).strip()
+ answer = re.sub(r'\.$', '', answer).strip()
+ return answer
+
+def is_correct_prediction(generated_text: str, correct_answer: str) -> bool:
+ """Check if the model prediction matches the correct answer"""
+ predicted_answer = extract_answer_from_generated(generated_text)
+ return predicted_answer.lower().strip() == correct_answer.lower().strip()
+
+def load_model_predictions() -> Dict[int, List[Dict]]:
+ """Load model predictions and group by template_id"""
+ print(f"Loading model predictions from {MODEL_PREDICTIONS_FILE}")
+
+ template_predictions = defaultdict(list)
+
+ with open(MODEL_PREDICTIONS_FILE, 'r', encoding='utf-8') as f:
+ for line_num, line in enumerate(tqdm(f, desc="Loading predictions"), 1):
+ try:
+ data = json.loads(line.strip())
+
+ template_id = data.get('template_id')
+ if template_id is None:
+ continue
+
+ # Check if prediction is correct
+ generated_text = data.get('generated', '')
+ correct_answer = data.get('correct_answer', '')
+
+ if is_correct_prediction(generated_text, correct_answer):
+ template_predictions[template_id].append({
+ 'template_id': template_id,
+ 'ecg_id': data.get('ecg_id', [None])[0] if data.get('ecg_id') else None,
+ 'generated': generated_text,
+ 'correct_answer': correct_answer,
+ 'pre_prompt': data.get('pre_prompt', ''),
+ 'line_number': line_num
+ })
+
+ except Exception as e:
+ print(f"Error processing line {line_num}: {e}")
+ continue
+
+ print(f"Found correct predictions for {len(template_predictions)} templates")
+ return template_predictions
+
+def extract_clinical_context(pre_prompt: str) -> str:
+ """Extract clinical context from pre_prompt"""
+ if 'Clinical Context:' in pre_prompt:
+ context_start = pre_prompt.find('Clinical Context:')
+ context_end = pre_prompt.find('\n\n', context_start)
+ if context_end == -1:
+ context_end = pre_prompt.find('\n', context_start)
+ if context_end != -1:
+ return pre_prompt[context_start:context_end].strip()
+ return "Clinical context not available"
+
+def extract_question_from_prompt(pre_prompt: str) -> str:
+ """Extract the question from pre_prompt - this is the authoritative question for each sample"""
+ if 'Question: ' in pre_prompt:
+ question_start = pre_prompt.find('Question: ')
+ question_end = pre_prompt.find('\n\n', question_start)
+ if question_end == -1:
+ question_end = pre_prompt.find('\n', question_start)
+ if question_end != -1:
+ return pre_prompt[question_start:question_end].strip()
+ return "Question not available"
+
+def get_answer_options_for_template(template_id: int) -> List[str]:
+ """Get answer options for a template"""
+ try:
+ return ECGQACoTQADataset.get_possible_answers_for_template(template_id)
+ except Exception as e:
+ print(f"Warning: Could not get answer options for template {template_id}: {e}")
+ return []
+
+def load_ecg_data(ecg_id: int) -> Tuple[np.ndarray, str]:
+ """Load ECG data for a given ECG ID"""
+ try:
+ ecg_path = get_ptbxl_ecg_path(ecg_id)
+
+ if not os.path.exists(ecg_path + '.dat'):
+ raise FileNotFoundError(f"ECG file not found: {ecg_path}.dat")
+
+ # Read ECG data using wfdb
+ ecg_data, meta = wfdb.rdsamp(ecg_path)
+
+ # Get sampling frequency
+ sampling_freq = meta.get('fs', 500) # Default to 500Hz if not specified
+
+ return ecg_data, sampling_freq
+
+ except Exception as e:
+ raise RuntimeError(f"Failed to load ECG {ecg_id}: {e}")
+
+def downsample_to_100hz(ecg_data: np.ndarray, original_freq: int) -> np.ndarray:
+ """Downsample ECG data to 100Hz"""
+ if original_freq == 100:
+ return ecg_data
+
+ # Calculate downsampling factor
+ downsample_factor = original_freq // 100
+
+ # Downsample by taking every nth sample
+ downsampled_data = ecg_data[::downsample_factor]
+
+ return downsampled_data
+
+def save_ecg_as_csv(ecg_data: np.ndarray, output_dir: str, ecg_id: int):
+ """Save ECG data as separate CSV files for each lead"""
+ # Downsample to 100Hz if needed
+ if ecg_data.shape[0] > 1000: # Likely 500Hz data
+ ecg_data = downsample_to_100hz(ecg_data, 500)
+
+ lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
+
+ for lead_idx, lead_name in enumerate(lead_names):
+ if lead_idx < ecg_data.shape[1]: # Make sure we don't exceed available leads
+ lead_data = ecg_data[:, lead_idx]
+
+ # Create DataFrame with time and signal values
+ time_points = np.arange(len(lead_data)) / 100.0 # 100Hz sampling
+ df = pd.DataFrame({
+ 'time_seconds': time_points,
+ 'signal_mV': lead_data
+ })
+
+ # Save to CSV
+ csv_filename = f"{output_dir}/lead_{lead_name}.csv"
+ df.to_csv(csv_filename, index=False)
+
+def create_ecg_plot(ecg_data: np.ndarray, template_id: int, ecg_id: int,
+ question: str, answer_options: List[str],
+ clinical_context: str, model_output: str,
+ correct_answer: str, output_dir: str):
+ """Create ECG plot with all information"""
+
+ # Downsample to 100Hz if needed
+ if ecg_data.shape[0] > 1000: # Likely 500Hz data
+ ecg_data = downsample_to_100hz(ecg_data, 500)
+
+ # Create the plot with all 12 leads
+ fig, axes = plt.subplots(12, 1, figsize=(14, 24))
+ fig.suptitle(f"Template {template_id}: ECG Analysis\nECG ID: {ecg_id}",
+ fontsize=16, fontweight='bold')
+
+ # Create time array for 100Hz sampling (10 seconds)
+ time_points = np.arange(0, 10, 0.01) # 100Hz for 10 seconds
+
+ # Plot all 12 leads
+ lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
+ for i, (ax, lead_name) in enumerate(zip(axes, lead_names)):
+ if i < ecg_data.shape[1]: # Make sure we don't exceed available leads
+ # Plot the ECG signal for this lead
+ ax.plot(time_points, ecg_data[:, i], linewidth=2, color="k", alpha=1.0)
+
+ # Add grid lines (millimeter paper style)
+ # Major grid lines (every 0.2s and 0.5mV)
+ ax.vlines(np.arange(0, 10, 0.2), -2.5, 2.5, colors="r", alpha=0.3, linewidth=0.5)
+ ax.hlines(np.arange(-2.5, 2.5, 0.5), 0, 10, colors="r", alpha=0.3, linewidth=0.5)
+
+ # Minor grid lines (every 0.04s and 0.1mV)
+ ax.vlines(np.arange(0, 10, 0.04), -2.5, 2.5, colors="r", alpha=0.1, linewidth=0.3)
+ ax.hlines(np.arange(-2.5, 2.5, 0.1), 0, 10, colors="r", alpha=0.1, linewidth=0.3)
+
+ ax.set_xticks(np.arange(0, 11, 1.0))
+ ax.set_ylabel(f'Lead {lead_name} (mV)', fontweight='bold')
+ ax.margins(0.0)
+ ax.set_ylim(-2.5, 2.5)
+ ax.set_title(f'Lead {lead_name}', fontweight='bold', pad=10)
+ else:
+ ax.set_title(f'Lead {lead_name} (not available)', fontweight='bold', pad=10)
+ ax.text(0.5, 0.5, 'Lead not available', ha='center', va='center', transform=ax.transAxes)
+
+ # Add information text box
+ info_text = f"""Question: {question}
+
+Answer Options: {' | '.join(answer_options[:5])}{'...' if len(answer_options) > 5 else ''}
+
+Clinical Context: {clinical_context[:200]}{'...' if len(clinical_context) > 200 else ''}
+
+Model Output: {model_output[:300]}{'...' if len(model_output) > 300 else ''}
+
+Expected Answer: {correct_answer}"""
+
+ fig.text(0.02, 0.02, info_text, fontsize=9, transform=fig.transFigure,
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8))
+
+ # Save plot
+ plot_filename = f"{output_dir}/ecg_plot.png"
+ fig.savefig(plot_filename, dpi=300, bbox_inches='tight', facecolor='white')
+ plt.close(fig)
+
+ return plot_filename
+
+def create_evaluation_text_file(output_dir: str, template_id: int, ecg_id: int,
+ question: str, answer_options: List[str],
+ clinical_context: str, model_output: str,
+ correct_answer: str):
+ """Create a text file with all evaluation information"""
+
+ txt_filename = f"{output_dir}/evaluation_info.txt"
+ with open(txt_filename, 'w') as f:
+ f.write(f"ECG-QA Doctor Evaluation\n")
+ f.write(f"=" * 50 + "\n\n")
+
+ f.write(f"Template ID: {template_id}\n")
+ f.write(f"ECG ID: {ecg_id}\n\n")
+
+ f.write(f"Question:\n{question}\n\n")
+
+ f.write(f"Answer Options:\n")
+ for i, option in enumerate(answer_options, 1):
+ f.write(f"{i}. {option}\n")
+ f.write(f"\n")
+
+ f.write(f"Clinical Context:\n{clinical_context}\n\n")
+
+ f.write(f"Model Output (Llama3B-Flamingo):\n{model_output}\n\n")
+
+ f.write(f"Expected Answer: {correct_answer}\n")
+
+def create_doctor_evaluation_dataset():
+ """Main function to create the doctor evaluation dataset"""
+
+ print("Creating doctor evaluation dataset...")
+ print("Note: Each template_id can have multiple different questions - using the specific question from each sample")
+
+ # Create output directory
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
+
+ # Load model predictions
+ template_predictions = load_model_predictions()
+
+ if not template_predictions:
+ print("No correct predictions found!")
+ return
+
+ # Process each template
+ processed_templates = 0
+
+ for template_id in sorted(template_predictions.keys()):
+ predictions = template_predictions[template_id]
+
+ if len(predictions) < SAMPLES_PER_TEMPLATE:
+ print(f"Template {template_id}: Only {len(predictions)} correct predictions, skipping")
+ continue
+
+ print(f"\nProcessing template {template_id} with {len(predictions)} correct predictions")
+
+ # Get answer options for this template
+ answer_options = get_answer_options_for_template(template_id)
+ if not answer_options:
+ print(f"Template {template_id}: No answer options found, skipping")
+ continue
+
+ # Process up to SAMPLES_PER_TEMPLATE samples
+ samples_to_process = predictions[:SAMPLES_PER_TEMPLATE]
+
+ for sample_idx, prediction in enumerate(samples_to_process, 1):
+ try:
+ ecg_id = prediction['ecg_id']
+ if ecg_id is None:
+ print(f"Template {template_id}, Sample {sample_idx}: No ECG ID, skipping")
+ continue
+
+ print(f" Processing sample {sample_idx}: ECG {ecg_id}")
+
+ # Create sample directory
+ sample_dir = f"{OUTPUT_DIR}/template_{template_id:02d}/sample{sample_idx}"
+ os.makedirs(sample_dir, exist_ok=True)
+
+ # Extract information
+ clinical_context = extract_clinical_context(prediction['pre_prompt'])
+ question = extract_question_from_prompt(prediction['pre_prompt']) # Use the specific question from this sample
+ model_output = prediction['generated']
+ correct_answer = prediction['correct_answer']
+
+ # Load ECG data
+ try:
+ ecg_data, sampling_freq = load_ecg_data(ecg_id)
+ print(f" Loaded ECG data: {ecg_data.shape}, {sampling_freq}Hz")
+ except Exception as e:
+ print(f" Error loading ECG {ecg_id}: {e}")
+ continue
+
+ # Save ECG as CSV files
+ try:
+ save_ecg_as_csv(ecg_data, sample_dir, ecg_id)
+ print(f" Saved ECG CSV files")
+ except Exception as e:
+ print(f" Error saving ECG CSV: {e}")
+
+ # Create ECG plot
+ try:
+ plot_filename = create_ecg_plot(
+ ecg_data, template_id, ecg_id, question, answer_options,
+ clinical_context, model_output, correct_answer, sample_dir
+ )
+ print(f" Created ECG plot: {plot_filename}")
+ except Exception as e:
+ print(f" Error creating ECG plot: {e}")
+
+ # Create evaluation text file
+ try:
+ create_evaluation_text_file(
+ sample_dir, template_id, ecg_id, question, answer_options,
+ clinical_context, model_output, correct_answer
+ )
+ print(f" Created evaluation text file")
+ except Exception as e:
+ print(f" Error creating text file: {e}")
+
+ except Exception as e:
+ print(f" Error processing sample {sample_idx}: {e}")
+ continue
+
+ processed_templates += 1
+ print(f"Completed template {template_id}")
+
+ print(f"\nDoctor evaluation dataset creation completed!")
+ print(f"Processed {processed_templates} templates")
+ print(f"Output directory: {OUTPUT_DIR}")
+
+ # Create summary file
+ create_summary_file(template_predictions)
+
+def create_summary_file(template_predictions: Dict[int, List[Dict]]):
+ """Create a summary file with statistics"""
+ summary_file = f"{OUTPUT_DIR}/dataset_summary.txt"
+
+ with open(summary_file, 'w') as f:
+ f.write("ECG-QA Doctor Evaluation Dataset Summary\n")
+ f.write("=" * 50 + "\n\n")
+
+ f.write(f"Total templates with correct predictions: {len(template_predictions)}\n")
+ f.write(f"Samples per template: {SAMPLES_PER_TEMPLATE}\n")
+ f.write(f"Total samples created: {len(template_predictions) * SAMPLES_PER_TEMPLATE}\n\n")
+
+ f.write("Template Statistics:\n")
+ f.write("-" * 30 + "\n")
+
+ for template_id in sorted(template_predictions.keys()):
+ predictions = template_predictions[template_id]
+ f.write(f"Template {template_id:2d}: {len(predictions):3d} correct predictions\n")
+
+ f.write(f"\nDataset Structure:\n")
+ f.write(f"ecg_doctor_eval/\n")
+ f.write(f"βββ template_01/\n")
+ f.write(f"β βββ sample1/\n")
+ f.write(f"β β βββ ecg_plot.png\n")
+ f.write(f"β β βββ evaluation_info.txt\n")
+ f.write(f"β β βββ lead_I.csv\n")
+ f.write(f"β β βββ lead_II.csv\n")
+ f.write(f"β β βββ ... (all 12 leads)\n")
+ f.write(f"β βββ sample2/\n")
+ f.write(f"β βββ ... (same structure)\n")
+ f.write(f"βββ template_02/\n")
+ f.write(f"β βββ ...\n")
+ f.write(f"βββ dataset_summary.txt\n")
+
+ f.write(f"\nNotes:\n")
+ f.write(f"- All predictions are CORRECT (model answer matches expected answer)\n")
+ f.write(f"- ECG data is downsampled to 100Hz for consistency\n")
+ f.write(f"- Each sample includes clinical context, question, answer options, and model reasoning\n")
+ f.write(f"- CSV files contain time series data for each ECG lead\n")
+
+ print(f"Summary file created: {summary_file}")
+
+if __name__ == "__main__":
+ try:
+ create_doctor_evaluation_dataset()
+ except Exception as e:
+ print(f"Error: {e}")
+ import traceback
+ traceback.print_exc()
diff --git a/curriculum_learning.py b/curriculum_learning.py
index a9cf893..1269414 100644
--- a/curriculum_learning.py
+++ b/curriculum_learning.py
@@ -781,7 +781,6 @@ def _evaluate_stage(
result["ecg_id"] = sample["ecg_id"]
if "correct_answer" in sample:
result["correct_answer"] = sample["correct_answer"]
-
results.append(result)
# Stream write each result immediately to per-rank file
results_fp.write(json.dumps(result, ensure_ascii=False) + "\n")
@@ -822,11 +821,7 @@ def _evaluate_stage(
print(f"Merged per-rank predictions into: {final_results_file}")
finally:
pass
-
- # Report test loss as NaN since we skip explicit loss computation during evaluation
- # Before, we were computing the loss explicitly, but this required to run the model twice, once for loss and once for predictions.
avg_test_loss = float("nan")
-
# Calculate stage-specific metrics
metrics = {"test_loss": avg_test_loss}
if epoch is not None:
@@ -848,7 +843,6 @@ def _evaluate_stage(
continue
additional_metrics = metric_func(predictions, gold_answers)
metrics.update(additional_metrics)
-
# Save results only on rank 0 (or when not distributed)
if (not dist.is_initialized()) or (self.rank == 0):
# Save metrics
@@ -1366,58 +1360,6 @@ def stage5_ecg_cot(
sampler=sampler,
)
- def stage4_sleep_cot(
- self, batch_size: int = None, eval_only: bool = False
- ) -> Dict[str, Any]:
- """Stage 4: Chain-of-Thought Reasoning (SleepEDF).
-
- Configuration:
- - Epochs: 60
- - OpenTSLMSP: encoder_lr=2e-4, projector_lr=1e-4
- - OpenTSLMFlamingo: base_lr=2e-4
- - Metric: Test loss only (chain-of-thought reasoning)
- """
- sampler = None
-
- return self._train_stage(
- stage_name="stage4_sleep_cot",
- dataset_class=SleepEDFCoTQADataset,
- num_epochs=60,
- lr_encoder=2e-4,
- lr_projector=1e-4,
- lr_base=2e-4,
- metric_func=None, # Only test loss for chain-of-thought reasoning
- batch_size=batch_size,
- eval_only=eval_only,
- sampler=sampler,
- )
-
- def stage5_ecg_cot(
- self, batch_size: int = None, eval_only: bool = False
- ) -> Dict[str, Any]:
- """Stage 5: Chain-of-Thought Reasoning (ECG QA CoT).
-
- Configuration:
- - Epochs: 60
- - OpenTSLMSP: encoder_lr=2e-4, projector_lr=1e-4
- - OpenTSLMFlamingo: base_lr=2e-4
- - Metric: Test loss only (chain-of-thought reasoning)
- """
- sampler = None
-
- return self._train_stage(
- stage_name="stage5_ecg_cot",
- dataset_class=ECGQACoTQADataset,
- num_epochs=60,
- lr_encoder=2e-4,
- lr_projector=1e-4,
- lr_base=2e-4,
- metric_func=None, # Only test loss for chain-of-thought reasoning
- batch_size=batch_size,
- eval_only=eval_only,
- sampler=sampler,
- )
-
def run_curriculum(
self, stages: List[str] = None, batch_size: int = None, eval_only: bool = False
):
diff --git a/demo/huggingface/.gitignore b/demo/huggingface/.gitignore
new file mode 100644
index 0000000..6320cd2
--- /dev/null
+++ b/demo/huggingface/.gitignore
@@ -0,0 +1 @@
+data
\ No newline at end of file
diff --git a/demo/huggingface/01_test_hf_tsqa.py b/demo/huggingface/01_test_hf_tsqa.py
new file mode 100755
index 0000000..0280ba7
--- /dev/null
+++ b/demo/huggingface/01_test_hf_tsqa.py
@@ -0,0 +1,81 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+"""
+Demo script for testing TSQA (Time Series Question Answering) model from HuggingFace.
+
+This script:
+1. Loads a pretrained model from HuggingFace Hub
+2. Loads the TSQA test dataset
+3. Generates predictions on the evaluation set
+4. Prints model outputs
+"""
+
+import sys
+import os
+
+# Add src to path
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src")))
+
+from model.llm.OpenTSLM import OpenTSLM
+from time_series_datasets.TSQADataset import TSQADataset
+from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate
+from torch.utils.data import DataLoader
+from model_config import PATCH_SIZE
+
+# Model repository ID - change this to test different models
+REPO_ID = "OpenTSLM/llama-3.2-1b-tsqa-sp"
+
+def main():
+ print("=" * 60)
+ print("TSQA Model Demo")
+ print("=" * 60)
+
+ # Load model from HuggingFace
+ print(f"\nπ₯ Loading model from {REPO_ID}...")
+ model = OpenTSLM.load_pretrained(REPO_ID)
+
+ # Create dataset
+ print("\nπ Loading TSQA test dataset...")
+ test_dataset = TSQADataset("test", EOS_TOKEN=model.get_eos_token())
+
+ # Create data loader
+ test_loader = DataLoader(
+ test_dataset,
+ shuffle=False,
+ batch_size=1,
+ collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate(
+ batch, patch_size=PATCH_SIZE
+ ),
+ )
+
+ print(f"\nπ Running inference on {len(test_dataset)} test samples...")
+ print("=" * 60)
+
+ # Iterate over evaluation set
+ for i, batch in enumerate(test_loader):
+ # Generate predictions
+ predictions = model.generate(batch, max_new_tokens=200)
+
+ # Print results
+ for sample, pred in zip(batch, predictions):
+ print(f"\nπ Sample {i + 1}:")
+ print(f" Question: {sample.get('pre_prompt', 'N/A')}")
+ if 'time_series_text' in sample:
+ print(f" Time Series Info: {sample['time_series_text'][:100]}...")
+ print(f" Gold Answer: {sample.get('answer', 'N/A')}")
+ print(f" Model Output: {pred}")
+ print("-" * 60)
+
+ # Limit to first 5 samples for demo
+ if i >= 4:
+ print("\nβ
Demo complete! (Showing first 5 samples)")
+ break
+
+if __name__ == "__main__":
+ main()
+
diff --git a/demo/huggingface/02_test_hf_m4.py b/demo/huggingface/02_test_hf_m4.py
new file mode 100755
index 0000000..dc33e7d
--- /dev/null
+++ b/demo/huggingface/02_test_hf_m4.py
@@ -0,0 +1,82 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+"""
+Demo script for testing M4 (Time Series Captioning) model from HuggingFace.
+
+This script:
+1. Loads a pretrained model from HuggingFace Hub
+2. Loads the M4 test dataset
+3. Generates predictions on the evaluation set
+4. Prints model outputs
+"""
+
+import sys
+import os
+
+# Add src to path
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src")))
+
+from model.llm.OpenTSLM import OpenTSLM
+from time_series_datasets.m4.M4QADataset import M4QADataset
+from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate
+from torch.utils.data import DataLoader
+from model_config import PATCH_SIZE
+
+# Model repository ID - change this to test different models
+REPO_ID = "OpenTSLM/llama-3.2-1b-m4-sp"
+
+def main():
+ print("=" * 60)
+ print("M4 Captioning Model Demo")
+ print("=" * 60)
+
+ # Load model from HuggingFace
+ print(f"\nπ₯ Loading model from {REPO_ID}...")
+ model = OpenTSLM.load_pretrained(REPO_ID)
+
+ # Create dataset
+ print("\nπ Loading M4 test dataset...")
+ test_dataset = M4QADataset("test", EOS_TOKEN=model.get_eos_token())
+
+ # Create data loader
+ test_loader = DataLoader(
+ test_dataset,
+ shuffle=False,
+ batch_size=1,
+ collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate(
+ batch, patch_size=PATCH_SIZE
+ ),
+ )
+
+ print(f"\nπ Running inference on {len(test_dataset)} test samples...")
+ print("=" * 60)
+
+ # Iterate over evaluation set
+ for i, batch in enumerate(test_loader):
+ # Generate predictions
+ predictions = model.generate(batch, max_new_tokens=200)
+
+ # Print results
+ for sample, pred in zip(batch, predictions):
+ print(f"\nπ Sample {i + 1}:")
+ if 'id' in sample:
+ print(f" Time Series ID: {sample['id']}")
+ if 'pre_prompt' in sample:
+ print(f" Prompt: {sample['pre_prompt']}")
+ print(f" Gold Caption: {sample.get('answer', 'N/A')}")
+ print(f" Model Output: {pred}")
+ print("-" * 60)
+
+ # Limit to first 5 samples for demo
+ if i >= 9:
+ print("\nβ
Demo complete! (Showing first 10 samples)")
+ break
+
+if __name__ == "__main__":
+ main()
+
diff --git a/demo/huggingface/03_test_hf_har_cot.py b/demo/huggingface/03_test_hf_har_cot.py
new file mode 100755
index 0000000..3c1f484
--- /dev/null
+++ b/demo/huggingface/03_test_hf_har_cot.py
@@ -0,0 +1,85 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+"""
+Demo script for testing HAR CoT (Human Activity Recognition Chain-of-Thought) model from HuggingFace.
+
+This script:
+1. Loads a pretrained model from HuggingFace Hub
+2. Loads the HAR CoT test dataset
+3. Generates predictions on the evaluation set
+4. Prints model outputs
+"""
+
+import sys
+import os
+
+# Add src to path
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src")))
+
+from model.llm.OpenTSLM import OpenTSLM
+from time_series_datasets.har_cot.HARCoTQADataset import HARCoTQADataset
+from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate
+from torch.utils.data import DataLoader
+from model_config import PATCH_SIZE
+
+# Model repository ID - change this to test different models
+REPO_ID = "OpenTSLM/llama-3.2-1b-har-sp"
+
+def main():
+ print("=" * 60)
+ print("HAR CoT Model Demo")
+ print("=" * 60)
+
+ # Load model from HuggingFace
+ print(f"\nπ₯ Loading model from {REPO_ID}...")
+ enable_lora = False
+ if "-sp" in REPO_ID:
+ enable_lora = True
+ model = OpenTSLM.load_pretrained(REPO_ID, enable_lora=enable_lora)
+
+
+
+ # Create dataset
+ print("\nπ Loading HAR CoT test dataset...")
+ test_dataset = HARCoTQADataset("test", EOS_TOKEN=model.get_eos_token())
+
+ # Create data loader
+ test_loader = DataLoader(
+ test_dataset,
+ shuffle=False,
+ batch_size=1,
+ collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate(
+ batch, patch_size=PATCH_SIZE
+ ),
+ )
+
+ print(f"\nπ Running inference on {len(test_dataset)} test samples...")
+ print("=" * 60)
+
+ # Iterate over evaluation set
+ for i, batch in enumerate(test_loader):
+ # Generate predictions
+ predictions = model.generate(batch, max_new_tokens=500)
+
+ # Print results
+ for sample, pred in zip(batch, predictions):
+ print(f"\nπ Sample {i + 1}:")
+ if 'pre_prompt' in sample:
+ print(f" Question: {sample['pre_prompt']}")
+ print(f" Gold Answer: {sample.get('answer', 'N/A')}")
+ print(f" Model Output: {pred}")
+ print("-" * 60)
+
+ # Limit to first 5 samples for demo
+ if i >= 4:
+ print("\nβ
Demo complete! (Showing first 5 samples)")
+ break
+
+if __name__ == "__main__":
+ main()
+
diff --git a/demo/huggingface/04_test_hf_sleep_cot.py b/demo/huggingface/04_test_hf_sleep_cot.py
new file mode 100755
index 0000000..6258be5
--- /dev/null
+++ b/demo/huggingface/04_test_hf_sleep_cot.py
@@ -0,0 +1,83 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+"""
+Demo script for testing Sleep CoT (Sleep Stage Classification Chain-of-Thought) model from HuggingFace.
+
+This script:
+1. Loads a pretrained model from HuggingFace Hub
+2. Loads the Sleep CoT test dataset
+3. Generates predictions on the evaluation set
+4. Prints model outputs
+"""
+
+import sys
+import os
+
+# Add src to path
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src")))
+
+from model.llm.OpenTSLM import OpenTSLM
+from time_series_datasets.sleep.SleepEDFCoTQADataset import SleepEDFCoTQADataset
+from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate
+from torch.utils.data import DataLoader
+from model_config import PATCH_SIZE
+
+# Model repository ID - change this to test different models
+REPO_ID = "OpenTSLM/llama-3.2-1b-sleep-sp"
+
+def main():
+ print("=" * 60)
+ print("Sleep CoT Model Demo")
+ print("=" * 60)
+
+ # Load model from HuggingFace
+ print(f"\nπ₯ Loading model from {REPO_ID}...")
+ enable_lora = False
+ if "-sp" in REPO_ID:
+ enable_lora = True
+ model = OpenTSLM.load_pretrained(REPO_ID, enable_lora=enable_lora)
+
+ # Create dataset
+ print("\nπ Loading Sleep CoT test dataset...")
+ test_dataset = SleepEDFCoTQADataset("test", EOS_TOKEN=model.get_eos_token())
+
+ # Create data loader
+ test_loader = DataLoader(
+ test_dataset,
+ shuffle=False,
+ batch_size=1,
+ collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate(
+ batch, patch_size=PATCH_SIZE
+ ),
+ )
+
+ print(f"\nπ Running inference on {len(test_dataset)} test samples...")
+ print("=" * 60)
+
+ # Iterate over evaluation set
+ for i, batch in enumerate(test_loader):
+ # Generate predictions
+ predictions = model.generate(batch, max_new_tokens=500)
+
+ # Print results
+ for sample, pred in zip(batch, predictions):
+ print(f"\nπ Sample {i + 1}:")
+ if 'pre_prompt' in sample:
+ print(f" Question: {sample['pre_prompt']}")
+ print(f" Gold Answer: {sample.get('answer', 'N/A')}")
+ print(f" Model Output: {pred}")
+ print("-" * 60)
+
+ # Limit to first 5 samples for demo
+ if i >= 4:
+ print("\nβ
Demo complete! (Showing first 5 samples)")
+ break
+
+if __name__ == "__main__":
+ main()
+
diff --git a/demo/huggingface/05_test_hf_ecg_qa_cot.py b/demo/huggingface/05_test_hf_ecg_qa_cot.py
new file mode 100755
index 0000000..7c7ac98
--- /dev/null
+++ b/demo/huggingface/05_test_hf_ecg_qa_cot.py
@@ -0,0 +1,87 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+"""
+Demo script for testing ECG QA CoT (ECG Question Answering Chain-of-Thought) model from HuggingFace.
+
+This script:
+1. Loads a pretrained model from HuggingFace Hub
+2. Loads the ECG QA CoT test dataset
+3. Generates predictions on the evaluation set
+4. Prints model outputs
+"""
+
+import sys
+import os
+
+# Add src to path
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src")))
+
+from model.llm.OpenTSLM import OpenTSLM
+from time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset
+from time_series_datasets.util import extend_time_series_to_match_patch_size_and_aggregate
+from torch.utils.data import DataLoader
+from model_config import PATCH_SIZE
+
+# Model repository ID - change this to test different models
+REPO_ID = "OpenTSLM/llama-3.2-1b-ecg-sp"
+
+def main():
+ print("=" * 60)
+ print("ECG QA CoT Model Demo")
+ print("=" * 60)
+
+ # Load model from HuggingFace
+ print(f"\nπ₯ Loading model from {REPO_ID}...")
+ enable_lora = False
+ if "-sp" in REPO_ID:
+ enable_lora = True
+ model = OpenTSLM.load_pretrained(REPO_ID, enable_lora=enable_lora)
+
+ # Create dataset
+ print("\nπ Loading ECG QA CoT test dataset...")
+ test_dataset = ECGQACoTQADataset("test", EOS_TOKEN=model.get_eos_token())
+
+ # Create data loader
+ test_loader = DataLoader(
+ test_dataset,
+ shuffle=False,
+ batch_size=1,
+ collate_fn=lambda batch: extend_time_series_to_match_patch_size_and_aggregate(
+ batch, patch_size=PATCH_SIZE
+ ),
+ )
+
+ print(f"\nπ Running inference on {len(test_dataset)} test samples...")
+ print("=" * 60)
+
+ # Iterate over evaluation set
+ for i, batch in enumerate(test_loader):
+ # Generate predictions
+ predictions = model.generate(batch, max_new_tokens=500)
+
+ # Print results
+ for sample, pred in zip(batch, predictions):
+ print(f"\nπ Sample {i + 1}:")
+ if 'pre_prompt' in sample:
+ print(f" Question: {sample['pre_prompt']}")
+ if 'template_id' in sample:
+ print(f" Template ID: {sample['template_id']}")
+ if 'ecg_id' in sample:
+ print(f" ECG ID: {sample['ecg_id']}")
+ print(f" Gold Answer: {sample.get('answer', 'N/A')}")
+ print(f" Model Output: {pred}")
+ print("-" * 60)
+
+ # Limit to first 5 samples for demo
+ if i >= 4:
+ print("\nβ
Demo complete! (Showing first 5 samples)")
+ break
+
+if __name__ == "__main__":
+ main()
+
diff --git a/evaluation/baseline/common_evaluator.py b/evaluation/baseline/common_evaluator.py
index b6e5cd1..9edbec4 100644
--- a/evaluation/baseline/common_evaluator.py
+++ b/evaluation/baseline/common_evaluator.py
@@ -498,9 +498,9 @@ def _consolidate_jsonl_results(self, model_name: str, dataset_name: str) -> str:
"dataset_name": dataset_name,
"total_samples": total_samples,
"successful_inferences": successful_inferences,
- "success_rate": successful_inferences / total_samples
- if total_samples > 0
- else 0.0,
+ "success_rate": (
+ successful_inferences / total_samples if total_samples > 0 else 0.0
+ ),
"metrics": aggregate_metrics,
"detailed_results": individual_results,
}
diff --git a/evaluation/baseline/parse_predictions_sleep_baseline.py b/evaluation/baseline/parse_predictions_sleep_baseline.py
index e8b9fd0..6352c8f 100644
--- a/evaluation/baseline/parse_predictions_sleep_baseline.py
+++ b/evaluation/baseline/parse_predictions_sleep_baseline.py
@@ -58,6 +58,7 @@
# --- Inline minimal utilities (avoid importing modules that require extra packages) ---
import re
+
def extract_answer(text: str) -> str:
"""Extract the final answer from text by taking content after 'Answer:'
and trimming trailing special tokens.
@@ -65,21 +66,21 @@ def extract_answer(text: str) -> str:
if "Answer: " not in text:
return text
answer = text.split("Answer: ")[-1].strip()
- answer = re.sub(r'<\|.*?\|>$', '', answer).strip()
+ answer = re.sub(r"<\|.*?\|>$", "", answer).strip()
return answer
def calculate_f1_score(prediction: str, ground_truth: str):
"""Binary exact-match F1 on normalized strings (lower/strip/punct)."""
- pred_normalized = prediction.lower().strip().rstrip('.,!?;:')
- truth_normalized = ground_truth.lower().strip().rstrip('.,!?;:')
+ pred_normalized = prediction.lower().strip().rstrip(".,!?;:")
+ truth_normalized = ground_truth.lower().strip().rstrip(".,!?;:")
f1 = 1.0 if pred_normalized == truth_normalized else 0.0
return {
- 'f1_score': f1,
- 'precision': f1,
- 'recall': f1,
- 'prediction_normalized': pred_normalized,
- 'ground_truth_normalized': truth_normalized,
+ "f1_score": f1,
+ "precision": f1,
+ "recall": f1,
+ "prediction_normalized": pred_normalized,
+ "ground_truth_normalized": truth_normalized,
}
@@ -122,7 +123,11 @@ def calculate_f1_stats(data_points: List[Dict], allowed_labels=None):
tp, fp, fn = counts["tp"], counts["fp"], counts["fn"]
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
+ f1 = (
+ 2 * (precision * recall) / (precision + recall)
+ if (precision + recall) > 0
+ else 0.0
+ )
class_f1_scores[class_name] = {
"f1": f1,
"precision": precision,
@@ -186,8 +191,8 @@ def canonicalize_sleep_label(s: str) -> str:
# If it's an option-like artifact (e.g., "(a) wake"), keep only the label part
# but since SleepEDF doesn't use options, just strip leading option markers if present
- if len(t) > 3 and t[0] == '(' and ')' in t[:4]:
- t = t.split(')', 1)[-1].strip()
+ if len(t) > 3 and t[0] == "(" and ")" in t[:4]:
+ t = t.split(")", 1)[-1].strip()
# Short aliases
if t in {"w"}:
@@ -290,18 +295,18 @@ def extract_structured_data(obj: Dict) -> List[Dict]:
# Binary exact-match accuracy on normalized labels handled in calculate_f1_score, but keep explicit flag
f1_result = calculate_f1_score(model_prediction, ground_truth)
- accuracy = f1_result['f1_score'] == 1.0
+ accuracy = f1_result["f1_score"] == 1.0
data_point = {
"generated": generated,
"model_prediction": model_prediction,
"ground_truth": ground_truth,
"accuracy": accuracy,
- "f1_score": f1_result['f1_score'],
- "precision": f1_result['precision'],
- "recall": f1_result['recall'],
- "prediction_normalized": f1_result['prediction_normalized'],
- "ground_truth_normalized": f1_result['ground_truth_normalized'],
+ "f1_score": f1_result["f1_score"],
+ "precision": f1_result["precision"],
+ "recall": f1_result["recall"],
+ "prediction_normalized": f1_result["prediction_normalized"],
+ "ground_truth_normalized": f1_result["ground_truth_normalized"],
}
data_points.append(data_point)
@@ -316,16 +321,16 @@ def main():
"--detailed-json",
type=Path,
required=True,
- help="Path to a single results JSON file containing 'detailed_results'"
+ help="Path to a single results JSON file containing 'detailed_results'",
)
ap.add_argument(
"--clean-out",
type=Path,
- help="Optional path to write clean JSONL of parsed per-sample points"
+ help="Optional path to write clean JSONL of parsed per-sample points",
)
args = ap.parse_args()
- with args.detailed_json.open('r', encoding='utf-8') as f:
+ with args.detailed_json.open("r", encoding="utf-8") as f:
obj = json.load(f)
# Extract per-sample points
@@ -374,9 +379,9 @@ def main():
print(f"Macro-F1 Score: {f1_stats.get('macro_f1', 0.0):.4f}")
print(f"Total Classes: {f1_stats.get('total_classes', 0)}")
- if f1_stats.get('class_f1_scores'):
+ if f1_stats.get("class_f1_scores"):
print(f"\nPer-Class F1 Scores:")
- for class_name, scores in f1_stats['class_f1_scores'].items():
+ for class_name, scores in f1_stats["class_f1_scores"].items():
print(
f" {class_name}: F1={scores['f1']:.4f}, "
f"P={scores['precision']:.4f}, R={scores['recall']:.4f}"
@@ -384,7 +389,7 @@ def main():
# Optional clean JSONL output
if args.clean_out:
- with args.clean_out.open('w', encoding='utf-8') as f:
+ with args.clean_out.open("w", encoding="utf-8") as f:
for item in data_points:
f.write(json.dumps(item, indent=2) + "\n")
print(f"\nData saved to {args.clean_out}")
diff --git a/evaluation/opentslm/ecg_qa_cot/parse_ecg_qa_cot_data.py b/evaluation/opentslm/ecg_qa_cot/parse_ecg_qa_cot_data.py
index fe416a5..c72223d 100644
--- a/evaluation/opentslm/ecg_qa_cot/parse_ecg_qa_cot_data.py
+++ b/evaluation/opentslm/ecg_qa_cot/parse_ecg_qa_cot_data.py
@@ -18,17 +18,18 @@
from tqdm import tqdm
# Add the src directory to the path to import from the dataset class
-sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src'))
+sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "src"))
from time_series_datasets.ecg_qa.ECGQACoTQADataset import ECGQACoTQADataset
+
def calculate_f1_score(prediction, ground_truth, possible_answers):
"""Calculate F1 score for single-label classification with template-specific answers.
-
+
Args:
prediction: Model's predicted answer
ground_truth: Ground truth answer
possible_answers: List of valid answers for this template
-
+
Returns:
Dict with F1 metrics and metadata
"""
@@ -38,77 +39,78 @@ def calculate_f1_score(prediction, ground_truth, possible_answers):
raise ValueError("Ground truth cannot be None")
if not possible_answers:
raise ValueError("Possible answers list cannot be empty")
-
+
# Normalize predictions and ground truth
- pred_normalized = prediction.lower().strip().rstrip('.,!?;:')
- truth_normalized = ground_truth.lower().strip().rstrip('.,!?;:')
-
+ pred_normalized = prediction.lower().strip().rstrip(".,!?;:")
+ truth_normalized = ground_truth.lower().strip().rstrip(".,!?;:")
+
# Check if prediction is in supported answers
possible_answers_lower = [ans.lower().strip() for ans in possible_answers]
pred_supported = pred_normalized in possible_answers_lower
truth_supported = truth_normalized in possible_answers_lower
-
+
# Calculate F1 (exact match after normalization)
f1 = 1.0 if pred_normalized == truth_normalized else 0.0
-
+
return {
- 'f1_score': f1,
- 'precision': f1,
- 'recall': f1,
- 'prediction_normalized': pred_normalized,
- 'ground_truth_normalized': truth_normalized,
- 'prediction_supported': pred_supported,
- 'ground_truth_supported': truth_supported,
- 'possible_answers': possible_answers,
+ "f1_score": f1,
+ "precision": f1,
+ "recall": f1,
+ "prediction_normalized": pred_normalized,
+ "ground_truth_normalized": truth_normalized,
+ "prediction_supported": pred_supported,
+ "ground_truth_supported": truth_supported,
+ "possible_answers": possible_answers,
}
+
def calculate_template_f1_stats(data_points):
"""Calculate F1 statistics per template.
-
+
Args:
data_points: List of data points with template_id and metrics
-
+
Returns:
Dict with per-template and overall statistics
"""
if not data_points:
return {}
-
+
# Group by template_id
template_groups = defaultdict(list)
for point in data_points:
- template_id = point.get('template_id')
+ template_id = point.get("template_id")
if template_id is None:
raise ValueError(f"Missing template_id in data point: {point}")
template_groups[template_id].append(point)
-
+
# Calculate per-template metrics
template_stats = {}
total_samples = 0
total_correct = 0
total_f1_sum = 0
total_macro_f1_weighted_sum = 0
-
+
for template_id, points in template_groups.items():
if not points:
continue
-
+
# Get possible answers for this template - required for evaluation
- possible_answers = points[0].get('possible_answers', [])
+ possible_answers = points[0].get("possible_answers", [])
if not possible_answers:
raise ValueError(f"No possible answers found for template {template_id}")
-
+
# Calculate per-template F1 stats
class_predictions = {}
for answer in possible_answers:
class_predictions[answer.lower()] = {"tp": 0, "fp": 0, "fn": 0}
-
+
# Count TP, FP, FN for each class in this template
for point in points:
gt_class = point.get("ground_truth_normalized", "")
pred_class = point.get("prediction_normalized", "")
pred_supported = point.get("prediction_supported", False)
-
+
# Only count if ground truth is in supported answers
if gt_class in class_predictions:
if pred_class == gt_class:
@@ -119,19 +121,23 @@ def calculate_template_f1_stats(data_points):
# Count FP only if prediction is supported
if pred_supported and pred_class in class_predictions:
class_predictions[pred_class]["fp"] += 1
-
+
# Calculate per-class F1 for this template
class_f1_scores = {}
template_f1_sum = 0
valid_classes = 0
-
+
for class_name, counts in class_predictions.items():
tp, fp, fn = counts["tp"], counts["fp"], counts["fn"]
-
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
-
+ f1 = (
+ 2 * (precision * recall) / (precision + recall)
+ if (precision + recall) > 0
+ else 0
+ )
+
class_f1_scores[class_name] = {
"f1": f1,
"precision": precision,
@@ -140,20 +146,24 @@ def calculate_template_f1_stats(data_points):
"fp": fp,
"fn": fn,
}
-
+
template_f1_sum += f1
valid_classes += 1
-
+
# Calculate macro-F1 for this template
macro_f1 = template_f1_sum / valid_classes if valid_classes > 0 else 0
-
+
# Calculate accuracy for this template
template_correct = sum(1 for point in points if point.get("accuracy", False))
template_accuracy = template_correct / len(points) if points else 0
-
+
# Calculate average F1 for this template
- template_avg_f1 = sum(point.get("f1_score", 0) for point in points) / len(points) if points else 0
-
+ template_avg_f1 = (
+ sum(point.get("f1_score", 0) for point in points) / len(points)
+ if points
+ else 0
+ )
+
template_stats[template_id] = {
"num_samples": len(points),
"accuracy": template_accuracy,
@@ -163,19 +173,21 @@ def calculate_template_f1_stats(data_points):
"num_classes": valid_classes,
"correct_predictions": template_correct,
}
-
+
total_samples += len(points)
total_correct += template_correct
total_f1_sum += template_avg_f1 * len(points) # Weighted by number of samples
total_macro_f1_weighted_sum += macro_f1 * len(points)
-
+
# Calculate overall statistics
overall_accuracy = total_correct / total_samples if total_samples > 0 else 0
overall_avg_f1 = total_f1_sum / total_samples if total_samples > 0 else 0
-
+
# Calculate macro-F1 across templates
template_macro_f1s = [stats["macro_f1"] for stats in template_stats.values()]
- overall_macro_f1 = sum(template_macro_f1s) / len(template_macro_f1s) if template_macro_f1s else 0
+ overall_macro_f1 = (
+ sum(template_macro_f1s) / len(template_macro_f1s) if template_macro_f1s else 0
+ )
overall_macro_f1_weighted = (
total_macro_f1_weighted_sum / total_samples if total_samples > 0 else 0
)
@@ -183,9 +195,11 @@ def calculate_template_f1_stats(data_points):
# Calculate unweighted average of per-template accuracies
template_accuracies = [stats["accuracy"] for stats in template_stats.values()]
overall_template_accuracy_avg = (
- sum(template_accuracies) / len(template_accuracies) if template_accuracies else 0
+ sum(template_accuracies) / len(template_accuracies)
+ if template_accuracies
+ else 0
)
-
+
return {
"overall": {
"total_samples": total_samples,
@@ -199,36 +213,38 @@ def calculate_template_f1_stats(data_points):
"per_template": template_stats,
}
+
def calculate_accuracy_stats(data_points):
"""Calculate overall accuracy statistics from data points"""
if not data_points:
return {}
-
+
total = len(data_points)
correct = sum(1 for point in data_points if point.get("accuracy", False))
accuracy_percentage = (correct / total) * 100 if total > 0 else 0
-
+
return {
"total_samples": total,
"correct_predictions": correct,
"incorrect_predictions": total - correct,
- "accuracy_percentage": accuracy_percentage
+ "accuracy_percentage": accuracy_percentage,
}
+
def parse_ecg_qa_cot_jsonl(input_file, output_file=None):
"""Parse ECG-QA CoT JSONL file and extract JSON objects with per-template metrics."""
if output_file is None:
input_path = Path(input_file)
output_file = str(input_path.parent / f"{input_path.stem}.clean.jsonl")
-
+
print(f"Parsing {input_file}")
print(f"Output will be saved to {output_file}")
-
+
extracted_data = extract_structured_data(input_file)
-
+
if extracted_data:
print(f"Extracted {len(extracted_data)} data points")
-
+
# Calculate and display overall accuracy statistics
accuracy_stats = calculate_accuracy_stats(extracted_data)
print(f"\nOverall Accuracy Statistics:")
@@ -236,23 +252,31 @@ def parse_ecg_qa_cot_jsonl(input_file, output_file=None):
print(f"Correct predictions: {accuracy_stats['correct_predictions']}")
print(f"Incorrect predictions: {accuracy_stats['incorrect_predictions']}")
print(f"Accuracy: {accuracy_stats['accuracy_percentage']:.2f}%")
-
+
# Calculate and display per-template F1 statistics
f1_stats = calculate_template_f1_stats(extracted_data)
print(f"\nOverall F1 Statistics:")
overall = f1_stats.get("overall", {})
print(f"Total templates: {overall.get('total_templates', 0)}")
print(f"Average F1 Score (sample-weighted): {overall.get('average_f1', 0):.4f}")
- print(f"Macro-F1 Score (unweighted over templates): {overall.get('macro_f1', 0):.4f}")
- print(f"Macro-F1 Score (sample-weighted over templates): {overall.get('macro_f1_weighted', 0):.4f}")
- print(f"Template Accuracy Avg (unweighted): {overall.get('template_accuracy_avg', 0):.4f}")
+ print(
+ f"Macro-F1 Score (unweighted over templates): {overall.get('macro_f1', 0):.4f}"
+ )
+ print(
+ f"Macro-F1 Score (sample-weighted over templates): {overall.get('macro_f1_weighted', 0):.4f}"
+ )
+ print(
+ f"Template Accuracy Avg (unweighted): {overall.get('template_accuracy_avg', 0):.4f}"
+ )
# Final single-value summary (aggregated across all templates)
print("\nFinal Results (aggregated across all templates):")
print(f"Final Accuracy (micro over samples): {overall.get('accuracy', 0):.4f}")
print(f"Final F1 (micro over samples): {overall.get('accuracy', 0):.4f}")
- print(f"Final Macro-F1 (weighted by template size): {overall.get('macro_f1_weighted', 0):.4f}")
-
+ print(
+ f"Final Macro-F1 (weighted by template size): {overall.get('macro_f1_weighted', 0):.4f}"
+ )
+
# Display per-template statistics
per_template = f1_stats.get("per_template", {})
if per_template:
@@ -263,18 +287,22 @@ def parse_ecg_qa_cot_jsonl(input_file, output_file=None):
print(f" Accuracy: {stats['accuracy']:.4f}")
print(f" Average F1: {stats['average_f1']:.4f}")
print(f" Macro-F1: {stats['macro_f1']:.4f}")
-
+
# Show per-class F1 scores for this template
- if stats['class_f1_scores']:
+ if stats["class_f1_scores"]:
print(f" Per-class F1:")
- for class_name, scores in stats['class_f1_scores'].items():
- if scores['tp'] + scores['fp'] + scores['fn'] > 0: # Only show classes with samples
- print(f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}")
-
- with open(output_file, 'w', encoding='utf-8') as f:
+ for class_name, scores in stats["class_f1_scores"].items():
+ if (
+ scores["tp"] + scores["fp"] + scores["fn"] > 0
+ ): # Only show classes with samples
+ print(
+ f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}"
+ )
+
+ with open(output_file, "w", encoding="utf-8") as f:
for item in extracted_data:
f.write(json.dumps(item, indent=2) + "\n")
-
+
print(f"\nData saved to {output_file}")
# Print concise overall stats at the end as a final summary
@@ -284,59 +312,73 @@ def parse_ecg_qa_cot_jsonl(input_file, output_file=None):
print(f"F1 (micro): {overall.get('accuracy', 0):.4f}")
print(f"Macro-F1 (unweighted): {overall.get('macro_f1', 0):.4f}")
print(f"Macro-F1 (weighted): {overall.get('macro_f1_weighted', 0):.4f}")
- print(f"Template Accuracy Avg (unweighted): {overall.get('template_accuracy_avg', 0):.4f}")
+ print(
+ f"Template Accuracy Avg (unweighted): {overall.get('template_accuracy_avg', 0):.4f}"
+ )
return extracted_data
else:
print("No data could be extracted from the file.")
return []
+
def extract_structured_data(input_file):
"""Extract structured data from JSONL file"""
data_points = []
-
- with open(input_file, 'r', encoding='utf-8') as f:
+
+ with open(input_file, "r", encoding="utf-8") as f:
for line_num, line in tqdm(enumerate(f, 1), desc="Processing JSONL"):
try:
# Parse JSON line
data = json.loads(line.strip())
-
+
# Extract generated and gold fields - these are required
- generated_text = data.get('generated_answer')
- gold_text = data.get('target_answer')
-
+ generated_text = data.get("generated_answer")
+ gold_text = data.get("target_answer")
+
if generated_text is None:
raise ValueError(f"Missing 'generated' field in line {line_num}")
if gold_text is None:
raise ValueError(f"Missing 'gold' field in line {line_num}")
-
+
# Extract template_id and ecg_id - these are required for ECG-QA evaluation
- template_id = data.get('template_id')
- ecg_id = data.get('ecg_id', "None")
-
+ template_id = data.get("template_id")
+ ecg_id = data.get("ecg_id", "None")
+
if template_id is None:
raise ValueError(f"Missing template_id in line {line_num}")
if ecg_id is None:
raise ValueError(f"Missing ecg_id in line {line_num}")
-
+
# Extract answers from both fields
model_prediction_raw = extract_answer(generated_text)
ground_truth_raw = extract_answer(gold_text)
-
+
# Get possible answers for this template - required for evaluation
try:
- possible_answers = ECGQACoTQADataset.get_possible_answers_for_template(template_id)
+ possible_answers = (
+ ECGQACoTQADataset.get_possible_answers_for_template(template_id)
+ )
except Exception as e:
- raise ValueError(f"Could not get possible answers for template {template_id}: {e}")
-
+ raise ValueError(
+ f"Could not get possible answers for template {template_id}: {e}"
+ )
+
if not possible_answers:
- raise ValueError(f"No possible answers found for template {template_id}")
-
+ raise ValueError(
+ f"No possible answers found for template {template_id}"
+ )
+
# Calculate F1 score with template-specific answers
- f1_result = calculate_f1_score(model_prediction_raw, ground_truth_raw, possible_answers)
-
+ f1_result = calculate_f1_score(
+ model_prediction_raw, ground_truth_raw, possible_answers
+ )
+
# Calculate accuracy (exact match)
- accuracy = (f1_result['prediction_normalized'] == f1_result['ground_truth_normalized']) and f1_result['ground_truth_supported']
-
+ accuracy = (
+ f1_result["prediction_normalized"]
+ == f1_result["ground_truth_normalized"]
+ ) and f1_result["ground_truth_supported"]
+
data_point = {
"generated": generated_text,
"model_prediction": model_prediction_raw,
@@ -344,15 +386,15 @@ def extract_structured_data(input_file):
"template_id": template_id,
"ecg_id": ecg_id,
"accuracy": accuracy,
- "f1_score": f1_result['f1_score'],
- "precision": f1_result['precision'],
- "recall": f1_result['recall'],
- "prediction_normalized": f1_result['prediction_normalized'],
- "ground_truth_normalized": f1_result['ground_truth_normalized'],
- "prediction_supported": f1_result['prediction_supported'],
- "ground_truth_supported": f1_result['ground_truth_supported'],
+ "f1_score": f1_result["f1_score"],
+ "precision": f1_result["precision"],
+ "recall": f1_result["recall"],
+ "prediction_normalized": f1_result["prediction_normalized"],
+ "ground_truth_normalized": f1_result["ground_truth_normalized"],
+ "prediction_supported": f1_result["prediction_supported"],
+ "ground_truth_supported": f1_result["ground_truth_supported"],
"possible_answers": possible_answers,
- "line_number": line_num
+ "line_number": line_num,
}
data_points.append(data_point)
except json.JSONDecodeError as e:
@@ -361,24 +403,30 @@ def extract_structured_data(input_file):
except Exception as e:
print(f"Unexpected error on line {line_num}: {e}")
continue
-
+
return data_points
+
def extract_answer(text):
"""Extract the final answer from text"""
if "Answer: " not in text:
return text
-
+
answer = text.split("Answer: ")[-1].strip()
# Remove any end-of-text tokens (including and <|...|>)
- answer = re.sub(r'<\|.*?\|>|$', '', answer).strip()
+ answer = re.sub(r"<\|.*?\|>|$", "", answer).strip()
# Remove trailing periods and normalize
- answer = re.sub(r'\.$', '', answer).strip()
+ answer = re.sub(r"\.$", "", answer).strip()
return answer
+
if __name__ == "__main__":
current_dir = Path(__file__).parent
- input_file = current_dir / "evaluation_results_openai-gpt-4o_ecgqacotqadataset.jsonl"
- clean_output = current_dir / "evaluation_results_openai-gpt-4o_ecgqacotqadataset.clean.jsonl"
-
+ input_file = (
+ current_dir / "evaluation_results_openai-gpt-4o_ecgqacotqadataset.jsonl"
+ )
+ clean_output = (
+ current_dir / "evaluation_results_openai-gpt-4o_ecgqacotqadataset.clean.jsonl"
+ )
+
parse_ecg_qa_cot_jsonl(input_file, clean_output)
diff --git a/evaluation/opentslm/parse_predictions.py b/evaluation/opentslm/parse_predictions.py
index a686fe1..8f1c67a 100644
--- a/evaluation/opentslm/parse_predictions.py
+++ b/evaluation/opentslm/parse_predictions.py
@@ -17,8 +17,8 @@
from collections import Counter
# Add the src directory to the path to import from the dataset class
-project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
-sys.path.append(os.path.join(project_root, 'src'))
+project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+sys.path.append(os.path.join(project_root, "src"))
# Import the dataset class to get labels
from time_series_datasets.har_cot.HARCoTQADataset import HARCoTQADataset
@@ -26,23 +26,25 @@
# Get the supported labels from the dataset class
SUPPORTED_LABELS = HARCoTQADataset.get_labels()
+
def calculate_f1_score(prediction, ground_truth):
"""Calculate F1 score for classification labels"""
# Normalize labels for comparison (lowercase, strip whitespace and trailing punctuation)
- pred_normalized = prediction.lower().strip().rstrip('.,!?;:')
- truth_normalized = ground_truth.lower().strip().rstrip('.,!?;:')
-
+ pred_normalized = prediction.lower().strip().rstrip(".,!?;:")
+ truth_normalized = ground_truth.lower().strip().rstrip(".,!?;:")
+
# For single prediction vs single ground truth, F1 is binary
f1 = 1.0 if pred_normalized == truth_normalized else 0.0
-
+
return {
- 'f1_score': f1,
- 'precision': f1, # For single-label classification, precision = recall = f1
- 'recall': f1,
- 'prediction_normalized': pred_normalized,
- 'ground_truth_normalized': truth_normalized
+ "f1_score": f1,
+ "precision": f1, # For single-label classification, precision = recall = f1
+ "recall": f1,
+ "prediction_normalized": pred_normalized,
+ "ground_truth_normalized": truth_normalized,
}
+
def calculate_f1_stats(data_points, allowed_labels=None):
"""Calculate both macro-F1 and average F1 (micro-F1) statistics.
@@ -53,11 +55,11 @@ def calculate_f1_stats(data_points, allowed_labels=None):
"""
if not data_points:
return {}
-
+
# Calculate average F1 (micro-F1) - simple average across all predictions
f1_scores = [point.get("f1_score", 0) for point in data_points]
average_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
-
+
# Group by ground truth class for macro-F1
class_predictions = {}
if allowed_labels:
@@ -66,10 +68,10 @@ def calculate_f1_stats(data_points, allowed_labels=None):
for point in data_points:
gt_class = point.get("ground_truth_normalized", "")
pred_class = point.get("prediction_normalized", "")
-
+
if gt_class not in class_predictions:
class_predictions[gt_class] = {"tp": 0, "fp": 0, "fn": 0}
-
+
# True positive: prediction matches ground truth
if pred_class == gt_class:
class_predictions[gt_class]["tp"] += 1
@@ -82,73 +84,79 @@ def calculate_f1_stats(data_points, allowed_labels=None):
class_predictions[pred_class]["fp"] += 1
else:
class_predictions[pred_class] = {"tp": 0, "fp": 1, "fn": 0}
-
+
# Calculate F1 per class
class_f1_scores = {}
total_f1 = 0
valid_classes = 0
-
+
for class_name, counts in class_predictions.items():
tp, fp, fn = counts["tp"], counts["fp"], counts["fn"]
-
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
-
+ f1 = (
+ 2 * (precision * recall) / (precision + recall)
+ if (precision + recall) > 0
+ else 0
+ )
+
class_f1_scores[class_name] = {
"f1": f1,
"precision": precision,
"recall": recall,
"tp": tp,
"fp": fp,
- "fn": fn
+ "fn": fn,
}
-
+
total_f1 += f1
valid_classes += 1
-
+
# Calculate macro-F1 (average across all classes)
macro_f1 = total_f1 / valid_classes if valid_classes > 0 else 0
-
+
return {
"average_f1": average_f1,
"macro_f1": macro_f1,
"class_f1_scores": class_f1_scores,
- "total_classes": valid_classes
+ "total_classes": valid_classes,
}
+
def calculate_accuracy_stats(data_points):
"""Calculate accuracy statistics from data points"""
if not data_points:
return {}
-
+
total = len(data_points)
correct = sum(1 for point in data_points if point.get("accuracy", False))
accuracy_percentage = (correct / total) * 100 if total > 0 else 0
-
+
return {
"total_samples": total,
"correct_predictions": correct,
"incorrect_predictions": total - correct,
- "accuracy_percentage": accuracy_percentage
+ "accuracy_percentage": accuracy_percentage,
}
-
def parse_rtf_jsonl(input_file, output_file=None):
"""Parse RTF-formatted JSONL file and extract JSON objects."""
if output_file is None:
input_path = Path(input_file)
- output_file = str(input_path.parent / f"{input_path.stem.split('.')[0]}.clean.jsonl")
-
+ output_file = str(
+ input_path.parent / f"{input_path.stem.split('.')[0]}.clean.jsonl"
+ )
+
print(f"Parsing {input_file}")
print(f"Output will be saved to {output_file}")
-
- with open(input_file, 'rb') as f:
- rtf_content = f.read().decode('utf-8', errors='ignore')
-
+
+ with open(input_file, "rb") as f:
+ rtf_content = f.read().decode("utf-8", errors="ignore")
+
extracted_data = extract_structured_data(rtf_content)
-
+
# Use the predefined supported labels from the dataset class
# This ensures consistency and prevents OOV predictions from creating new classes
allowed_labels = set(SUPPORTED_LABELS)
@@ -159,12 +167,14 @@ def parse_rtf_jsonl(input_file, output_file=None):
point["excluded"] = not is_valid_prediction
if not is_valid_prediction:
excluded_count += 1
-
+
if extracted_data:
print(f"Extracted {len(extracted_data)} data points")
if excluded_count > 0:
- print(f"Excluded {excluded_count} predictions not in the label set from metrics")
-
+ print(
+ f"Excluded {excluded_count} predictions not in the label set from metrics"
+ )
+
# Calculate and display accuracy statistics (include all samples)
accuracy_stats = calculate_accuracy_stats(extracted_data)
print(f"\nAccuracy Statistics:")
@@ -172,80 +182,85 @@ def parse_rtf_jsonl(input_file, output_file=None):
print(f"Correct predictions: {accuracy_stats['correct_predictions']}")
print(f"Incorrect predictions: {accuracy_stats['incorrect_predictions']}")
print(f"Accuracy: {accuracy_stats['accuracy_percentage']:.2f}%")
-
+
# Calculate and display F1 statistics (prevent OOV predictions from creating new classes)
f1_stats = calculate_f1_stats(extracted_data, allowed_labels=allowed_labels)
print(f"\nF1 Score Statistics:")
print(f"Average F1 Score: {f1_stats['average_f1']:.4f}")
print(f"Macro-F1 Score: {f1_stats['macro_f1']:.4f}")
print(f"Total Classes: {f1_stats['total_classes']}")
-
+
# Display per-class F1 scores
- if f1_stats['class_f1_scores']:
+ if f1_stats["class_f1_scores"]:
print(f"\nPer-Class F1 Scores:")
- for class_name, scores in f1_stats['class_f1_scores'].items():
- print(f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}")
-
- with open(output_file, 'w', encoding='utf-8') as f:
+ for class_name, scores in f1_stats["class_f1_scores"].items():
+ print(
+ f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}"
+ )
+
+ with open(output_file, "w", encoding="utf-8") as f:
for item in extracted_data:
f.write(json.dumps(item, indent=2) + "\n")
-
+
print(f"\nData saved to {output_file}")
return extracted_data
else:
print("No data could be extracted from the file.")
return []
+
def extract_structured_data(rtf_content):
"""Extract structured data from RTF content"""
data_points = []
-
+
# Find key components
generated_pattern = r'generated":\s*"(.*?)"'
generated_matches = re.findall(generated_pattern, rtf_content)
-
+
gold_pattern = r'gold":\s*"(.*?)"'
gold_matches = re.findall(gold_pattern, rtf_content)
-
+
min_length = min(len(generated_matches), len(gold_matches))
-
+
for i in range(min_length):
model_prediction = extract_answer(generated_matches[i]).replace("", "")
ground_truth = extract_answer(gold_matches[i]).replace("", "")
-
+
# Calculate accuracy (exact match)
accuracy = model_prediction == ground_truth
-
+
# Calculate F1 score
f1_result = calculate_f1_score(model_prediction, ground_truth)
-
+
data_point = {
"generated": generated_matches[i],
"model_prediction": model_prediction,
"ground_truth": ground_truth,
"accuracy": accuracy,
- "f1_score": f1_result['f1_score'],
- "precision": f1_result['precision'],
- "recall": f1_result['recall'],
- "prediction_normalized": f1_result['prediction_normalized'],
- "ground_truth_normalized": f1_result['ground_truth_normalized']
+ "f1_score": f1_result["f1_score"],
+ "precision": f1_result["precision"],
+ "recall": f1_result["recall"],
+ "prediction_normalized": f1_result["prediction_normalized"],
+ "ground_truth_normalized": f1_result["ground_truth_normalized"],
}
data_points.append(data_point)
-
+
return data_points
+
def extract_answer(text):
"""Extract the final answer from text"""
if "Answer: " not in text:
return text
-
+
answer = text.split("Answer: ")[-1].strip()
- answer = re.sub(r'<\|.*?\|>$', '', answer).strip()
+ answer = re.sub(r"<\|.*?\|>$", "", answer).strip()
return answer
+
if __name__ == "__main__":
current_dir = Path(__file__).parent
input_file = current_dir / "gemma3_270m_sp_har.jsonl"
clean_output = current_dir / "gemma3_270m_sp_har.clean.jsonl"
-
+
parse_rtf_jsonl(input_file, clean_output)
diff --git a/evaluation/opentslm/sleep/parse_sleep_cot_data.py b/evaluation/opentslm/sleep/parse_sleep_cot_data.py
index bed0b88..1fb2086 100644
--- a/evaluation/opentslm/sleep/parse_sleep_cot_data.py
+++ b/evaluation/opentslm/sleep/parse_sleep_cot_data.py
@@ -18,7 +18,7 @@
from tqdm import tqdm
# Add the src directory to the path to import from the dataset class
-sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src'))
+sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..", "src"))
from time_series_datasets.sleep.SleepEDFCoTQADataset import SleepEDFCoTQADataset
# We'll determine supported labels dynamically from the actual ground truth data
@@ -26,6 +26,7 @@
FALLBACK_LABELS = SleepEDFCoTQADataset.get_labels()
SUPPORTED_LABELS = [] # Will be populated dynamically
+
def _canonicalize_label(text):
"""Return canonical label with stage 4 merged into stage 3.
@@ -39,8 +40,8 @@ def _canonicalize_label(text):
cleaned = str(text).strip()
# Remove any end-of-text tokens and trailing period
- cleaned = re.sub(r'<\|.*?\|>|$', '', cleaned).strip()
- cleaned = re.sub(r'\.$', '', cleaned).strip()
+ cleaned = re.sub(r"<\|.*?\|>|$", "", cleaned).strip()
+ cleaned = re.sub(r"\.$", "", cleaned).strip()
lowered = cleaned.lower()
@@ -76,6 +77,8 @@ def _canonicalize_label(text):
label_set = SUPPORTED_LABELS if SUPPORTED_LABELS else FALLBACK_LABELS
is_supported = canonical in label_set
return canonical if canonical else cleaned, is_supported
+
+
def calculate_f1_score(prediction, ground_truth):
"""Calculate F1 score for single-label classification with supported labels.
@@ -89,15 +92,16 @@ def calculate_f1_score(prediction, ground_truth):
f1 = 1.0 if pred_canon == truth_canon else 0.0
return {
- 'f1_score': f1,
- 'precision': f1,
- 'recall': f1,
- 'prediction_normalized': pred_canon.lower().strip(),
- 'ground_truth_normalized': truth_canon.lower().strip(),
- 'prediction_supported': pred_supported,
- 'ground_truth_supported': truth_supported,
+ "f1_score": f1,
+ "precision": f1,
+ "recall": f1,
+ "prediction_normalized": pred_canon.lower().strip(),
+ "ground_truth_normalized": truth_canon.lower().strip(),
+ "prediction_supported": pred_supported,
+ "ground_truth_supported": truth_supported,
}
+
def calculate_f1_stats(data_points):
"""Calculate both macro-F1 and average F1 (micro-F1) statistics.
@@ -107,16 +111,18 @@ def calculate_f1_stats(data_points):
"""
if not data_points:
return {}
-
+
# Calculate average F1 (micro-F1) - simple average across all predictions
f1_scores = [point.get("f1_score", 0) for point in data_points]
average_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
-
+
# Initialize class buckets for only supported classes (lowercased for consistency)
# Use discovered labels if available, otherwise fall back to dataset labels
labels_to_use = SUPPORTED_LABELS if SUPPORTED_LABELS else FALLBACK_LABELS
supported_lower = {label.lower(): label for label in labels_to_use}
- class_predictions = {lab.lower(): {"tp": 0, "fp": 0, "fn": 0} for lab in labels_to_use}
+ class_predictions = {
+ lab.lower(): {"tp": 0, "fp": 0, "fn": 0} for lab in labels_to_use
+ }
for point in data_points:
gt_class = point.get("ground_truth_normalized", "")
@@ -147,7 +153,11 @@ def calculate_f1_stats(data_points):
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
+ f1 = (
+ 2 * (precision * recall) / (precision + recall)
+ if (precision + recall) > 0
+ else 0
+ )
# Use canonical casing in output keys
pretty_name = supported_lower.get(class_name, class_name)
@@ -172,45 +182,47 @@ def calculate_f1_stats(data_points):
"total_classes": valid_classes,
}
+
def calculate_accuracy_stats(data_points):
"""Calculate accuracy statistics from data points"""
if not data_points:
return {}
-
+
total = len(data_points)
correct = sum(1 for point in data_points if point.get("accuracy", False))
accuracy_percentage = (correct / total) * 100 if total > 0 else 0
-
+
return {
"total_samples": total,
"correct_predictions": correct,
"incorrect_predictions": total - correct,
- "accuracy_percentage": accuracy_percentage
+ "accuracy_percentage": accuracy_percentage,
}
+
def parse_sleep_cot_jsonl(input_file, output_file=None):
"""Parse sleep COT JSONL file and extract JSON objects."""
if output_file is None:
input_path = Path(input_file)
output_file = str(input_path.parent / f"{input_path.stem}.clean.jsonl")
-
+
print(f"Parsing {input_file}")
print(f"Output will be saved to {output_file}")
-
+
# First, discover the actual labels from the ground truth data
global SUPPORTED_LABELS
discovered_labels = discover_ground_truth_labels(input_file)
SUPPORTED_LABELS = discovered_labels
-
+
print(f"Discovered {len(discovered_labels)} labels from ground truth data:")
for label in sorted(discovered_labels):
print(f" - {label}")
-
+
extracted_data = extract_structured_data(input_file)
-
+
if extracted_data:
print(f"Extracted {len(extracted_data)} data points")
-
+
# Calculate and display accuracy statistics
accuracy_stats = calculate_accuracy_stats(extracted_data)
print(f"\nAccuracy Statistics:")
@@ -218,89 +230,93 @@ def parse_sleep_cot_jsonl(input_file, output_file=None):
print(f"Correct predictions: {accuracy_stats['correct_predictions']}")
print(f"Incorrect predictions: {accuracy_stats['incorrect_predictions']}")
print(f"Accuracy: {accuracy_stats['accuracy_percentage']:.2f}%")
-
+
# Calculate and display F1 statistics
f1_stats = calculate_f1_stats(extracted_data)
print(f"\nF1 Score Statistics:")
print(f"Average F1 Score: {f1_stats['average_f1']:.4f}")
print(f"Macro-F1 Score: {f1_stats['macro_f1']:.4f}")
print(f"Total Classes: {f1_stats['total_classes']}")
-
+
# Display per-class F1 scores
- if f1_stats['class_f1_scores']:
+ if f1_stats["class_f1_scores"]:
print(f"\nPer-Class F1 Scores:")
- for class_name, scores in f1_stats['class_f1_scores'].items():
- print(f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}")
+ for class_name, scores in f1_stats["class_f1_scores"].items():
+ print(
+ f" {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}"
+ )
pass
-
- with open(output_file, 'w', encoding='utf-8') as f:
+
+ with open(output_file, "w", encoding="utf-8") as f:
for item in extracted_data:
f.write(json.dumps(item, indent=2) + "\n")
-
+
print(f"\nData saved to {output_file}")
return extracted_data
else:
print("No data could be extracted from the file.")
return []
+
def discover_ground_truth_labels(input_file):
"""Discover actual labels from ground truth data in the JSONL file"""
discovered_labels = set()
-
- with open(input_file, 'r', encoding='utf-8') as f:
+
+ with open(input_file, "r", encoding="utf-8") as f:
for line in f:
try:
data = json.loads(line.strip())
- gold_text = data.get('gold', '')
+ gold_text = data.get("gold", "")
ground_truth_raw = extract_answer(gold_text)
gt_canon, _ = _canonicalize_label(ground_truth_raw)
if gt_canon:
discovered_labels.add(gt_canon)
except (json.JSONDecodeError, Exception):
continue
-
+
return list(discovered_labels)
+
def extract_structured_data(input_file):
"""Extract structured data from JSONL file"""
data_points = []
-
- with open(input_file, 'r', encoding='utf-8') as f:
+
+ with open(input_file, "r", encoding="utf-8") as f:
for line_num, line in tqdm(enumerate(f, 1)):
try:
# Parse JSON line
data = json.loads(line.strip())
-
+
# Extract generated and gold fields
- generated_text = data.get('generated', '')
- gold_text = data.get('gold', '')
-
+ generated_text = data.get("generated", "")
+ gold_text = data.get("gold", "")
+
# Extract answers from both fields
model_prediction_raw = extract_answer(generated_text)
ground_truth_raw = extract_answer(gold_text)
# Canonicalize labels and merge stage 4 -> stage 3
pred_canon, pred_supported = _canonicalize_label(model_prediction_raw)
gt_canon, gt_supported = _canonicalize_label(ground_truth_raw)
-
+
# Calculate accuracy (exact match)
accuracy = (pred_canon == gt_canon) and gt_supported
-
+
# Calculate F1 score
f1_result = calculate_f1_score(model_prediction_raw, ground_truth_raw)
-
+
data_point = {
"generated": generated_text,
"model_prediction": model_prediction_raw,
"ground_truth": ground_truth_raw,
"accuracy": accuracy,
- "f1_score": f1_result['f1_score'],
- "precision": f1_result['precision'],
- "recall": f1_result['recall'],
- "prediction_normalized": f1_result['prediction_normalized'],
- "ground_truth_normalized": f1_result['ground_truth_normalized'],
- "prediction_supported": f1_result['prediction_supported'],
- "ground_truth_supported": f1_result['ground_truth_supported'],
- "line_number": line_num
+ "f1_score": f1_result["f1_score"],
+ "precision": f1_result["precision"],
+ "recall": f1_result["recall"],
+ "prediction_normalized": f1_result["prediction_normalized"],
+ "ground_truth_normalized": f1_result["ground_truth_normalized"],
+ "prediction_supported": f1_result["prediction_supported"],
+ "ground_truth_supported": f1_result["ground_truth_supported"],
+ "line_number": line_num,
}
data_points.append(data_point)
except json.JSONDecodeError as e:
@@ -309,24 +325,26 @@ def extract_structured_data(input_file):
except Exception as e:
print(f"Unexpected error on line {line_num}: {e}")
continue
-
+
return data_points
+
def extract_answer(text):
"""Extract the final answer from text"""
if "Answer: " not in text:
return text
-
+
answer = text.split("Answer: ")[-1].strip()
# Remove any end-of-text tokens (including and <|...|>)
- answer = re.sub(r'<\|.*?\|>|$', '', answer).strip()
+ answer = re.sub(r"<\|.*?\|>|$", "", answer).strip()
# Remove trailing periods and normalize
- answer = re.sub(r'\.$', '', answer).strip()
+ answer = re.sub(r"\.$", "", answer).strip()
return answer
+
if __name__ == "__main__":
current_dir = Path(__file__).parent
input_file = current_dir / "llama_1b_flamingo_predictions.jsonl"
clean_output = current_dir / "llama_1b_flamingo_predictions.clean.jsonl"
-
+
parse_sleep_cot_jsonl(input_file, clean_output)
diff --git a/evaluation/opentslm/sleep/plot_sleep_predictions.py b/evaluation/opentslm/sleep/plot_sleep_predictions.py
index 17cb618..c613859 100644
--- a/evaluation/opentslm/sleep/plot_sleep_predictions.py
+++ b/evaluation/opentslm/sleep/plot_sleep_predictions.py
@@ -22,30 +22,31 @@
OUTPUT_DIR = "sleep_cot_plots"
# Publication style
-plt.style.use('seaborn-v0_8')
+plt.style.use("seaborn-v0_8")
sns.set_palette("colorblind")
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
display_label_map = {
- 'W': 'Wake',
- 'N1': 'Non-REM stage 1',
- 'N2': 'Non-REM stage 2',
- 'N3': 'Non-REM stage 3',
- 'N4': 'Non-REM stage 4',
- 'REM': 'REM sleep',
- 'M': 'Movement',
- 'Unknown': 'Unknown'
+ "W": "Wake",
+ "N1": "Non-REM stage 1",
+ "N2": "Non-REM stage 2",
+ "N3": "Non-REM stage 3",
+ "N4": "Non-REM stage 4",
+ "REM": "REM sleep",
+ "M": "Movement",
+ "Unknown": "Unknown",
}
+
def plot_sample(row, idx):
- eeg_data = np.array(json.loads(row['eeg_data']))
- full_pred = row['full_prediction']
- gt_label = row['ground_truth_label']
- pred_label = row['predicted_label']
- sample_idx = row['sample_index']
- series_length = row['series_length']
+ eeg_data = np.array(json.loads(row["eeg_data"]))
+ full_pred = row["full_prediction"]
+ gt_label = row["ground_truth_label"]
+ pred_label = row["predicted_label"]
+ sample_idx = row["sample_index"]
+ series_length = row["series_length"]
# Map labels to pretty names
pretty_gt = display_label_map.get(gt_label, gt_label)
@@ -59,7 +60,7 @@ def plot_sample(row, idx):
elif len(full_pred) > text_length:
# Truncate if longer
full_pred = full_pred[:text_length]
-
+
# Add extra newlines to ensure consistent text box height
full_pred = full_pred + "\n"
@@ -71,27 +72,42 @@ def plot_sample(row, idx):
fig, ax1 = plt.subplots(figsize=(12, 7))
t = np.arange(len(eeg_plot))
# Use the same blue color as PAMAP2 plots (first color from colorblind palette)
- ax1.plot(t, eeg_plot, linewidth=2.5, color='#0173B2', alpha=0.8, label='EEG')
- ax1.set_xlabel('Time Step', fontsize=26)
- ax1.set_ylabel('Normalized EEG Amplitude', fontsize=26)
- ax1.set_title(f"Sample {sample_idx} | GT: {pretty_gt} | Pred: {pretty_pred}", fontsize=22, fontweight='bold')
- ax1.legend(fontsize=13, loc='upper right')
+ ax1.plot(t, eeg_plot, linewidth=2.5, color="#0173B2", alpha=0.8, label="EEG")
+ ax1.set_xlabel("Time Step", fontsize=26)
+ ax1.set_ylabel("Normalized EEG Amplitude", fontsize=26)
+ ax1.set_title(
+ f"Sample {sample_idx} | GT: {pretty_gt} | Pred: {pretty_pred}",
+ fontsize=22,
+ fontweight="bold",
+ )
+ ax1.legend(fontsize=13, loc="upper right")
ax1.grid(True, alpha=0.3)
- ax1.tick_params(axis='both', which='major', labelsize=26)
+ ax1.tick_params(axis="both", which="major", labelsize=26)
ax1.set_ylim(-3, 3)
ax1.set_yticks(np.linspace(-3, 3, 7))
# Add full_prediction as a text box below the plot (same as PAMAP2)
- plt.gcf().text(0.01, -0.02, f"Prediction:\n{full_pred}", fontsize=30, ha='left', va='top', wrap=True,
- bbox=dict(boxstyle='round', facecolor='whitesmoke', alpha=0.9, edgecolor='gray'))
+ plt.gcf().text(
+ 0.01,
+ -0.02,
+ f"Prediction:\n{full_pred}",
+ fontsize=30,
+ ha="left",
+ va="top",
+ wrap=True,
+ bbox=dict(
+ boxstyle="round", facecolor="whitesmoke", alpha=0.9, edgecolor="gray"
+ ),
+ )
plt.tight_layout(rect=[0, 0.05, 1, 1])
fname = f"sample_{idx+1:03d}_gt_{pretty_gt.lower().replace(' ', '_').replace('-', '_')}.png"
- plt.savefig(os.path.join(OUTPUT_DIR, fname), dpi=300, bbox_inches='tight')
+ plt.savefig(os.path.join(OUTPUT_DIR, fname), dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved {fname}")
+
def main():
df = pd.read_csv(CSV_PATH)
print(f"Loaded {len(df)} samples from {CSV_PATH}")
@@ -99,5 +115,6 @@ def main():
plot_sample(row, idx)
print(f"All plots saved to {OUTPUT_DIR}/")
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/get_memory_use.py b/get_memory_use.py
index a13d2cf..f39f4e1 100644
--- a/get_memory_use.py
+++ b/get_memory_use.py
@@ -420,17 +420,23 @@ def main():
res["dataset"],
res["loss"],
res["peak_cuda_bytes"],
- f"{peak_gb:.4f}"
- if isinstance(peak_gb, float) and peak_gb >= 0
- else peak_gb,
+ (
+ f"{peak_gb:.4f}"
+ if isinstance(peak_gb, float) and peak_gb >= 0
+ else peak_gb
+ ),
res.get("peak_cuda_reserved_bytes", -1),
- f"{peak_reserved_gb:.4f}"
- if isinstance(peak_reserved_gb, float) and peak_reserved_gb >= 0
- else peak_reserved_gb,
+ (
+ f"{peak_reserved_gb:.4f}"
+ if isinstance(peak_reserved_gb, float) and peak_reserved_gb >= 0
+ else peak_reserved_gb
+ ),
res.get("nvml_peak_bytes", -1),
- f"{nvml_peak_gb:.4f}"
- if isinstance(nvml_peak_gb, float) and nvml_peak_gb >= 0
- else nvml_peak_gb,
+ (
+ f"{nvml_peak_gb:.4f}"
+ if isinstance(nvml_peak_gb, float) and nvml_peak_gb >= 0
+ else nvml_peak_gb
+ ),
res["status"],
res["error"],
],
diff --git a/hf_test.py b/hf_test.py
new file mode 100644
index 0000000..dba29e5
--- /dev/null
+++ b/hf_test.py
@@ -0,0 +1,24 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+from src import OpenTSLM, TextPrompt, TextTimeSeriesPrompt, FullPrompt
+
+# Load model
+model = OpenTSLM.load_pretrained("OpenTSLM/gemma-3-270m-pt-har-flamingo")
+
+# Create prompt with raw time series data (normalization handled automatically)
+prompt = FullPrompt(
+ pre_prompt=TextPrompt("You are an expert in HAR analysis."),
+ text_time_series_prompt_list=[
+ TextTimeSeriesPrompt("X-axis accelerometer", [2.34, 2.34, 7.657, 3.21, -1.2])
+ ],
+ post_prompt=TextPrompt("What activity is this? Reasn step by step providing a full rationale before replying.")
+)
+
+# Generate response
+output = model.eval_prompt(prompt, normalize=True)
+print(output)
diff --git a/plot_memory_simulation.py b/plot_memory_simulation.py
new file mode 100644
index 0000000..250a537
--- /dev/null
+++ b/plot_memory_simulation.py
@@ -0,0 +1,220 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+"""
+Plot memory usage on simulation datasets from memory_simulation.csv.
+
+- Only uses datasets starting with 'Simulation-'.
+- Extracts time series length (L) and number of series (N).
+- Computes total_length = N * L.
+- Plots memory vs total_length per base model, comparing SoftPrompt vs Flamingo.
+- OOM runs (> 180GB) are shown with a dashed line, red X, and "OOM" label.
+- Always shows panels in order: gemma-270m, gemma-1b, llama-1b, llama-3b.
+"""
+
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+import matplotlib
+import re
+
+OOM_THRESHOLD = 180 # GB
+
+
+def parse_model_name(llm_id, model_type):
+ """Return base_model, config (SoftPrompt or Flamingo)."""
+ if llm_id.startswith("meta-llama/"):
+ base_name = llm_id.replace("meta-llama/", "")
+ elif llm_id.startswith("google/"):
+ base_name = llm_id.replace("google/", "")
+ else:
+ base_name = llm_id
+
+ # Normalize base model names to match expected order
+ if "Llama-3.2-1B" in base_name:
+ base_name = "Llama-3.2-1B"
+ elif "Llama-3.2-3B" in base_name:
+ base_name = "Llama-3.2-3B"
+ elif "gemma-3-270m" in base_name:
+ base_name = "Gemma-3-270M"
+ elif "gemma-3-1b-pt" in base_name:
+ base_name = "Gemma-3-1B-pt"
+
+ if model_type == "EmbedHealthSP":
+ type_name = "SoftPrompt"
+ elif model_type == "EmbedHealthFlamingo":
+ type_name = "Flamingo"
+ else:
+ type_name = model_type
+
+ return base_name, type_name
+
+
+def parse_simulation_dataset(name):
+ """Parse Simulation dataset name like 'Simulation-L10-N5' β (L=10, N=5)."""
+ match = re.match(r"Simulation-L(\d+)-N(\d+)", name)
+ if match:
+ return int(match.group(1)), int(match.group(2))
+ return None, None
+
+
+def plot_memory_usage_sim(csv_file="memory_simulation.csv"):
+ # --- Paper-style settings ---
+ plt.style.use("seaborn-v0_8-white")
+ matplotlib.rcParams.update({
+ "font.family": "serif",
+ "font.serif": ["Palatino", "Times New Roman", "DejaVu Serif"],
+ "font.size": 16,
+ "axes.labelsize": 18,
+ "axes.titlesize": 18,
+ "legend.fontsize": 15,
+ "xtick.labelsize": 15,
+ "ytick.labelsize": 15,
+ "axes.linewidth": 0.6,
+ "axes.edgecolor": "0.15",
+ })
+
+ df = pd.read_csv(csv_file)
+
+ # Replace -1 with NaN (ignore failed runs)
+ df["peak_cuda_reserved_gb"] = df["peak_cuda_reserved_gb"].replace(-1, pd.NA)
+
+ # Keep only simulation datasets
+ df = df[df["dataset"].str.startswith("Simulation-")]
+
+ # Parse model name and dataset details
+ df[["base_model", "config"]] = df.apply(
+ lambda row: pd.Series(parse_model_name(row["llm_id"], row["model"])), axis=1
+ )
+ df[["L", "N"]] = df["dataset"].apply(
+ lambda s: pd.Series(parse_simulation_dataset(s))
+ )
+ df = df.dropna(subset=["L", "N"])
+ df["L"] = df["L"].astype(int)
+ df["N"] = df["N"].astype(int)
+
+ # Compute total sequence length
+ df["total_length"] = df["L"] * df["N"]
+
+ # Sort
+ df = df.sort_values(by=["base_model", "config", "total_length"])
+
+ # Fixed base_model order
+ base_model_order = ["Gemma-3-270M", "Gemma-3-1B-pt", "Llama-3.2-1B", "Llama-3.2-3B"]
+
+ # One subplot per model (always 4)
+ n_models = len(base_model_order)
+ fig, axes = plt.subplots(1, n_models, figsize=(3.2 * n_models, 4.5), sharey=True)
+
+ if n_models == 1:
+ axes = [axes]
+
+ # Muted palette for configs - order matters for legend
+ palette = {"SoftPrompt": "#4477AA", "Flamingo": "#CC6677"}
+ config_order = ["SoftPrompt", "Flamingo"]
+
+ for ax, base_model in zip(axes, base_model_order):
+ subdf = df[df["base_model"] == base_model]
+
+ if subdf.empty:
+ ax.set_title(base_model, fontsize=13, fontweight="bold")
+ ax.set_facecolor("#F8F9FA")
+ ax.text(
+ 0.5, 0.5, "No data",
+ ha="center", va="center",
+ fontsize=10, color="gray"
+ )
+ ax.set_xticks([])
+ ax.set_yticks([])
+ continue
+
+ for cfg in config_order:
+ cfg_df = subdf[subdf["config"] == cfg]
+ if cfg_df.empty:
+ continue
+ cfg_df = cfg_df.sort_values("total_length")
+ color = palette[cfg]
+
+ # Successful runs (β€ threshold)
+ ok_df = cfg_df[cfg_df["peak_cuda_reserved_gb"] <= OOM_THRESHOLD]
+ ax.plot(
+ ok_df["total_length"], ok_df["peak_cuda_reserved_gb"],
+ label=cfg,
+ color=color,
+ linewidth=3.0,
+ alpha=0.9,
+ )
+
+ # First OOM run (if any)
+ oom_df = cfg_df[cfg_df["peak_cuda_reserved_gb"] > OOM_THRESHOLD]
+ if not oom_df.empty and not ok_df.empty:
+ first_oom = oom_df.iloc[0]
+ last_ok = ok_df.iloc[-1]
+
+ # dashed line up to OOM
+ ax.plot(
+ [last_ok["total_length"], first_oom["total_length"]],
+ [last_ok["peak_cuda_reserved_gb"], OOM_THRESHOLD * 1.05],
+ color=color,
+ linestyle="--",
+ linewidth=1.5,
+ alpha=0.8,
+ )
+
+ # red X marker
+ ax.scatter(
+ first_oom["total_length"], OOM_THRESHOLD * 1.05,
+ color="red",
+ marker="x",
+ s=70,
+ linewidth=2,
+ zorder=5,
+ )
+ ax.text(
+ first_oom["total_length"], OOM_THRESHOLD * 1.05,
+ "OOM", color="red", fontsize=9,
+ fontweight="bold", ha="center", va="bottom"
+ )
+
+ # Titles & labels
+ ax.set_title(base_model, fontsize=17, fontweight="bold")
+
+ # Only show axis labels on specific subplots
+ if ax == axes[0]: # Leftmost subplot
+ ax.set_ylabel("Peak CUDA Reserved (GB)", fontsize=16, fontweight="bold")
+ ax.set_xlabel("Total Sequence Length (N Γ L)", fontsize=16, fontweight="bold")
+ else:
+ ax.set_ylabel("")
+ ax.set_xlabel("")
+ ax.set_facecolor("#F8F9FA")
+ ax.grid(True, which="major", linestyle="-", linewidth=0.4, alpha=0.5)
+ ax.grid(True, which="minor", linestyle=":", linewidth=0.3, alpha=0.3)
+ ax.minorticks_on()
+ ax.tick_params(axis="both", labelsize=15)
+
+ # Legend only in first subplot
+ if ax == axes[0]:
+ leg = ax.legend(title=None, fontsize=15, loc="best", frameon=True,
+ framealpha=0.95, edgecolor="0.3")
+ for text in leg.get_texts():
+ text.set_fontweight("bold")
+
+ plt.tight_layout(pad=0.5)
+ for fmt in ["png", "pdf"]:
+ plt.savefig(
+ f"memory_usage_simulation.{fmt}",
+ dpi=300 if fmt == "png" else None,
+ bbox_inches="tight",
+ pad_inches=0,
+ facecolor="white",
+ format=fmt,
+ )
+ plt.show()
+
+
+if __name__ == "__main__":
+ plot_memory_usage_sim()
diff --git a/plot_memory_simulation_per_length.py b/plot_memory_simulation_per_length.py
new file mode 100644
index 0000000..3431a14
--- /dev/null
+++ b/plot_memory_simulation_per_length.py
@@ -0,0 +1,286 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+"""
+Paper-style plots: memory usage scaling with N for different lengths (L).
+
+- Rows = config (SoftPrompt, Flamingo)
+- Cols = sequence lengths (L) [excluding L=1]
+- Hue = base model
+- Y-axis sharing logic:
+ * Flamingo: all panels share y-axis
+ * SoftPrompt: all panels have independent y-axes
+- OOM cases (peak_cuda_reserved_gb > OOM_THRESHOLD) are shown by extending the
+ line upward and marking with a red X + "OOM".
+"""
+
+import pandas as pd
+import seaborn as sns
+import matplotlib.pyplot as plt
+import matplotlib
+import re
+from matplotlib.lines import Line2D
+
+OOM_THRESHOLD = 180 # GB
+
+
+def parse_model_name(llm_id, model_type):
+ """Return base_model, config (SoftPrompt or Flamingo)."""
+ if llm_id.startswith("meta-llama/"):
+ base_name = llm_id.replace("meta-llama/", "")
+ elif llm_id.startswith("google/"):
+ base_name = llm_id.replace("google/", "")
+ else:
+ base_name = llm_id
+
+ # Normalize base model names to match expected order
+ if "Llama-3.2-1B" in base_name:
+ base_name = "Llama-3.2-1B"
+ elif "Llama-3.2-3B" in base_name:
+ base_name = "Llama-3.2-3B"
+ elif "gemma-3-270m" in base_name:
+ base_name = "Gemma-3-270M"
+ elif "gemma-3-1b-pt" in base_name:
+ base_name = "Gemma-3-1B-pt"
+
+ if model_type == "EmbedHealthSP":
+ type_name = "SoftPrompt"
+ elif model_type == "EmbedHealthFlamingo":
+ type_name = "Flamingo"
+ else:
+ type_name = model_type
+
+ return base_name, type_name
+
+
+def parse_simulation_dataset(name):
+ """Parse Simulation dataset name like 'Simulation-L10-N5' β (L=10, N=5)."""
+ match = re.match(r"Simulation-L(\d+)-N(\d+)", name)
+ if match:
+ return int(match.group(1)), int(match.group(2))
+ return None, None
+
+
+def plot_memory_usage_paper(csv_file="memory_simulation.csv"):
+ # Publication style
+ plt.style.use("seaborn-v0_8-white")
+ matplotlib.rcParams.update({
+ "font.family": "serif",
+ "font.serif": ["Palatino", "Times New Roman", "DejaVu Serif"],
+ "font.size": 18,
+ "axes.labelsize": 20,
+ "axes.titlesize": 20,
+ "legend.fontsize": 17,
+ "xtick.labelsize": 17,
+ "ytick.labelsize": 17,
+ "axes.linewidth": 0.6,
+ "axes.edgecolor": "0.15",
+ })
+
+ # Load & preprocess
+ df = pd.read_csv(csv_file)
+ df["peak_cuda_reserved_gb"] = df["peak_cuda_reserved_gb"].replace(-1, pd.NA)
+ df = df[df["dataset"].str.startswith("Simulation-")]
+ df[["base_model", "config"]] = df.apply(
+ lambda row: pd.Series(parse_model_name(row["llm_id"], row["model"])), axis=1
+ )
+ df[["L", "N"]] = df["dataset"].apply(
+ lambda s: pd.Series(parse_simulation_dataset(s))
+ )
+ df = df.dropna(subset=["L", "N"])
+ df["L"] = df["L"].astype(int)
+ df["N"] = df["N"].astype(int)
+ df = df[df["L"] != 1]
+ df = df.sort_values(by=["base_model", "config", "L", "N"])
+
+ # Palette + markers - use consistent colors with the other script
+ # Define consistent order and colors for each model
+ model_order = ["Gemma-3-270M", "Gemma-3-1B-pt", "Llama-3.2-1B", "Llama-3.2-3B"]
+ color_map = {
+ "Gemma-3-270M": "#4477AA",
+ "Gemma-3-1B-pt": "#66CCEE",
+ "Llama-3.2-1B": "#228833",
+ "Llama-3.2-3B": "#CC6677"
+ }
+
+ # Get base models in the specified order
+ base_models = [bm for bm in model_order if bm in df["base_model"].unique()]
+ custom_palette = [color_map.get(bm, "#888888") for bm in base_models]
+ markers_dict = dict(zip(
+ base_models,
+ ["o", "s", "^", "D", "p", "X", "*"]
+ ))
+
+ # Unique sequence lengths
+ unique_L = sorted(df["L"].unique())
+
+ # Create subplot grid manually: 2 rows (SoftPrompt, Flamingo)
+ fig, axes = plt.subplots(
+ 2, len(unique_L),
+ figsize=(3.2 * len(unique_L), 6),
+ sharex="col",
+ )
+
+ # Row mapping
+ row_map = {"SoftPrompt": 0, "Flamingo": 1}
+
+ # Precompute Flamingo y-lims
+ flamingo_df = df[df["config"] == "Flamingo"]
+ flamingo_ymin, flamingo_ymax = None, None
+ if not flamingo_df.empty:
+ flamingo_ymin = flamingo_df["peak_cuda_reserved_gb"].min(skipna=True)
+ flamingo_ymax = flamingo_df["peak_cuda_reserved_gb"].max(skipna=True)
+
+ flamingo_ymin = 0
+ flamingo_ymax = max(flamingo_ymax if flamingo_ymax else 0, OOM_THRESHOLD * 1.1)
+
+ # Iterate configs
+ for cfg in ["SoftPrompt", "Flamingo"]:
+ cfg_df = df[df["config"] == cfg]
+
+ for j, L in enumerate(unique_L):
+ ax = axes[row_map[cfg], j]
+ subdf = cfg_df[cfg_df["L"] == L]
+
+ for bm, sdf in subdf.groupby("base_model"):
+ sdf = sdf.sort_values("N")
+
+ # Split successful vs OOM runs
+ ok_df = sdf[sdf["peak_cuda_reserved_gb"] <= OOM_THRESHOLD]
+ oom_df = sdf[sdf["peak_cuda_reserved_gb"] > OOM_THRESHOLD]
+
+ # Normal line
+ if not ok_df.empty:
+ ax.plot(
+ ok_df["N"], ok_df["peak_cuda_reserved_gb"],
+ label=bm,
+ color=custom_palette[base_models.index(bm)],
+ marker=markers_dict[bm],
+ linewidth=2.2,
+ markersize=5,
+ alpha=0.9,
+ )
+
+ # OOM handling
+ if not oom_df.empty:
+ first_oom = oom_df.iloc[0]
+ last_ok_y = ok_df["peak_cuda_reserved_gb"].iloc[-1] if not ok_df.empty else OOM_THRESHOLD * 0.9
+
+ # extend line upward
+ ax.plot(
+ [ok_df["N"].iloc[-1] if not ok_df.empty else first_oom["N"], first_oom["N"]],
+ [last_ok_y, OOM_THRESHOLD * 1.05],
+ color=custom_palette[base_models.index(bm)],
+ linestyle="--",
+ linewidth=1.5,
+ alpha=0.8,
+ )
+
+ # red X marker at OOM
+ ax.scatter(
+ first_oom["N"], OOM_THRESHOLD * 1.05,
+ color="red",
+ marker="x",
+ s=70,
+ linewidth=2,
+ zorder=5,
+ )
+ ax.text(
+ first_oom["N"], OOM_THRESHOLD * 1.05,
+ "OOM", color="red", fontsize=9,
+ fontweight="bold", ha="center", va="bottom"
+ )
+
+ # Titles
+ if row_map[cfg] == 0:
+ ax.set_title(f"L = {L}", fontsize=19, fontweight="bold")
+
+ # Y labels only leftmost col
+ if j == 0:
+ ax.set_ylabel(
+ f"{cfg}",
+ fontsize=18,
+ fontweight="bold"
+ )
+ else:
+ ax.set_ylabel("")
+
+ # Styling
+ ax.set_facecolor("#F8F9FA")
+ ax.grid(True, which="major", linestyle="-", linewidth=0.4, alpha=0.5)
+ ax.grid(True, which="minor", linestyle=":", linewidth=0.3, alpha=0.3)
+ ax.minorticks_on()
+ # Remove individual x-axis labels - will add global one later
+ ax.set_xlabel("")
+ ax.set_xticks([1, 2, 3, 4, 5])
+ ax.tick_params(axis="x", which="both", labelbottom=True)
+
+ # Y-axis rules
+ if cfg == "Flamingo":
+ ax.set_ylim(0, 65)
+ elif j <= 2:
+ ax.set_ylim(0, 25)
+
+ # Legend (global, right side) - create in specified order
+ # Get handles and labels from first subplot
+ handles_dict = {}
+ labels_dict = {}
+ for handle, label in zip(*axes[0, 0].get_legend_handles_labels()):
+ handles_dict[label] = handle
+ labels_dict[label] = label
+
+ # Create ordered handles and labels
+ ordered_handles = []
+ ordered_labels = []
+ for bm in base_models:
+ if bm in handles_dict:
+ ordered_handles.append(handles_dict[bm])
+ ordered_labels.append(labels_dict[bm])
+
+ # Add OOM handle
+ oom_handle = Line2D([0], [0], color="red", marker="x", linestyle="--",
+ markersize=8, label="Out of Memory (OOM)")
+ ordered_handles.append(oom_handle)
+ ordered_labels.append("Out of Memory (OOM)")
+
+ # Legend in top left plot only
+ top_left_ax = axes[0, 0]
+ top_left_ax.legend(
+ ordered_handles, ordered_labels,
+ title=None,
+ loc="upper right",
+ frameon=True,
+ framealpha=0.95,
+ edgecolor="0.3",
+ fontsize=12,
+ )
+
+ # Add main vertical title
+ fig.text(0.02, 0.5, "Peak Memory (GB)", rotation=90, fontsize=20, fontweight="bold", ha="center", va="center")
+
+ # Add global x-axis label spanning the bottom row
+ fig.text(0.5, 0.01, "Number of Time Series (N)", fontsize=16, fontweight="bold", ha="center", va="center")
+
+ # Layout - no longer need space for bottom legend
+ plt.tight_layout(pad=0.5)
+ plt.subplots_adjust(left=0.08)
+
+ # Save
+ for fmt in ["png", "pdf"]:
+ plt.savefig(
+ f"memory_usage_paper.{fmt}",
+ dpi=300 if fmt == "png" else None,
+ bbox_inches="tight",
+ pad_inches=0,
+ facecolor="white",
+ format=fmt,
+ )
+ plt.show()
+
+
+if __name__ == "__main__":
+ plot_memory_usage_paper()
diff --git a/src/model/__init__.py b/src/model/__init__.py
index 7f7c8a3..7566f7d 100644
--- a/src/model/__init__.py
+++ b/src/model/__init__.py
@@ -4,4 +4,4 @@
# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
#
# SPDX-License-Identifier: MIT
-#
\ No newline at end of file
+#
diff --git a/src/model/encoder/CNNTokenizer.py b/src/model/encoder/CNNTokenizer.py
index 74338ec..bee3780 100644
--- a/src/model/encoder/CNNTokenizer.py
+++ b/src/model/encoder/CNNTokenizer.py
@@ -21,7 +21,7 @@ def __init__(
dropout: float = 0.0,
transformer_input_dim: int = TRANSFORMER_INPUT_DIM,
patch_size: int = PATCH_SIZE,
- max_patches: int = 2600,
+ max_patches: int = 1024,
):
"""
Args:
diff --git a/src/model/encoder/TransformerCNNEncoder.py b/src/model/encoder/TransformerCNNEncoder.py
index d5e6a28..5c96844 100644
--- a/src/model/encoder/TransformerCNNEncoder.py
+++ b/src/model/encoder/TransformerCNNEncoder.py
@@ -24,7 +24,7 @@ def __init__(
num_layers: int = 6,
patch_size: int = PATCH_SIZE,
ff_dim: int = 1024,
- max_patches: int = 2600,
+ max_patches: int = 1024,
):
"""
Args:
diff --git a/src/model/encoder/TransformerMLPEncoder.py b/src/model/encoder/TransformerMLPEncoder.py
index efe4328..6aa59e5 100644
--- a/src/model/encoder/TransformerMLPEncoder.py
+++ b/src/model/encoder/TransformerMLPEncoder.py
@@ -24,7 +24,7 @@ def __init__(
num_layers: int = 6,
patch_size: int = PATCH_SIZE,
ff_dim: int = 2048,
- max_patches: int = 2600,
+ max_patches: int = 1024,
):
"""
Args:
diff --git a/src/model/encoder/__init__.py b/src/model/encoder/__init__.py
index 7f7c8a3..7566f7d 100644
--- a/src/model/encoder/__init__.py
+++ b/src/model/encoder/__init__.py
@@ -4,4 +4,4 @@
# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
#
# SPDX-License-Identifier: MIT
-#
\ No newline at end of file
+#
diff --git a/src/model/llm/OpenTSLM.py b/src/model/llm/OpenTSLM.py
new file mode 100644
index 0000000..d705e9b
--- /dev/null
+++ b/src/model/llm/OpenTSLM.py
@@ -0,0 +1,175 @@
+#
+# This source file is part of the OpenTSLM open-source project
+#
+# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
+#
+# SPDX-License-Identifier: MIT
+#
+import torch
+from typing import Optional, Union
+from enum import Enum
+from huggingface_hub import hf_hub_download
+
+from .OpenTSLMSP import OpenTSLMSP
+from .OpenTSLMFlamingo import OpenTSLMFlamingo
+
+
+class ModelType(Enum):
+ """Enumeration of supported model types."""
+
+ SP = "sp"
+ FLAMINGO = "flamingo"
+
+
+class OpenTSLM:
+ """
+ Factory class for loading EmbedHealth models from Hugging Face Hub.
+
+ Automatically detects model type based on repository ID suffix and returns
+ the appropriate model instance (EmbedHealthSP or EmbedHealthFlamingo) with
+ optimal parameters from curriculum learning training.
+
+ - Repository IDs ending with "-sp" load EmbedHealthSP models
+ - Repository IDs ending with "-flamingo" load EmbedHealthFlamingo models
+
+ The factory automatically applies the exact same parameters used in curriculum learning:
+ - EmbedHealthSP: Uses default constructor parameters
+ - EmbedHealthFlamingo: cross_attn_every_n_layers=1, gradient_checkpointing=False
+
+ These parameters are fixed and cannot be overridden since they were determined during training.
+
+ Example:
+ >>> model = OpenTSLM.load_pretrained("OpenTSLM/gemma-3-270m-pt-sleep-flamingo")
+ >>>
+ >>> from prompt.full_prompt import FullPrompt
+ >>> prompt = FullPrompt(...)
+ >>> response = model.eval_prompt(prompt)
+ """
+
+ @classmethod
+ def load_pretrained(
+ cls,
+ repo_id: str,
+ device: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ enable_lora: Optional[bool] = False,
+ ) -> Union[OpenTSLMSP, OpenTSLMFlamingo]:
+ """
+ Load a pretrained model from Hugging Face Hub.
+
+ Args:
+ repo_id: Hugging Face repository ID (e.g., "OpenTSLM/gemma-3-270m-pt-sleep-flamingo")
+ device: Device to load the model on (default: auto-detect)
+ cache_dir: Directory to cache downloaded models (optional)
+ enable_lora: Whether to enable LoRA (default: False)
+
+ Returns:
+ Union[OpenTSLMSP, OpenTSLMFlamingo]: The loaded model instance
+
+ Example:
+ >>> model = OpenTSLM.load_pretrained("OpenTSLM/gemma-3-270m-pt-sleep-flamingo")
+ >>> prompt = FullPrompt(...)
+ >>> response = model.eval_prompt(prompt)
+ """
+ device = cls._get_device(device)
+ model_type = cls._detect_model_type(repo_id)
+ checkpoint_path = cls._download_model_files(repo_id, cache_dir)
+ base_llm_id = cls._get_base_llm_id(repo_id)
+
+ print(f"π Loading {model_type.value.upper()} model...")
+ print(f" Repository: {repo_id}")
+ print(f" Base LLM: {base_llm_id}")
+ print(f" Device: {device}")
+
+ # Instantiate model with fixed training parameters
+ if model_type == ModelType.SP:
+ # OpenTSLMSP uses default parameters from curriculum learning
+ model = OpenTSLMSP(llm_id=base_llm_id, device=device)
+ if enable_lora:
+ model.enable_lora()
+ elif model_type == ModelType.FLAMINGO:
+ # OpenTSLMFlamingo with fixed parameters from curriculum learning
+ model = OpenTSLMFlamingo(
+ device=device,
+ llm_id=base_llm_id,
+ cross_attn_every_n_layers=1,
+ gradient_checkpointing=False,
+ )
+ else:
+ raise ValueError(f"Unknown model type: {model_type}")
+
+ # Load the checkpoint
+ model.load_from_file(checkpoint_path)
+ model.eval()
+
+ print(f"β
{model_type.value.upper()} model loaded successfully!")
+ return model
+
+ @staticmethod
+ def _get_device(device: Optional[str]) -> str:
+ """Auto-detect device if not specified."""
+ if device is not None:
+ return device
+
+ if torch.cuda.is_available():
+ return "cuda"
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
+ return "mps"
+ else:
+ return "cpu"
+
+ @staticmethod
+ def _detect_model_type(repo_id: str) -> ModelType:
+ """Detect model type from repository ID suffix."""
+ if repo_id.endswith("-sp"):
+ return ModelType.SP
+ elif repo_id.endswith("-flamingo"):
+ return ModelType.FLAMINGO
+ else:
+ raise ValueError(
+ f"Repository ID '{repo_id}' must end with either '-sp' or '-flamingo' "
+ f"to indicate the model type."
+ )
+
+ @staticmethod
+ def _download_model_files(repo_id: str, cache_dir: Optional[str] = None) -> str:
+ """Download model checkpoint from Hugging Face Hub."""
+ try:
+ # Download the main model checkpoint file
+ checkpoint_path = hf_hub_download(
+ repo_id=repo_id,
+ filename="model_checkpoint.pt",
+ cache_dir=cache_dir,
+ local_files_only=False,
+ )
+ print(f"β
Downloaded model checkpoint from {repo_id}")
+ return checkpoint_path
+
+ except Exception as e:
+ raise RuntimeError(
+ f"Failed to download model from {repo_id}. "
+ f"Tried 'model_checkpoint.pt'. "
+ f"Original error: {e}"
+ )
+
+ @staticmethod
+ def _get_base_llm_id(repo_id: str) -> str:
+ """Get the base LLM ID from static mapping based on repository ID pattern."""
+ repo_name = repo_id.split("/")[-1] if "/" in repo_id else repo_id
+
+ # Extract base model from repository name pattern
+ if repo_name.startswith("llama-3.2-3b"):
+ return "meta-llama/Llama-3.2-3B"
+ elif repo_name.startswith("llama-3.2-1b"):
+ return "meta-llama/Llama-3.2-1B"
+ elif repo_name.startswith("gemma-3-1b"):
+ return "google/gemma-3-1b"
+ elif repo_name.startswith("gemma-3-270m"):
+ return "google/gemma-3-270m"
+ else:
+ # Raise exception if pattern doesn't match
+ raise ValueError(
+ f"Unable to determine base LLM ID from repository name '{repo_name}'. "
+ f"Repository name must start with one of: 'llama-3.2-3b', 'llama-3.2-1b', "
+ f"'gemma-3-1b', or 'gemma-3-270m'."
+ )
diff --git a/src/model/llm/OpenTSLMFlamingo.py b/src/model/llm/OpenTSLMFlamingo.py
index 6f43181..58517c7 100644
--- a/src/model/llm/OpenTSLMFlamingo.py
+++ b/src/model/llm/OpenTSLMFlamingo.py
@@ -349,19 +349,24 @@ def load_from_file(self, path: str = "best_model.pt"):
print(f" ... and {len(unexpected_keys) - 10} more keys")
self.to(self.device)
- def eval_prompt(self, prompt: FullPrompt, max_new_tokens: int = 30000) -> str:
+ def eval_prompt(
+ self, prompt: FullPrompt, max_new_tokens: int = 1000, normalize: bool = False
+ ) -> str:
"""
Evaluate a prompt and return the generated text.
"""
# Temporarily disable compilation to avoid data-dependent operation issues
original_disable = torch._dynamo.config.disable
torch._dynamo.config.disable = True
-
try:
batch = [prompt.to_dict()]
self.eval()
- batch = extend_time_series_to_match_patch_size_and_aggregate(batch)
+ batch = extend_time_series_to_match_patch_size_and_aggregate(
+ batch, normalize=normalize
+ )
+ print("Generating")
output = self.generate(batch, max_new_tokens=max_new_tokens)
+ print(f"Generated output: {output[0]}")
return output[0]
finally:
# Restore original compilation setting
diff --git a/src/model/llm/OpenTSLMSP.py b/src/model/llm/OpenTSLMSP.py
index 5bc4302..d594433 100644
--- a/src/model/llm/OpenTSLMSP.py
+++ b/src/model/llm/OpenTSLMSP.py
@@ -21,9 +21,9 @@
print("Warning: peft not available. LoRA fine-tuning will be disabled.")
from model_config import ENCODER_OUTPUT_DIM
-from model.llm.TimeSeriesLLM import TimeSeriesLLM
-from model.encoder.TransformerCNNEncoder import TransformerCNNEncoder
-from model.projector.MLPProjector import MLPProjector
+from .TimeSeriesLLM import TimeSeriesLLM
+from ..encoder.TransformerCNNEncoder import TransformerCNNEncoder
+from ..projector.MLPProjector import MLPProjector
from prompt.full_prompt import FullPrompt
from time_series_datasets.util import (
extend_time_series_to_match_patch_size_and_aggregate,
@@ -493,13 +493,17 @@ def save_lora_state_to_checkpoint(self, checkpoint: dict):
return 0
- def eval_prompt(self, prompt: FullPrompt, max_new_tokens: int = 30000) -> str:
+ def eval_prompt(
+ self, prompt: FullPrompt, max_new_tokens: int = 30000, normalize: bool = False
+ ) -> str:
"""
Evaluate a prompt and return the generated text.
"""
batch = [prompt.to_dict()]
self.eval()
- batch = extend_time_series_to_match_patch_size_and_aggregate(batch)
+ batch = extend_time_series_to_match_patch_size_and_aggregate(
+ batch, normalize=normalize
+ )
output = self.generate(batch, max_new_tokens=max_new_tokens)
return output[0]
diff --git a/src/model/llm/__init__.py b/src/model/llm/__init__.py
index 7f7c8a3..7566f7d 100644
--- a/src/model/llm/__init__.py
+++ b/src/model/llm/__init__.py
@@ -4,4 +4,4 @@
# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
#
# SPDX-License-Identifier: MIT
-#
\ No newline at end of file
+#
diff --git a/src/model/projector/__init__.py b/src/model/projector/__init__.py
index 7f7c8a3..7566f7d 100644
--- a/src/model/projector/__init__.py
+++ b/src/model/projector/__init__.py
@@ -4,4 +4,4 @@
# SPDX-FileCopyrightText: 2025 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
#
# SPDX-License-Identifier: MIT
-#
\ No newline at end of file
+#
diff --git a/src/time_series_datasets/har_cot/HARAccQADataset.py b/src/time_series_datasets/har_cot/HARAccQADataset.py
index d1ee850..8250b03 100644
--- a/src/time_series_datasets/har_cot/HARAccQADataset.py
+++ b/src/time_series_datasets/har_cot/HARAccQADataset.py
@@ -28,8 +28,16 @@
class HARAccQADataset(QADataset):
- def __init__(self, split: Literal["train", "test", "validation"], EOS_TOKEN: str, format_sample_str: bool = False, time_series_format_function=None):
- super().__init__(split, EOS_TOKEN, format_sample_str, time_series_format_function)
+ def __init__(
+ self,
+ split: Literal["train", "test", "validation"],
+ EOS_TOKEN: str,
+ format_sample_str: bool = False,
+ time_series_format_function=None,
+ ):
+ super().__init__(
+ split, EOS_TOKEN, format_sample_str, time_series_format_function
+ )
def _load_splits(self) -> Tuple[Dataset, Dataset, Dataset]:
"""
@@ -107,7 +115,9 @@ def _format_sample(self, row):
dataset_val = HARAccQADataset(split="validation", EOS_TOKEN="")
dataset_test = HARAccQADataset(split="test", EOS_TOKEN="")
- print(f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}")
+ print(
+ f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}"
+ )
dataloader = DataLoader(
dataset_test,
diff --git a/src/time_series_datasets/m4/m4_loader.py b/src/time_series_datasets/m4/m4_loader.py
index 14fc5ea..34c8974 100644
--- a/src/time_series_datasets/m4/m4_loader.py
+++ b/src/time_series_datasets/m4/m4_loader.py
@@ -28,13 +28,16 @@
from typing import Dict, List, Literal, Optional, Tuple
from datasets import Dataset
from sklearn.model_selection import train_test_split
+from time_series_datasets.constants import RAW_DATA
# ---------------------------
# Constants
# ---------------------------
RELEASE_URL = "https://polybox.ethz.ch/index.php/s/MT3y9WdEebT8wfj/download/M4TimeSeriesCaptionDatasetV02.zip"
-DATA_DIR = "data/M4TimeSeriesCaptionDataset"
+
+
+DATA_DIR = os.path.join(RAW_DATA, "M4TimeSeriesCaptionDataset")
GENERATED_DATA_DIR = os.path.join(DATA_DIR, "M4TimeSeriesCaptionDataset")
AVAILABLE_FREQUENCIES = ["Daily", "Hourly", "Monthly", "Quarterly", "Weekly", "Yearly"]
diff --git a/src/time_series_datasets/sleep/SleepEDFQADataset.py b/src/time_series_datasets/sleep/SleepEDFQADataset.py
index 6bba122..24c7ead 100644
--- a/src/time_series_datasets/sleep/SleepEDFQADataset.py
+++ b/src/time_series_datasets/sleep/SleepEDFQADataset.py
@@ -10,15 +10,27 @@
from typing import List, Tuple, Literal
import sys
import os
-sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) )
+
+sys.path.append(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+)
from prompt.text_time_series_prompt import TextTimeSeriesPrompt
from time_series_datasets.QADataset import QADataset
from time_series_datasets.sleep.sleepedf_cot_loader import load_sleepedf_cot_splits
import numpy as np
+
class SleepEDFCoTQADataset(QADataset):
- def __init__(self, split: Literal["train", "test", "validation"], EOS_TOKEN: str, format_sample_str: bool = False, time_series_format_function=None):
- super().__init__(split, EOS_TOKEN, format_sample_str, time_series_format_function)
+ def __init__(
+ self,
+ split: Literal["train", "test", "validation"],
+ EOS_TOKEN: str,
+ format_sample_str: bool = False,
+ time_series_format_function=None,
+ ):
+ super().__init__(
+ split, EOS_TOKEN, format_sample_str, time_series_format_function
+ )
def _load_splits(self) -> Tuple[Dataset, Dataset, Dataset]:
return load_sleepedf_cot_splits()
@@ -71,13 +83,20 @@ def _get_text_time_series_prompt_list(self, row) -> List[TextTimeSeriesPrompt]:
std = max(std, min_std)
series_norm = (series - mean) / std
text_prompt = f"The following is the EEG time series, it has mean {mean:.4f} and std {std:.4f}:"
-
+
return [TextTimeSeriesPrompt(text_prompt, series_norm.tolist())]
@staticmethod
def get_labels() -> List[str]:
# This could be made dynamic, but for now, use the standard sleep stages
- return ["Wake", "Non-REM stage 1", "Non-REM stage 2", "Non-REM stage 3", "REM sleep", "Movement"]
+ return [
+ "Wake",
+ "Non-REM stage 1",
+ "Non-REM stage 2",
+ "Non-REM stage 3",
+ "REM sleep",
+ "Movement",
+ ]
def _format_sample(self, row):
sample = super()._format_sample(row)
@@ -85,15 +104,21 @@ def _format_sample(self, row):
sample["original_data"] = row["time_series"]
return sample
+
if __name__ == "__main__":
dataset = SleepEDFCoTQADataset(split="train", EOS_TOKEN="")
dataset_val = SleepEDFCoTQADataset(split="validation", EOS_TOKEN="")
dataset_test = SleepEDFCoTQADataset(split="test", EOS_TOKEN="")
- print(f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}")
+ print(
+ f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}"
+ )
if len(dataset) > 0:
sample = dataset[0]
print("Sample keys:", sample.keys())
print("Sample answer:", sample["answer"])
- print("Sample time series text:", sample["time_series_text"] if "time_series_text" in sample else "N/A")
+ print(
+ "Sample time series text:",
+ sample["time_series_text"] if "time_series_text" in sample else "N/A",
+ )
print("Sample pre prompt:", sample["pre_prompt"])
- print("Sample post prompt:", sample["post_prompt"])
\ No newline at end of file
+ print("Sample post prompt:", sample["post_prompt"])
diff --git a/src/time_series_datasets/util.py b/src/time_series_datasets/util.py
index 6850a4d..3649245 100644
--- a/src/time_series_datasets/util.py
+++ b/src/time_series_datasets/util.py
@@ -19,9 +19,12 @@
def extend_time_series_to_match_patch_size_and_aggregate(
- batch, *, patch_size: int = PATCH_SIZE
+ batch, *, patch_size: int = PATCH_SIZE, normalize: bool = False
):
- """Pad variable-length series so each sample length is a multiple of *patch_size*."""
+ """
+ Pad variable-length series so each sample length is a multiple of *patch_size*.
+ Optionally normalize each time series to have zero mean and unit variance.
+ """
for element in batch:
# 1) pull out the list of (1D) timeβseries
@@ -30,13 +33,26 @@ def extend_time_series_to_match_patch_size_and_aggregate(
# 2) convert each to a torch.Tensor (float)
ts_tensors = [torch.as_tensor(ts, dtype=torch.float32) for ts in ts_list]
- # 3) find the longest series length
+ # 3) normalize each time series if requested
+ if normalize:
+ normalized_tensors = []
+ for ts in ts_tensors:
+ mean = ts.mean()
+ std = ts.std()
+ if std > 1e-8: # Avoid division by zero
+ ts_normalized = (ts - mean) / std
+ else:
+ ts_normalized = ts - mean
+ normalized_tensors.append(ts_normalized)
+ ts_tensors = normalized_tensors
+
+ # 4) find the longest series length
max_len = max([ts.size(0) for ts in ts_tensors])
- # 4) round up to nearest multiple of patch_size
+ # 5) round up to nearest multiple of patch_size
padded_len = ((max_len + patch_size - 1) // patch_size) * patch_size
- # 5) pad (or trim) each series to padded_len
+ # 6) pad (or trim) each series to padded_len
padded = []
for ts in ts_tensors:
L = ts.size(0)
@@ -47,7 +63,7 @@ def extend_time_series_to_match_patch_size_and_aggregate(
ts = ts[:padded_len]
padded.append(ts)
- # 6) stack into a single 2D tensor: (num_series, padded_len)
+ # 7) stack into a single 2D tensor: (num_series, padded_len)
element["time_series"] = torch.stack(padded, dim=0)
return batch