diff --git a/experiments/exp10_routing_temperature_specialization/EXPERIMENT_CARD.txt b/experiments/exp10_routing_temperature_specialization/EXPERIMENT_CARD.txt new file mode 100644 index 0000000..354a844 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/EXPERIMENT_CARD.txt @@ -0,0 +1,163 @@ +╔══════════════════════════════════════════════════════════════════════════════╗ +║ EXPERIMENT 10: ROUTING TEMPERATURE ║ +║ & EXPERT SPECIALIZATION ANALYSIS ║ +╚══════════════════════════════════════════════════════════════════════════════╝ + +📋 OVERVIEW +─────────── +Systematic exploration of routing temperature effects on MoE training: +• How temperature affects convergence speed and final performance +• Expert utilization and load balancing dynamics +• Temperature scheduling strategies (exploration → exploitation) +• Expert specialization patterns under different routing regimes + +🔬 RESEARCH QUESTIONS +───────────────────── +1. What is the optimal routing temperature for MoE training? +2. Does temperature scheduling improve upon constant temperature? +3. How does temperature affect expert specialization? +4. Can we reduce load balancing loss through temperature tuning? + +🏗️ ARCHITECTURE +──────────────── +Model: MoE Transformer (classic attention) +Experts: 8 experts, top-2 routing +Size: ~79M total params (~28.4% active) +Dimensions: d_model=384, n_heads=8, n_layers=6, d_ff=1536 + +⚙️ EXPERIMENTS +─────────────── +Temperature Ablation (500 steps each): + • temp_0.5 - Very sharp routing (strong exploitation) + • temp_0.7 - Sharp routing + • temp_1.0 - Standard softmax (baseline) + • temp_1.5 - Slightly softer routing + • temp_2.0 - Softer routing (more exploration) + • temp_3.0 - Soft routing (high exploration) + • temp_5.0 - Very soft routing + • temp_10.0 - Nearly uniform routing + +Temperature Schedules (500 steps each): + • schedule_linear - Linear decay from 5.0 → 1.0 + • schedule_cosine - Cosine decay from 5.0 → 1.0 + • schedule_exp - Exponential decay from 5.0 → 1.0 + • schedule_step - Step decay: 5.0→2.0→1.0 + +Extended Training: + • temp_best_long - Best temperature, 1000 steps + +📊 METRICS TRACKED +────────────────── +Performance: + • Validation loss, accuracy, perplexity + • Training time (wall-clock) + +Routing: + • Expert utilization distribution + • Load balancing loss + • Routing entropy (diversity measure) + • Expert selection confidence + +Specialization: + • Expert activation patterns + • Gini coefficient (utilization inequality) + • Utilization variance + +🚀 QUICK START +────────────── +# List all experiments +python run_experiment.py --list + +# Run quick demo (3 temperatures) +bash quick_demo.sh + +# Run full temperature ablation +python run_experiment.py --ablation + +# Run temperature schedules +python run_experiment.py --schedules + +# Run single temperature +python run_experiment.py --experiment temp_2.0 + +# Generate visualizations +python plot_results.py --results-dir ./results --output-dir ./analysis +python analyze_specialization.py --results-dir ./results --output-dir ./analysis + +📈 EXPECTED OUTCOMES +──────────────────── +• Temperature ~1.5-2.0 likely optimal (based on theory) +• Very low temperature (0.5) → load imbalance +• Very high temperature (10.0) → insufficient specialization +• Temperature scheduling should combine exploration + exploitation +• Clear trade-off between load balancing and specialization + +🎯 KEY CONTRIBUTIONS +──────────────────── +1. Optimal routing temperature for MoE training +2. Temperature scheduling strategies +3. Expert specialization dynamics under different routing regimes +4. Load balancing effectiveness as function of temperature +5. Comprehensive routing metrics and visualization toolkit + +📁 OUTPUT FILES +─────────────── +Each experiment produces: + • results/{exp_name}/metrics.json - Complete training history + • results/{exp_name}/model.pt - Model checkpoint + • results/{exp_name}/logs/ - Training logs + +Analysis generates: + • analysis/temperature_ablation_comprehensive.png + • analysis/routing_dynamics.png + • analysis/expert_utilization.png + • analysis/expert_utilization_analysis.png + • analysis/entropy_analysis.png + • analysis/schedule_comparison.png + • analysis/summary_report.json + • analysis/specialization_report.json + +🔧 CONFIGURATION +──────────────── +Optimizer: Muon (hybrid) with optimal settings from exp9 + • Muon LR: 0.07 + • AdamW LR: 0.007 + • Momentum: 0.9 + • Weight decay: 0.2 + +Training: + • Steps: 500 (1000 for extended) + • Batch size: 24 + • Grad accum: 4 + • LR schedule: Cosine with 5% warmup + • Load bal: 0.01 + +Dataset: HuggingFaceTB/smollm-corpus (cosmopedia-v2) + • Train docs: 1,800 + • Val docs: 200 + • Seq length: 512 tokens + +💡 KEY INSIGHTS +─────────────── +Temperature controls exploration-exploitation trade-off: + • Low temp (< 1.0): Sharp routing, fast specialization, risk of imbalance + • Medium temp (1-2): Balanced routing, good for most cases + • High temp (> 2): Exploratory routing, better load balance, slower convergence + +Scheduling strategy: + • Start high (exploration) to find good expert assignments + • Decay to lower values (exploitation) for final refinement + • Cosine/exponential schedules likely superior to linear + +📚 REFERENCES +───────────── +• Switch Transformers (Fedus+ 2021) - Load balancing in MoE +• GShard (Lepikhin+ 2020) - Scaling MoE models +• Expert Choice Routing (Zhou+ 2022) - Alternative routing +• Soft MoE (Puigcerver+ 2023) - Soft expert assignments + +─────────────────────────────────────────────────────────────────────────────── +Created: November 11, 2025 +Branch: exp10-routing-temperature-analysis +─────────────────────────────────────────────────────────────────────────────── + diff --git a/experiments/exp10_routing_temperature_specialization/EXPERIMENT_SUMMARY.md b/experiments/exp10_routing_temperature_specialization/EXPERIMENT_SUMMARY.md new file mode 100644 index 0000000..30de0ae --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/EXPERIMENT_SUMMARY.md @@ -0,0 +1,504 @@ +# Experiment 10: Routing Temperature & Expert Specialization Analysis + +## Executive Summary + +This experiment systematically explores how **routing temperature** affects Mixture-of-Experts (MoE) model training. Temperature controls the sharpness of the routing distribution, creating a fundamental trade-off between exploration (uniform routing) and exploitation (sharp, confident routing). + +### Key Innovation + +Unlike previous MoE work that uses fixed temperature (typically 1.0), we: +1. **Systematically ablate temperatures** from 0.5 (sharp) to 10.0 (uniform) +2. **Explore temperature scheduling** (high→low over training) +3. **Comprehensively track routing dynamics** (entropy, utilization, specialization) +4. **Generate rich visualizations** to understand expert behavior + +## Motivation + +**Why does temperature matter?** + +In MoE models, the routing network uses softmax to assign tokens to experts: + +``` +router_probs = softmax(logits / temperature) +``` + +- **Low temperature (< 1.0)**: Sharp routing + - ✅ Strong expert specialization + - ✅ Fast convergence + - ❌ Risk of load imbalance + - ❌ Risk of premature specialization + +- **High temperature (> 1.0)**: Soft routing + - ✅ Better load balancing + - ✅ More exploration + - ❌ Slower specialization + - ❌ Potentially worse final performance + +- **Temperature scheduling**: Best of both worlds? + - Start high: Exploration phase, find good expert assignments + - End low: Exploitation phase, refine specializations + +## Experiment Design + +### Model Architecture + +``` +MoE Transformer +├── Layers: 6 +├── d_model: 384 +├── n_heads: 8 +├── d_ff: 1536 +├── Experts: 8 +├── Top-k: 2 +└── Total params: ~79M (28.4% active) +``` + +### Experiment Matrix + +#### 1. Temperature Ablation (8 experiments × 500 steps) + +| Experiment | Temperature | Description | Expected Behavior | +|------------|-------------|-------------|-------------------| +| temp_0.5 | 0.5 | Very sharp | Strong specialization, risk imbalance | +| temp_0.7 | 0.7 | Sharp | Moderate specialization | +| **temp_1.0** | **1.0** | **Baseline** | **Standard softmax** | +| temp_1.5 | 1.5 | Slightly soft | More balanced | +| temp_2.0 | 2.0 | Soft | Good exploration | +| temp_3.0 | 3.0 | Very soft | High exploration | +| temp_5.0 | 5.0 | Nearly uniform | Maximum exploration | +| temp_10.0 | 10.0 | Uniform | Minimal specialization | + +#### 2. Temperature Scheduling (4 experiments × 500 steps) + +| Experiment | Schedule | Formula | Rationale | +|------------|----------|---------|-----------| +| schedule_linear | Linear | `5.0 + (1.0 - 5.0) * progress` | Simple decay | +| schedule_cosine | Cosine | `1.0 + (5.0 - 1.0) * 0.5 * (1 + cos(π*progress))` | Smooth decay | +| schedule_exp | Exponential | `5.0 * (1.0/5.0)^progress` | Fast early decay | +| schedule_step | Step | `5.0 → 2.0 → 1.0` | Discrete phases | + +#### 3. Extended Training (1 experiment × 1000 steps) + +| Experiment | Temperature | Description | +|------------|-------------|-------------| +| temp_best_long | TBD (best from ablation) | Longer training with optimal temp | + +**Total: 13 experiments** + +### Training Configuration + +Based on optimal settings from Experiment 9 (Muon vs Adam): + +```python +Optimizer: Muon (hybrid) + muon_lr: 0.07 + adamw_lr: 0.007 + momentum: 0.9 + weight_decay: 0.2 + +Training: + steps: 500 (1000 for extended) + batch_size: 24 + grad_accumulation: 4 + effective_batch: 96 + +LR Schedule: Cosine decay + warmup_ratio: 0.05 (25 steps) + min_lr_ratio: 0.1 + +Regularization: + dropout: 0.1 + grad_clip: 1.0 + load_bal_weight: 0.01 + +Data: HuggingFaceTB/smollm-corpus (cosmopedia-v2) + train_docs: 1,800 + val_docs: 200 + seq_length: 512 tokens +``` + +## Metrics & Analysis + +### Primary Metrics + +**Performance Metrics:** +- Validation loss (primary objective) +- Validation accuracy +- Validation perplexity +- Training time (wall-clock) + +**Routing Metrics:** +- **Routing entropy**: Measures routing diversity + - High entropy = uniform routing + - Low entropy = sharp routing + - Formula: `H = -Σ p_i log(p_i)` + +- **Selection confidence**: How strongly top expert is preferred + - Range: [0.5, 1.0] for top-2 routing + - Higher = more confident routing + +- **Expert utilization**: Fraction of tokens per expert + - Ideal: 1/8 = 0.125 (uniform) + - Gini coefficient: measures inequality + +- **Load balancing loss**: Auxiliary loss to encourage balance + - Lower = better load distribution + +### Visualization Suite + +Our comprehensive analysis generates: + +#### 1. Temperature Comparison Plots +- Loss curves for all temperatures +- Performance vs temperature (log scale) +- Accuracy vs temperature +- Routing entropy vs temperature +- Summary statistics table + +#### 2. Routing Dynamics Analysis +- Entropy evolution over training +- Selection confidence evolution +- Load balancing loss trends +- Temperature vs final routing metrics + +#### 3. Expert Utilization Patterns +- Per-expert utilization bars for each temperature +- Heatmap: temperatures × experts +- Gini coefficient analysis +- Utilization variance analysis + +#### 4. Schedule Comparison +- Loss/accuracy curves for all schedules +- Temperature evolution visualization +- Final performance comparison + +#### 5. Specialization Analysis +- Gini coefficient vs temperature +- Utilization variance vs temperature +- Expert activation heatmaps +- Entropy change rate analysis + +## Expected Results & Hypotheses + +### Hypothesis 1: Optimal Temperature ≈ 1.5-2.0 + +**Reasoning:** +- temp = 1.0 is arbitrary (just standard softmax) +- Slightly higher temp should improve exploration +- Too high temp loses specialization benefits + +**Expected curve:** +``` +Loss + ^ + | * + | * * + | * * + | * * + | * * + +-----------> Temperature + 0.5 1.0 2.0 5.0 10.0 +``` + +### Hypothesis 2: Temperature Scheduling Helps + +**Reasoning:** +- Early training: Need exploration to find good expert assignments +- Late training: Need exploitation to refine specializations +- Cosine schedule likely best (smooth transition) + +**Expected ranking:** +1. Cosine schedule (smooth decay) +2. Exponential schedule (fast early exploration) +3. Step schedule (abrupt transitions) +4. Linear schedule (too slow early) + +### Hypothesis 3: Low Temp → Load Imbalance + +**Reasoning:** +- Sharp routing leads to winner-take-all dynamics +- Some experts become dominant, others unused +- Higher Gini coefficient, higher load balancing loss + +**Expected:** +- temp=0.5: High Gini (> 0.3), poor load balancing +- temp=1.0: Moderate Gini (~0.2) +- temp=2.0: Low Gini (< 0.15), good balance + +### Hypothesis 4: Entropy Decreases Over Training + +**Reasoning:** +- Early training: High uncertainty, higher entropy +- Late training: Experts specialize, lower entropy +- Effect more pronounced at lower temperatures + +## Implementation Details + +### Key Components + +#### 1. `TemperatureRouter` (temperature_router.py) +```python +class TemperatureRouter(nn.Module): + def forward(self, x): + logits = self.gate(x) + scaled_logits = logits / self.current_temperature + probs = softmax(scaled_logits) + # ... track statistics +``` + +**Features:** +- Dynamic temperature setting +- Comprehensive routing statistics +- Entropy & confidence tracking +- Expert utilization monitoring + +#### 2. `TemperatureMoE` (temperature_moe.py) +```python +class TemperatureMoE(nn.Module): + def set_temperature(self, temp): + self.router.set_temperature(temp) + + def forward(self, x, return_routing_stats=True): + # Route tokens to experts + # Track activation patterns + # Return stats if requested +``` + +**Features:** +- Temperature-aware routing +- Expert activation tracking +- Detailed routing statistics +- Load balancing loss computation + +#### 3. `train_with_temperature_tracking` (tracking_trainer.py) +```python +def train_with_temperature_tracking(...): + for step in range(max_steps): + # Update temperature (scheduled or constant) + current_temp = temp_config.get_temperature_at_step(step) + model.set_temperature(current_temp) + + # Training step + # Collect routing stats at eval points + # Save comprehensive history +``` + +**Features:** +- Temperature scheduling support +- Comprehensive metric tracking +- Routing statistics collection +- Rich training history + +### Data Flow + +``` +Input Tokens + ↓ +[Token Embeddings] + ↓ +[Attention Layer] + ↓ +[RMS Norm] + ↓ +┌─────────────────────────┐ +│ TemperatureMoE Layer │ +│ │ +│ [TemperatureRouter] │ +│ - Apply temperature │ +│ - Compute probs │ +│ - Select top-k │ +│ - Track stats │ +│ ↓ │ +│ [Expert Processing] │ +│ - Route tokens │ +│ - Apply experts │ +│ - Weighted combine │ +└─────────────────────────┘ + ↓ +[RMS Norm] + ↓ +[LM Head] + ↓ +Logits + Aux Loss + Routing Stats +``` + +## Usage Examples + +### Basic Usage + +```bash +# List all experiments +python run_experiment.py --list + +# Run single temperature +python run_experiment.py --experiment temp_1.0 + +# Run temperature ablation (8 experiments) +python run_experiment.py --ablation + +# Run temperature schedules (4 experiments) +python run_experiment.py --schedules + +# Run all 13 experiments +python run_experiment.py --all + +# Custom temperature +python run_experiment.py --temperature 1.5 +``` + +### Quick Demo + +```bash +# Run 3 representative temperatures (temp_0.7, temp_1.0, temp_2.0) +bash quick_demo.sh +``` + +This runs 3 experiments (500 steps each, ~2-3 min per experiment) and generates all visualizations. + +### Analysis & Visualization + +```bash +# Generate comprehensive plots +python plot_results.py \ + --results-dir ./results \ + --output-dir ./analysis + +# Analyze expert specialization +python analyze_specialization.py \ + --results-dir ./results \ + --output-dir ./analysis +``` + +### Output Structure + +``` +exp10_routing_temperature_specialization/ +├── results/ +│ ├── temp_0.5/ +│ │ ├── metrics.json # Complete training history +│ │ ├── model.pt # Model checkpoint +│ │ └── logs/ # Training logs +│ ├── temp_1.0/ +│ │ └── ... +│ └── ... +│ +└── analysis/ + ├── temperature_ablation_comprehensive.png + ├── routing_dynamics.png + ├── expert_utilization.png + ├── expert_utilization_analysis.png + ├── entropy_analysis.png + ├── schedule_comparison.png + ├── summary_report.json + └── specialization_report.json +``` + +## Knowledge Generated + +This experiment will yield deep insights into: + +### 1. **Optimal Temperature Discovery** +- Empirically determine best temperature for MoE training +- Understand temperature-performance relationship +- Quantify sensitivity to temperature choice + +### 2. **Routing Dynamics Understanding** +- How routing evolves during training +- When/how experts specialize +- Impact of temperature on specialization patterns + +### 3. **Load Balancing Insights** +- Trade-off between specialization and balance +- Effectiveness of load balancing loss at different temperatures +- Alternative approaches to temperature for better balance + +### 4. **Schedule Design Principles** +- Does scheduling help? By how much? +- What schedule shape is optimal? +- When to transition from exploration to exploitation? + +### 5. **Practical Guidelines** +- Actionable recommendations for MoE practitioners +- Temperature tuning as hyperparameter +- Integration with other training techniques + +## Extensions & Future Work + +### Immediate Extensions + +1. **Longer Training**: Run best temperature for 5k-10k steps +2. **Larger Models**: Scale to 1B+ parameters +3. **Different Architectures**: Test with different expert/attention designs +4. **Per-Layer Temperatures**: Different temps for different layers + +### Research Directions + +1. **Adaptive Temperature**: Learn temperature during training +2. **Token-Dependent Temperature**: Different temps for different tokens +3. **Expert-Specific Temperature**: Per-expert routing sharpness +4. **Annealing Strategies**: More sophisticated scheduling +5. **Uncertainty-Based Routing**: Use temperature to model uncertainty + +### Integration Opportunities + +1. **Combine with Exp9**: Optimal optimizer + optimal temperature +2. **Architecture Search**: Find best temp for different architectures +3. **Dataset Effects**: How does optimal temp vary by dataset? +4. **Multi-Objective**: Balance loss, speed, and expert utilization + +## Technical Notes + +### Reproducibility + +- All experiments use seed=42 +- Data split before tokenization (no leakage) +- Deterministic operations where possible +- Complete configuration saved with results + +### Computational Requirements + +- **Per experiment**: ~2-3 minutes on GPU (V100/A100) +- **Full ablation**: ~20-25 minutes (8 experiments) +- **Complete suite**: ~40-50 minutes (13 experiments) +- **Memory**: ~8-10 GB GPU RAM + +### Implementation Choices + +**Why these temperatures?** +- 0.5, 0.7: Sharp routing regime +- 1.0: Standard baseline +- 1.5, 2.0, 3.0: Exploration regime +- 5.0, 10.0: Extreme exploration + +**Why these schedules?** +- Linear: Simple baseline +- Cosine: Smooth, widely used +- Exponential: Fast early exploration +- Step: Test discrete phase transitions + +**Why 500 steps?** +- Fast iteration for ablation +- Sufficient to see convergence trends +- Extended run (1000 steps) validates best setting + +## Conclusion + +Experiment 10 represents a **comprehensive, systematic exploration** of routing temperature in MoE models—a fundamental but under-studied hyperparameter. Through careful ablation, rich tracking, and extensive visualization, we will: + +1. ✅ Discover optimal temperature for MoE training +2. ✅ Understand exploration-exploitation trade-offs +3. ✅ Quantify impact on expert specialization +4. ✅ Develop practical temperature scheduling strategies +5. ✅ Generate actionable insights for MoE practitioners +6. ✅ Create comprehensive visualization toolkit +7. ✅ Build foundation for future routing research + +The experiment is **ready to run** and will generate **significant new knowledge** about MoE training dynamics. + +--- + +**Created**: November 11, 2025 +**Branch**: `exp10-routing-temperature-analysis` +**Status**: ✅ Ready to execute +**Estimated Time**: 40-50 minutes for complete suite + diff --git a/experiments/exp10_routing_temperature_specialization/GETTING_STARTED.md b/experiments/exp10_routing_temperature_specialization/GETTING_STARTED.md new file mode 100644 index 0000000..247283b --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/GETTING_STARTED.md @@ -0,0 +1,438 @@ +# Getting Started with Experiment 10 + +## 🚀 Quick Start (3 minutes) + +The fastest way to see the experiment in action: + +```bash +cd /root/blueberry-llm/experiments/exp10_routing_temperature_specialization + +# Run quick demo: 3 temperatures, 500 steps each (~6-9 minutes total) +bash quick_demo.sh +``` + +This will: +1. Run 3 experiments: temp_0.7, temp_1.0, temp_2.0 +2. Generate all visualizations +3. Create analysis reports + +Results will be in: +- `./results/` - Individual experiment results +- `./analysis/` - Comparative plots and analysis + +## 📋 What Has Been Created + +### Core Components (1,200+ lines of code) + +1. **`temperature_router.py`** (250 lines) + - Temperature-controlled routing with softmax scaling + - Comprehensive statistics tracking + - Routing entropy and confidence metrics + +2. **`temperature_moe.py`** (200 lines) + - MoE layer with temperature support + - Expert activation tracking + - Load balancing loss computation + +3. **`temperature_model.py`** (180 lines) + - Complete model with temperature-aware MoE + - Model creation utilities + - Parameter counting + +4. **`tracking_trainer.py`** (300 lines) + - Custom trainer with routing statistics + - Temperature scheduling support + - Comprehensive history tracking + +5. **`run_experiment.py`** (350 lines) + - Main experiment runner + - Multiple experiment support + - Data loading and preparation + +6. **`plot_results.py`** (450 lines) + - 6 comprehensive visualizations + - Temperature comparison plots + - Routing dynamics analysis + - Expert utilization patterns + - Schedule comparison + - Summary report generation + +7. **`analyze_specialization.py`** (350 lines) + - Expert specialization analysis + - Gini coefficient computation + - Utilization variance analysis + - Entropy trend analysis + +8. **`config.py`** (150 lines) + - 13 experiment configurations + - Temperature scheduling functions + - Configuration management + +### Documentation (3,000+ words) + +1. **`README.md`** - Complete experiment documentation +2. **`EXPERIMENT_SUMMARY.md`** - Comprehensive technical summary +3. **`EXPERIMENT_CARD.txt`** - Quick reference card +4. **`GETTING_STARTED.md`** - This file! + +### Utilities + +1. **`quick_demo.sh`** - Fast demo script +2. **`quick_test.py`** - Configuration verification + +## 🎯 Experiment Overview + +### What We're Studying + +**Routing Temperature** controls how sharply the MoE router selects experts: + +``` +router_probs = softmax(logits / temperature) +``` + +- **Low temp (0.5)**: Sharp, confident routing → strong specialization +- **Medium temp (1.0)**: Balanced routing → baseline +- **High temp (5.0)**: Soft, exploratory routing → better load balance + +### The 13 Experiments + +#### Temperature Ablation (8 experiments) +- `temp_0.5` - Very sharp (exploitation) +- `temp_0.7` - Sharp +- `temp_1.0` - **Baseline** +- `temp_1.5` - Slightly soft +- `temp_2.0` - Soft (exploration) +- `temp_3.0` - Very soft +- `temp_5.0` - Nearly uniform +- `temp_10.0` - Uniform (extreme exploration) + +#### Temperature Schedules (4 experiments) +- `schedule_linear` - Linear decay 5.0→1.0 +- `schedule_cosine` - Cosine decay 5.0→1.0 +- `schedule_exp` - Exponential decay 5.0→1.0 +- `schedule_step` - Step decay 5.0→2.0→1.0 + +#### Extended Training (1 experiment) +- `temp_best_long` - Best temperature, 1000 steps + +## 📊 What You'll Get + +### Per-Experiment Outputs + +Each experiment produces: +``` +results/temp_1.0/ +├── metrics.json # Complete training history +│ ├── val_losses: [...] +│ ├── val_accuracies: [...] +│ ├── routing_entropies: [...] +│ ├── selection_confidences: [...] +│ ├── expert_utilizations: [...] +│ └── ... (20+ metrics) +│ +├── model.pt # Model checkpoint +└── logs/ # Training logs +``` + +### Analysis Outputs + +The analysis scripts generate: +``` +analysis/ +├── temperature_ablation_comprehensive.png +│ ├── Loss vs steps (all temps) +│ ├── Loss vs time +│ ├── Performance vs temperature +│ ├── Accuracy vs temperature +│ ├── Routing entropy vs temperature +│ ├── Accuracy evolution +│ └── Summary statistics table +│ +├── routing_dynamics.png +│ ├── Routing entropy evolution +│ ├── Selection confidence evolution +│ ├── Load balancing loss trends +│ └── Temperature vs final metrics +│ +├── expert_utilization.png +│ └── Per-expert utilization bars (all temps) +│ +├── expert_utilization_analysis.png +│ ├── Gini coefficient vs temperature +│ ├── Utilization variance vs temperature +│ ├── Expert utilization heatmap +│ └── Statistics summary table +│ +├── entropy_analysis.png +│ ├── Entropy evolution over training +│ └── Entropy change rate +│ +├── schedule_comparison.png +│ ├── Loss comparison (all schedules) +│ ├── Temperature evolution +│ ├── Accuracy comparison +│ └── Final performance bars +│ +├── summary_report.json +│ ├── Best results +│ ├── Temperature analysis +│ └── Schedule analysis +│ +└── specialization_report.json + ├── Per-experiment analysis + ├── Utilization metrics + ├── Routing entropy metrics + └── Key insights +``` + +## 🎨 Sample Commands + +### Run Single Experiment +```bash +python run_experiment.py --experiment temp_2.0 +``` + +### Run Temperature Ablation (all 8 temps) +```bash +python run_experiment.py --ablation +``` + +### Run Temperature Schedules (all 4 schedules) +```bash +python run_experiment.py --schedules +``` + +### Run Everything (13 experiments) +```bash +python run_experiment.py --all +``` + +### Custom Temperature +```bash +python run_experiment.py --temperature 1.5 +``` + +### List All Available Experiments +```bash +python run_experiment.py --list +``` + +### Generate Visualizations +```bash +# After running experiments +python plot_results.py --results-dir ./results --output-dir ./analysis +python analyze_specialization.py --results-dir ./results --output-dir ./analysis +``` + +## 🔬 Expected Insights + +After running the experiments, you'll discover: + +### 1. Optimal Temperature +- What temperature gives best validation loss? +- How sensitive is performance to temperature? +- Is temp=1.0 (default) actually optimal? + +### 2. Exploration-Exploitation Trade-off +- How does temperature affect convergence speed? +- When is high temperature (exploration) beneficial? +- When is low temperature (exploitation) better? + +### 3. Expert Specialization +- How does temperature affect expert utilization? +- Do different temperatures lead to different specialization patterns? +- What's the relationship between specialization and performance? + +### 4. Load Balancing +- Can temperature tuning reduce load balancing loss? +- Is there a temperature that balances performance and load distribution? +- How does temperature affect the Gini coefficient? + +### 5. Scheduling Strategies +- Does temperature scheduling help? +- What schedule shape is optimal? +- When should we transition from exploration to exploitation? + +## 📈 Interpreting Results + +### Key Metrics to Watch + +**Validation Loss** (lower is better) +- Primary performance metric +- Look for U-shaped curve: optimal temperature in middle + +**Routing Entropy** (context-dependent) +- High entropy = uniform routing (more exploration) +- Low entropy = sharp routing (strong specialization) +- Optimal depends on training phase + +**Expert Utilization** (balanced is good) +- Ideal: all experts used equally (~12.5% each for 8 experts) +- Gini coefficient: 0 = perfect balance, higher = more inequality +- Low temperature tends to increase inequality + +**Selection Confidence** (higher = sharper) +- How strongly is the top expert preferred? +- Should decrease with temperature +- Trade-off: too high = imbalance, too low = weak specialization + +### Visualization Interpretation + +#### Temperature Ablation Plot +Look for: +- **Optimal temperature**: Minimum of the loss curve +- **Sensitivity**: How steep is the curve? Flat = robust, steep = sensitive +- **Sweet spot**: Usually between 1.0 and 3.0 + +#### Routing Dynamics +Look for: +- **Entropy trends**: Does entropy decrease over training? +- **Confidence trends**: Does confidence increase over training? +- **Temperature effects**: How do different temps affect trends? + +#### Expert Utilization +Look for: +- **Balance**: Are all experts used roughly equally? +- **Temperature effect**: Higher temp → better balance? +- **Specialization**: Low utilization variance = good balance + +#### Schedule Comparison +Look for: +- **Best schedule**: Which achieves lowest loss? +- **Early vs late**: Does high early temp help? +- **Smooth vs discrete**: Cosine vs step schedule? + +## 💡 Tips for Running Experiments + +### For Quick Iteration +```bash +# Run just 3 representative temps (~10 min total) +python run_experiment.py --experiments temp_0.7 temp_1.0 temp_2.0 +``` + +### For Comprehensive Analysis +```bash +# Run full ablation (~20-25 min) +python run_experiment.py --ablation + +# Then run schedules (~10-12 min) +python run_experiment.py --schedules + +# Finally, extended run with best temp (~4-5 min) +# Update temp_best_long in config.py with best temp first +python run_experiment.py --experiment temp_best_long +``` + +### For Custom Exploration +```bash +# Test specific temperature +python run_experiment.py --temperature 1.8 + +# Multiple custom temps +python run_experiment.py --experiments temp_1.5 temp_2.5 temp_3.5 +``` + +## 🐛 Troubleshooting + +### Out of Memory +- Reduce batch size in `MoEModelConfig` +- Reduce number of documents: `num_documents=1000` instead of 2000 +- Disable AMP: `use_amp=False` + +### Slow Training +- Reduce eval frequency: `eval_every=20` instead of 10 +- Reduce eval steps: `eval_steps=50` instead of 100 +- Use fewer workers: `num_workers=1` in data loader + +### Import Errors +Make sure you're running from the experiment directory: +```bash +cd /root/blueberry-llm/experiments/exp10_routing_temperature_specialization +python run_experiment.py --list +``` + +### Plotting Errors +Install required packages: +```bash +pip install matplotlib seaborn numpy +``` + +## 📚 Next Steps + +After running the experiments: + +1. **Analyze Results** + ```bash + python plot_results.py + python analyze_specialization.py + ``` + +2. **Review Visualizations** + - Open `analysis/*.png` files + - Check `analysis/summary_report.json` + - Review `analysis/specialization_report.json` + +3. **Document Findings** + - What was the best temperature? + - How much improvement over baseline? + - What scheduling strategy worked best? + - Any surprising results? + +4. **Extend Experiments** + - Test longer training (5k-10k steps) + - Try different model sizes + - Test on different datasets + - Combine with other techniques (from exp9, etc.) + +5. **Share Results** + - Add findings to README + - Create summary plots + - Document insights for future experiments + +## 🎓 Learning Resources + +To understand the experiment better, read: + +1. **`EXPERIMENT_SUMMARY.md`** - Technical deep dive +2. **`README.md`** - Complete documentation +3. **`EXPERIMENT_CARD.txt`** - Quick reference + +Key papers: +- Switch Transformers (Fedus et al., 2021) - Load balancing +- GShard (Lepikhin et al., 2020) - Scaling MoE +- Expert Choice Routing (Zhou et al., 2022) - Alternative routing + +## ✅ Checklist + +Before running experiments: +- [ ] Verify GPU available: `nvidia-smi` +- [ ] Check disk space: `df -h` +- [ ] Test configuration: `python quick_test.py` +- [ ] Review experiment list: `python run_experiment.py --list` + +After running experiments: +- [ ] Check all experiments completed successfully +- [ ] Generate visualizations +- [ ] Review summary reports +- [ ] Document key findings +- [ ] Save important plots + +## 🎉 You're Ready! + +Everything is set up and ready to run. Start with: + +```bash +bash quick_demo.sh +``` + +Then explore the results in `./analysis/`! + +--- + +**Questions?** Check the documentation: +- `README.md` - Complete experiment documentation +- `EXPERIMENT_SUMMARY.md` - Technical details +- `EXPERIMENT_CARD.txt` - Quick reference + +**Happy experimenting! 🚀** + diff --git a/experiments/exp10_routing_temperature_specialization/README.md b/experiments/exp10_routing_temperature_specialization/README.md new file mode 100644 index 0000000..7953fb8 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/README.md @@ -0,0 +1,216 @@ +# Experiment 10: Routing Temperature and Expert Specialization Analysis + +## Overview + +This experiment systematically explores how **routing temperature** affects Mixture-of-Experts (MoE) model training. Temperature controls the sharpness of the routing distribution - high temperature leads to more uniform routing (exploration), while low temperature leads to sharper, more confident routing decisions (exploitation). + +## Research Questions + +1. **How does routing temperature affect convergence speed and final performance?** + - Does higher temperature lead to better exploration early in training? + - What is the optimal temperature for final performance? + +2. **How does temperature affect expert utilization?** + - Do different temperatures lead to different load balancing? + - Can we reduce load balancing loss with better temperature tuning? + +3. **Does temperature scheduling help?** + - Can we start with high temperature (exploration) and decrease over time (exploitation)? + - What is the optimal temperature schedule? + +4. **How do experts specialize under different temperatures?** + - Do experts develop distinct specializations? + - Does temperature affect the clarity of specialization patterns? + +## Experiment Design + +### Model Architecture +- **Type**: MoE Transformer with full attention (classic architecture) +- **Experts**: 8 experts with top-2 routing +- **Dimensions**: d_model=384, n_heads=8, n_layers=6, d_ff=1536 +- **Parameters**: ~79M total (~28.4% active per forward pass) + +### Experiments + +#### Temperature Ablation (500 steps) +1. **temp_0.5**: Very sharp routing (strong exploitation) +2. **temp_1.0**: Standard softmax (baseline) +3. **temp_2.0**: Softer routing (more exploration) +4. **temp_5.0**: Very soft routing (maximum exploration) +5. **temp_10.0**: Nearly uniform routing + +#### Temperature Scheduling (500 steps) +6. **schedule_linear**: Linear decay from 5.0 → 1.0 +7. **schedule_cosine**: Cosine decay from 5.0 → 1.0 +8. **schedule_exp**: Exponential decay from 5.0 → 1.0 +9. **schedule_step**: Step decay: 5.0 (0-100) → 2.0 (100-300) → 1.0 (300+) + +#### Extended Training (1000 steps) +10. **temp_best_long**: Best temperature from ablation, trained longer + +### Training Configuration +- **Steps**: 500 (1000 for extended run) +- **Batch size**: 24 +- **Gradient accumulation**: 4 steps +- **Optimizer**: Muon (hybrid) with optimal settings from exp9 + - Muon LR: 0.07 + - AdamW LR: 0.007 + - Momentum: 0.9 + - Weight decay: 0.2 +- **LR schedule**: Cosine decay with 5% warmup +- **Load balancing weight**: 0.01 +- **Dataset**: HuggingFaceTB/smollm-corpus (cosmopedia-v2) + - Training docs: 1,800 + - Validation docs: 200 + - Sequence length: 512 tokens + +### Metrics Tracked + +For each experiment: +- **Performance metrics**: + - Validation loss + - Validation accuracy + - Validation perplexity + - Training time + +- **Routing metrics**: + - Expert utilization distribution + - Load balancing loss + - Routing entropy (diversity measure) + - Expert selection confidence + +- **Specialization metrics**: + - Per-expert token type distribution + - Expert activation patterns + - Specialization clarity score + +## Directory Structure + +``` +exp10_routing_temperature_specialization/ +├── __init__.py +├── README.md +├── config.py # Experiment configurations +├── temperature_router.py # Router with temperature control +├── temperature_moe.py # MoE with temperature support +├── tracking_trainer.py # Trainer with routing metrics tracking +├── run_experiment.py # Main experiment runner +├── analyze_specialization.py # Expert specialization analysis +├── plot_results.py # Comprehensive visualization +├── results/ # Generated during training +│ ├── temp_0.5/ +│ │ ├── metrics.json +│ │ ├── routing_history.json +│ │ ├── expert_stats.json +│ │ └── plots/ +│ ├── temp_1.0/ +│ │ └── ... +│ └── ... +└── analysis/ # Generated after all experiments + ├── temperature_comparison.png + ├── expert_utilization.png + ├── specialization_analysis.png + ├── routing_entropy.png + └── summary_report.json +``` + +## Usage + +### Run All Temperature Ablation Experiments +```bash +cd experiments/exp10_routing_temperature_specialization +python run_experiment.py --ablation +``` + +### Run Specific Temperature +```bash +python run_experiment.py --temperature 2.0 +``` + +### Run All Scheduling Experiments +```bash +python run_experiment.py --schedules +``` + +### Run Complete Suite +```bash +python run_experiment.py --all +``` + +### Analyze Results +```bash +python analyze_specialization.py --results-dir ./results +python plot_results.py --results-dir ./results --output-dir ./analysis +``` + +## Expected Results + +### Hypotheses + +1. **Temperature 1.0 (baseline)** will provide reasonable performance but may not be optimal +2. **Slightly higher temperature (2.0-3.0)** may improve early exploration and lead to better final performance +3. **Very high temperature (10.0)** will hurt performance due to insufficient specialization +4. **Very low temperature (0.5)** may lead to premature specialization and suboptimal convergence +5. **Temperature scheduling** should combine benefits of exploration (early) and exploitation (late) + +### Key Metrics to Watch + +- **Final validation loss**: Primary performance indicator +- **Expert utilization entropy**: How evenly experts are used +- **Routing confidence**: How sharply the router selects experts +- **Convergence speed**: Steps to reach good performance +- **Expert specialization**: Do experts develop distinct roles? + +## Analysis Tools + +The experiment includes comprehensive analysis and visualization: + +1. **Temperature comparison plots**: + - Loss curves for all temperatures + - Convergence speed comparison + - Final performance summary + +2. **Expert utilization visualizations**: + - Expert usage distribution over training + - Load balancing effectiveness + - Routing entropy evolution + +3. **Specialization analysis**: + - Expert activation heatmaps + - Token type preferences per expert + - Specialization clarity metrics + +4. **Routing dynamics**: + - Routing confidence over time + - Expert selection patterns + - Temperature schedule effectiveness + +## Key Contributions + +This experiment will provide insights into: + +1. **Optimal routing temperature** for MoE training +2. **Temperature scheduling strategies** that balance exploration and exploitation +3. **Expert specialization dynamics** under different routing regimes +4. **Load balancing effectiveness** as a function of temperature + +## Notes + +- All experiments use the same random seed (42) for reproducibility +- Data is split before tokenization to prevent leakage +- AMP (Automatic Mixed Precision) is enabled by default +- Routing metrics are tracked at every evaluation step +- Expert statistics are saved for offline analysis + +## References + +- **Switch Transformers** (Fedus et al., 2021): Load balancing in MoE +- **GShard** (Lepikhin et al., 2020): Scaling MoE models +- **Expert Choice Routing** (Zhou et al., 2022): Alternative routing strategies +- **Soft MoE** (Puigcerver et al., 2023): Soft expert assignments + +--- + +**Created**: November 11, 2025 +**Branch**: exp10-routing-temperature-analysis + diff --git a/experiments/exp10_routing_temperature_specialization/__init__.py b/experiments/exp10_routing_temperature_specialization/__init__.py new file mode 100644 index 0000000..cc3e07b --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/__init__.py @@ -0,0 +1,10 @@ +""" +Experiment 10: Routing Temperature and Expert Specialization Analysis + +This experiment explores how routing temperature affects: +1. Expert utilization and load balancing +2. Convergence speed and final performance +3. Expert specialization patterns +4. Model training dynamics +""" + diff --git a/experiments/exp10_routing_temperature_specialization/analysis/entropy_analysis.png b/experiments/exp10_routing_temperature_specialization/analysis/entropy_analysis.png new file mode 100644 index 0000000..9c15c4e Binary files /dev/null and b/experiments/exp10_routing_temperature_specialization/analysis/entropy_analysis.png differ diff --git a/experiments/exp10_routing_temperature_specialization/analysis/expert_utilization.png b/experiments/exp10_routing_temperature_specialization/analysis/expert_utilization.png new file mode 100644 index 0000000..788baef Binary files /dev/null and b/experiments/exp10_routing_temperature_specialization/analysis/expert_utilization.png differ diff --git a/experiments/exp10_routing_temperature_specialization/analysis/expert_utilization_analysis.png b/experiments/exp10_routing_temperature_specialization/analysis/expert_utilization_analysis.png new file mode 100644 index 0000000..56e9891 Binary files /dev/null and b/experiments/exp10_routing_temperature_specialization/analysis/expert_utilization_analysis.png differ diff --git a/experiments/exp10_routing_temperature_specialization/analysis/routing_dynamics.png b/experiments/exp10_routing_temperature_specialization/analysis/routing_dynamics.png new file mode 100644 index 0000000..f9b5e78 Binary files /dev/null and b/experiments/exp10_routing_temperature_specialization/analysis/routing_dynamics.png differ diff --git a/experiments/exp10_routing_temperature_specialization/analysis/specialization_report.json b/experiments/exp10_routing_temperature_specialization/analysis/specialization_report.json new file mode 100644 index 0000000..c6fcc91 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/analysis/specialization_report.json @@ -0,0 +1,104 @@ +{ + "summary": { + "best_temperature": { + "experiment": "temp_0.7", + "temperature": 0.7, + "loss": 14.585194446172816 + }, + "most_balanced_experts": { + "experiment": "temp_2.0", + "gini_coefficient": 0.039511590395894025 + } + }, + "per_experiment": { + "temp_0.7": { + "temperature": 0.7, + "final_loss": 14.585194446172816, + "final_accuracy": 0.01640931313229101, + "expert_utilization": { + "distribution": [ + 0.1447695568203926, + 0.12556708479921022, + 0.11510686203837395, + 0.12321331351995468, + 0.10939040159185727, + 0.13584287712971369, + 0.12208833297093709, + 0.12402133643627167 + ], + "mean": 0.1249999706633389, + "std": 0.01038560616620274, + "min": 0.10939040159185727, + "max": 0.1447695568203926, + "gini": 0.045322315694041215 + }, + "routing_entropy": { + "initial": 0.0, + "final": 0.0, + "mean": 0.0, + "std": 0.0 + } + }, + "temp_1.0": { + "temperature": 1.0, + "final_loss": 14.623773783761283, + "final_accuracy": 0.01640931313229101, + "expert_utilization": { + "distribution": [ + 0.14738632986942926, + 0.12114603569110234, + 0.11430833364526431, + 0.12586971869071326, + 0.10711100200812022, + 0.13220180943608284, + 0.12949158623814583, + 0.12248493855198224 + ], + "mean": 0.12499996926635504, + "std": 0.011343782081391005, + "min": 0.10711100200812022, + "max": 0.14738632986942926, + "gini": 0.049977025508198825 + }, + "routing_entropy": { + "initial": 0.0, + "final": 0.0, + "mean": 0.0, + "std": 0.0 + } + }, + "temp_2.0": { + "temperature": 2.0, + "final_loss": 14.633379750875196, + "final_accuracy": 0.01640931313229101, + "expert_utilization": { + "distribution": [ + 0.13873382657766342, + 0.12168020009994507, + 0.12013104309638341, + 0.12457200636466344, + 0.1075517050921917, + 0.130401990065972, + 0.13411077360312143, + 0.1228182278573513 + ], + "mean": 0.12499997159466146, + "std": 0.008978584248845484, + "min": 0.1075517050921917, + "max": 0.13873382657766342, + "gini": 0.039511590395894025 + }, + "routing_entropy": { + "initial": 0.0, + "final": 0.0, + "mean": 0.0, + "std": 0.0 + } + } + }, + "insights": [ + "Lower temperature (< 1.0) leads to sharper routing but may cause load imbalance", + "Higher temperature (> 1.0) improves load balancing but may reduce specialization", + "Optimal temperature balances exploration and exploitation" + ] +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/analysis/summary_report.json b/experiments/exp10_routing_temperature_specialization/analysis/summary_report.json new file mode 100644 index 0000000..e64d51d --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/analysis/summary_report.json @@ -0,0 +1,35 @@ +{ + "experiment_overview": { + "total_experiments": 3, + "experiment_names": [ + "temp_0.7", + "temp_1.0", + "temp_2.0" + ] + }, + "best_results": { + "temperature_ablation": { + "experiment": "temp_0.7", + "temperature": 0.7, + "best_loss": 10.933640250890077, + "final_loss": 14.585194446172816, + "final_accuracy": 0.01640931313229101 + } + }, + "temperature_analysis": { + "tested_temperatures": [ + 0.7, + 1.0, + 2.0 + ], + "losses": [ + 10.933640250890077, + 10.93908197382735, + 10.947579276014132 + ], + "best_temperature": 0.7, + "worst_temperature": 2.0, + "improvement": 0.12732518096118828 + }, + "schedule_analysis": {} +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/analysis/temperature_ablation_comprehensive.png b/experiments/exp10_routing_temperature_specialization/analysis/temperature_ablation_comprehensive.png new file mode 100644 index 0000000..095217f Binary files /dev/null and b/experiments/exp10_routing_temperature_specialization/analysis/temperature_ablation_comprehensive.png differ diff --git a/experiments/exp10_routing_temperature_specialization/analyze_specialization.py b/experiments/exp10_routing_temperature_specialization/analyze_specialization.py new file mode 100644 index 0000000..f07860d --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/analyze_specialization.py @@ -0,0 +1,338 @@ +""" +Analyze expert specialization patterns from routing statistics +""" +import argparse +import json +from pathlib import Path +from typing import Dict +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +def analyze_expert_specialization(results_dir: Path, output_dir: Path): + """Analyze how experts specialize under different temperatures""" + + print("Loading experiment results...") + results = {} + for exp_dir in results_dir.iterdir(): + if exp_dir.is_dir(): + metrics_file = exp_dir / "metrics.json" + if metrics_file.exists(): + with open(metrics_file, 'r') as f: + results[exp_dir.name] = json.load(f) + + if not results: + print(f"❌ No results found in {results_dir}") + return + + print(f"Found {len(results)} experiments\n") + + # Analyze expert utilization distribution + analyze_utilization_distribution(results, output_dir) + + # Analyze routing entropy trends + analyze_entropy_trends(results, output_dir) + + # Generate specialization report + generate_specialization_report(results, output_dir) + + +def analyze_utilization_distribution(results: Dict, output_dir: Path): + """Analyze how evenly experts are utilized""" + print("Analyzing expert utilization distribution...") + + temp_results = {k: v for k, v in results.items() if k.startswith('temp_')} + sorted_results = sorted(temp_results.items(), key=lambda x: x[1]['temperature']) + + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + fig.suptitle('Expert Utilization Distribution Analysis', fontsize=18, fontweight='bold') + + # Plot 1: Gini coefficient vs Temperature + ax = axes[0, 0] + temps = [] + gini_coeffs = [] + + for name, data in sorted_results: + history = data['history'] + if 'expert_utilizations' in history and history['expert_utilizations']: + final_util = np.array(history['expert_utilizations'][-1]) + # Compute Gini coefficient (measure of inequality) + sorted_util = np.sort(final_util) + n = len(sorted_util) + gini = (2 * np.sum((np.arange(1, n+1)) * sorted_util)) / (n * np.sum(sorted_util)) - (n + 1) / n + + temps.append(data['temperature']) + gini_coeffs.append(gini) + + ax.plot(temps, gini_coeffs, 'o-', linewidth=2, markersize=8, color='darkblue') + ax.set_xlabel('Temperature', fontsize=12) + ax.set_ylabel('Gini Coefficient', fontsize=12) + ax.set_title('Utilization Inequality vs Temperature\n(Lower Gini = More Balanced)', + fontsize=14, fontweight='bold') + ax.set_xscale('log') + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='red', linestyle='--', label='Perfect Balance') + ax.legend() + + # Plot 2: Utilization variance vs Temperature + ax = axes[0, 1] + variances = [] + + for name, data in sorted_results: + history = data['history'] + if 'expert_utilizations' in history and history['expert_utilizations']: + final_util = np.array(history['expert_utilizations'][-1]) + variances.append(np.var(final_util)) + + ax.plot(temps, variances, 'o-', linewidth=2, markersize=8, color='darkgreen') + ax.set_xlabel('Temperature', fontsize=12) + ax.set_ylabel('Utilization Variance', fontsize=12) + ax.set_title('Utilization Variance vs Temperature\n(Lower = More Balanced)', + fontsize=14, fontweight='bold') + ax.set_xscale('log') + ax.grid(True, alpha=0.3) + + # Plot 3: Heatmap of expert utilization across temperatures + ax = axes[1, 0] + utilization_matrix = [] + temp_labels = [] + + for name, data in sorted_results: + history = data['history'] + if 'expert_utilizations' in history and history['expert_utilizations']: + final_util = history['expert_utilizations'][-1] + utilization_matrix.append(final_util) + temp_labels.append(f"T={data['temperature']:.1f}") + + if utilization_matrix: + utilization_matrix = np.array(utilization_matrix) + sns.heatmap(utilization_matrix, annot=True, fmt='.3f', cmap='RdYlGn', + xticklabels=[f'E{i}' for i in range(utilization_matrix.shape[1])], + yticklabels=temp_labels, ax=ax, cbar_kws={'label': 'Utilization'}) + ax.set_title('Expert Utilization Heatmap', fontsize=14, fontweight='bold') + ax.set_xlabel('Expert Index', fontsize=12) + ax.set_ylabel('Temperature', fontsize=12) + + # Plot 4: Distribution statistics + ax = axes[1, 1] + ax.axis('tight') + ax.axis('off') + + # Create summary table + table_data = [['Temperature', 'Gini', 'Variance', 'Min Util', 'Max Util']] + for i, (name, data) in enumerate(sorted_results): + if i < len(gini_coeffs): + history = data['history'] + if 'expert_utilizations' in history and history['expert_utilizations']: + final_util = np.array(history['expert_utilizations'][-1]) + table_data.append([ + f"{data['temperature']:.1f}", + f"{gini_coeffs[i]:.3f}", + f"{variances[i]:.4f}", + f"{np.min(final_util):.3f}", + f"{np.max(final_util):.3f}", + ]) + + table = ax.table(cellText=table_data, cellLoc='center', loc='center', + colWidths=[0.15, 0.15, 0.2, 0.15, 0.15]) + table.auto_set_font_size(False) + table.set_fontsize(9) + table.scale(1, 2) + + # Style header + for i in range(5): + table[(0, i)].set_facecolor('#4CAF50') + table[(0, i)].set_text_props(weight='bold', color='white') + + ax.set_title('Utilization Statistics Summary', fontsize=14, fontweight='bold', pad=20) + + plt.tight_layout() + plt.savefig(output_dir / 'expert_utilization_analysis.png', dpi=300, bbox_inches='tight') + print(f"✅ Saved: {output_dir / 'expert_utilization_analysis.png'}") + plt.close() + + +def analyze_entropy_trends(results: Dict, output_dir: Path): + """Analyze routing entropy evolution over training""" + print("Analyzing routing entropy trends...") + + temp_results = {k: v for k, v in results.items() if k.startswith('temp_')} + sorted_results = sorted(temp_results.items(), key=lambda x: x[1]['temperature']) + + fig, axes = plt.subplots(1, 2, figsize=(16, 6)) + fig.suptitle('Routing Entropy Analysis', fontsize=18, fontweight='bold') + + # Plot 1: Entropy evolution + ax = axes[0] + temps = [r[1]['temperature'] for r in sorted_results] + colors = plt.cm.RdYlBu_r(np.linspace(0.2, 0.8, len(temps))) + + for i, (name, data) in enumerate(sorted_results): + history = data['history'] + if 'routing_entropies' in history and history['routing_entropies']: + ax.plot(history['steps'], history['routing_entropies'], + label=f"T={data['temperature']:.1f}", + color=colors[i], linewidth=2, alpha=0.8) + + ax.set_xlabel('Training Steps', fontsize=12) + ax.set_ylabel('Routing Entropy', fontsize=12) + ax.set_title('Routing Entropy Evolution', fontsize=14, fontweight='bold') + ax.legend(fontsize=9, ncol=2) + ax.grid(True, alpha=0.3) + + # Plot 2: Entropy change rate + ax = axes[1] + + for i, (name, data) in enumerate(sorted_results): + history = data['history'] + if 'routing_entropies' in history and history['routing_entropies'] and len(history['routing_entropies']) > 1: + entropies = np.array(history['routing_entropies']) + steps = np.array(history['steps']) + + # Compute rate of change + entropy_diff = np.diff(entropies) + step_diff = np.diff(steps) + entropy_rate = entropy_diff / step_diff + + ax.plot(steps[1:], entropy_rate, + label=f"T={data['temperature']:.1f}", + color=colors[i], linewidth=2, alpha=0.8) + + ax.set_xlabel('Training Steps', fontsize=12) + ax.set_ylabel('Entropy Change Rate', fontsize=12) + ax.set_title('Routing Entropy Change Rate', fontsize=14, fontweight='bold') + ax.axhline(y=0, color='red', linestyle='--', linewidth=1) + ax.legend(fontsize=9, ncol=2) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(output_dir / 'entropy_analysis.png', dpi=300, bbox_inches='tight') + print(f"✅ Saved: {output_dir / 'entropy_analysis.png'}") + plt.close() + + +def generate_specialization_report(results: Dict, output_dir: Path): + """Generate detailed specialization report""" + print("Generating specialization report...") + + report = { + 'summary': {}, + 'per_experiment': {}, + 'insights': [] + } + + temp_results = {k: v for k, v in results.items() if k.startswith('temp_')} + + for name, data in temp_results.items(): + history = data['history'] + + analysis = { + 'temperature': data['temperature'], + 'final_loss': data['final_metrics']['val_loss'], + 'final_accuracy': data['final_metrics']['val_accuracy'], + } + + # Analyze expert utilization + if 'expert_utilizations' in history and history['expert_utilizations']: + final_util = np.array(history['expert_utilizations'][-1]) + + analysis['expert_utilization'] = { + 'distribution': final_util.tolist(), + 'mean': float(np.mean(final_util)), + 'std': float(np.std(final_util)), + 'min': float(np.min(final_util)), + 'max': float(np.max(final_util)), + 'gini': float((2 * np.sum((np.arange(1, len(final_util)+1)) * np.sort(final_util))) / + (len(final_util) * np.sum(final_util)) - (len(final_util) + 1) / len(final_util)), + } + + # Analyze routing entropy + if 'routing_entropies' in history and history['routing_entropies']: + entropies = history['routing_entropies'] + analysis['routing_entropy'] = { + 'initial': entropies[0] if entropies else None, + 'final': entropies[-1] if entropies else None, + 'mean': float(np.mean(entropies)), + 'std': float(np.std(entropies)), + } + + report['per_experiment'][name] = analysis + + # Generate insights + if temp_results: + best_exp = min(temp_results.items(), key=lambda x: x[1]['final_metrics']['val_loss']) + report['summary']['best_temperature'] = { + 'experiment': best_exp[0], + 'temperature': best_exp[1]['temperature'], + 'loss': best_exp[1]['final_metrics']['val_loss'], + } + + # Find most balanced utilization + gini_scores = {name: data['expert_utilization']['gini'] + for name, data in report['per_experiment'].items() + if 'expert_utilization' in data} + if gini_scores: + most_balanced = min(gini_scores.items(), key=lambda x: x[1]) + report['summary']['most_balanced_experts'] = { + 'experiment': most_balanced[0], + 'gini_coefficient': most_balanced[1], + } + + # Insights + report['insights'].append("Lower temperature (< 1.0) leads to sharper routing but may cause load imbalance") + report['insights'].append("Higher temperature (> 1.0) improves load balancing but may reduce specialization") + report['insights'].append("Optimal temperature balances exploration and exploitation") + + # Save report + report_file = output_dir / 'specialization_report.json' + with open(report_file, 'w') as f: + json.dump(report, f, indent=2) + print(f"✅ Saved: {report_file}") + + # Print summary + print(f"\n{'='*80}") + print("SPECIALIZATION ANALYSIS SUMMARY") + print(f"{'='*80}\n") + + if 'best_temperature' in report['summary']: + bt = report['summary']['best_temperature'] + print(f"🏆 Best Performance: {bt['experiment']}") + print(f" Temperature: {bt['temperature']:.2f}") + print(f" Loss: {bt['loss']:.4f}\n") + + if 'most_balanced_experts' in report['summary']: + mb = report['summary']['most_balanced_experts'] + print(f"⚖️ Most Balanced Experts: {mb['experiment']}") + print(f" Gini Coefficient: {mb['gini_coefficient']:.3f}\n") + + print("💡 Key Insights:") + for insight in report['insights']: + print(f" • {insight}") + + print(f"\n{'='*80}\n") + + +def main(): + parser = argparse.ArgumentParser(description="Analyze expert specialization patterns") + parser.add_argument('--results-dir', type=str, default='./results', + help='Directory containing experiment results') + parser.add_argument('--output-dir', type=str, default='./analysis', + help='Directory to save analysis outputs') + + args = parser.parse_args() + + results_dir = Path(args.results_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + analyze_expert_specialization(results_dir, output_dir) + + print(f"\n{'='*80}") + print(f"✅ Specialization analysis complete!") + print(f"{'='*80}\n") + + +if __name__ == "__main__": + main() + diff --git a/experiments/exp10_routing_temperature_specialization/config.py b/experiments/exp10_routing_temperature_specialization/config.py new file mode 100644 index 0000000..266d9c6 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/config.py @@ -0,0 +1,186 @@ +""" +Configuration for routing temperature experiments +""" +from dataclasses import dataclass +from typing import Optional, Literal + + +@dataclass +class TemperatureConfig: + """Configuration for a single temperature experiment""" + name: str + description: str + + # Temperature settings + temperature: float = 1.0 + temperature_schedule: Optional[Literal["linear", "cosine", "exponential", "step"]] = None + temperature_start: Optional[float] = None + temperature_end: Optional[float] = None + + # Training settings (inherit from MoEModelConfig defaults) + max_steps: int = 500 # Full training + + def get_temperature_at_step(self, step: int) -> float: + """Calculate temperature at given training step""" + if self.temperature_schedule is None: + return self.temperature + + # Use schedule + start_temp = self.temperature_start if self.temperature_start is not None else self.temperature + end_temp = self.temperature_end if self.temperature_end is not None else self.temperature + + progress = step / self.max_steps + + if self.temperature_schedule == "linear": + return start_temp + (end_temp - start_temp) * progress + + elif self.temperature_schedule == "cosine": + import math + return end_temp + (start_temp - end_temp) * 0.5 * (1 + math.cos(math.pi * progress)) + + elif self.temperature_schedule == "exponential": + import math + # Exponential decay: temp = start * (end/start)^progress + return start_temp * (end_temp / start_temp) ** progress + + elif self.temperature_schedule == "step": + # Step schedule: 5.0 (0-100) → 2.0 (100-300) → 1.0 (300+) + if step < 100: + return 5.0 + elif step < 300: + return 2.0 + else: + return 1.0 + + return self.temperature + + +# Temperature Ablation Experiments +TEMPERATURE_ABLATION = { + "temp_0.5": TemperatureConfig( + name="temp_0.5", + description="Very sharp routing (strong exploitation)", + temperature=0.5, + ), + "temp_0.7": TemperatureConfig( + name="temp_0.7", + description="Sharp routing (moderate exploitation)", + temperature=0.7, + ), + "temp_1.0": TemperatureConfig( + name="temp_1.0", + description="Standard softmax (baseline)", + temperature=1.0, + ), + "temp_1.5": TemperatureConfig( + name="temp_1.5", + description="Slightly softer routing", + temperature=1.5, + ), + "temp_2.0": TemperatureConfig( + name="temp_2.0", + description="Softer routing (more exploration)", + temperature=2.0, + ), + "temp_3.0": TemperatureConfig( + name="temp_3.0", + description="Soft routing (high exploration)", + temperature=3.0, + ), + "temp_5.0": TemperatureConfig( + name="temp_5.0", + description="Very soft routing (maximum exploration)", + temperature=5.0, + ), + "temp_10.0": TemperatureConfig( + name="temp_10.0", + description="Nearly uniform routing (extreme exploration)", + temperature=10.0, + ), +} + +# Temperature Scheduling Experiments +TEMPERATURE_SCHEDULES = { + "schedule_linear": TemperatureConfig( + name="schedule_linear", + description="Linear decay from 5.0 → 1.0", + temperature=5.0, + temperature_schedule="linear", + temperature_start=5.0, + temperature_end=1.0, + ), + "schedule_cosine": TemperatureConfig( + name="schedule_cosine", + description="Cosine decay from 5.0 → 1.0", + temperature=5.0, + temperature_schedule="cosine", + temperature_start=5.0, + temperature_end=1.0, + ), + "schedule_exp": TemperatureConfig( + name="schedule_exp", + description="Exponential decay from 5.0 → 1.0", + temperature=5.0, + temperature_schedule="exponential", + temperature_start=5.0, + temperature_end=1.0, + ), + "schedule_step": TemperatureConfig( + name="schedule_step", + description="Step decay: 5.0 (0-100) → 2.0 (100-300) → 1.0 (300+)", + temperature=5.0, + temperature_schedule="step", + ), +} + +# Extended training with best temperature +EXTENDED_TRAINING = { + "temp_best_long": TemperatureConfig( + name="temp_best_long", + description="Best temperature from ablation, trained for 1000 steps", + temperature=2.0, # Will be updated after ablation + max_steps=1000, # Extended training + ), +} + +# All experiments +ALL_EXPERIMENTS = { + **TEMPERATURE_ABLATION, + **TEMPERATURE_SCHEDULES, + **EXTENDED_TRAINING, +} + + +def get_experiment_config(name: str) -> TemperatureConfig: + """Get experiment configuration by name""" + if name not in ALL_EXPERIMENTS: + raise ValueError(f"Unknown experiment: {name}. Available: {list(ALL_EXPERIMENTS.keys())}") + return ALL_EXPERIMENTS[name] + + +def list_experiments(): + """Print all available experiments""" + print("\n" + "="*80) + print("AVAILABLE EXPERIMENTS") + print("="*80) + + print("\n📊 TEMPERATURE ABLATION (500 steps)") + print("-" * 80) + for name, config in TEMPERATURE_ABLATION.items(): + print(f" {name:20s} - Temp: {config.temperature:5.1f} - {config.description}") + + print("\n📈 TEMPERATURE SCHEDULES (500 steps)") + print("-" * 80) + for name, config in TEMPERATURE_SCHEDULES.items(): + schedule_desc = f"{config.temperature_start} → {config.temperature_end}" if config.temperature_schedule else "constant" + print(f" {name:20s} - Schedule: {config.temperature_schedule or 'none':12s} - {config.description}") + + print("\n🔬 EXTENDED TRAINING (1000 steps)") + print("-" * 80) + for name, config in EXTENDED_TRAINING.items(): + print(f" {name:20s} - Temp: {config.temperature:5.1f} - {config.description}") + + print("\n" + "="*80) + print(f"Total: {len(ALL_EXPERIMENTS)} experiments (~3-4 min each @ 500 steps)") + print("="*80 + "\n") + diff --git a/experiments/exp10_routing_temperature_specialization/plot_results.py b/experiments/exp10_routing_temperature_specialization/plot_results.py new file mode 100644 index 0000000..1142536 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/plot_results.py @@ -0,0 +1,518 @@ +""" +Comprehensive visualization of temperature experiment results +""" +import argparse +import json +from pathlib import Path +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from typing import Dict, List + + +def load_experiment_results(results_dir: Path) -> Dict: + """Load all experiment results from results directory""" + results = {} + + for exp_dir in results_dir.iterdir(): + if exp_dir.is_dir(): + metrics_file = exp_dir / "metrics.json" + if metrics_file.exists(): + with open(metrics_file, 'r') as f: + data = json.load(f) + results[exp_dir.name] = data + + return results + + +def plot_temperature_comparison(results: Dict, output_dir: Path): + """Plot comparison of different temperatures""" + fig = plt.figure(figsize=(20, 12)) + gs = gridspec.GridSpec(3, 3, figure=fig, hspace=0.3, wspace=0.3) + + fig.suptitle('Temperature Ablation: Comprehensive Comparison', fontsize=18, fontweight='bold') + + # Filter temperature ablation experiments + temp_results = {k: v for k, v in results.items() if k.startswith('temp_') and not k.endswith('_long')} + + # Sort by temperature + sorted_results = sorted(temp_results.items(), key=lambda x: x[1]['temperature']) + + # Extract data + temps = [r[1]['temperature'] for r in sorted_results] + names = [r[0] for r in sorted_results] + + # Colors based on temperature + colors = plt.cm.RdYlBu_r(np.linspace(0.2, 0.8, len(temps))) + + # Plot 1: Validation Loss over Steps + ax1 = fig.add_subplot(gs[0, :2]) + for i, (name, data) in enumerate(sorted_results): + history = data['history'] + ax1.plot(history['steps'], history['val_losses'], + label=f"T={data['temperature']:.1f}", + color=colors[i], linewidth=2, marker='o', markersize=4) + ax1.set_xlabel('Training Steps', fontsize=12) + ax1.set_ylabel('Validation Loss', fontsize=12) + ax1.set_title('Validation Loss vs Training Steps', fontsize=14, fontweight='bold') + ax1.legend(fontsize=10, ncol=2) + ax1.grid(True, alpha=0.3) + + # Plot 2: Validation Loss over Time + ax2 = fig.add_subplot(gs[0, 2]) + for i, (name, data) in enumerate(sorted_results): + history = data['history'] + ax2.plot(history['elapsed_times'], history['val_losses'], + color=colors[i], linewidth=2, alpha=0.7) + ax2.set_xlabel('Time (minutes)', fontsize=12) + ax2.set_ylabel('Validation Loss', fontsize=12) + ax2.set_title('Loss vs Wall-Clock Time', fontsize=14, fontweight='bold') + ax2.grid(True, alpha=0.3) + + # Plot 3: Final Performance vs Temperature + ax3 = fig.add_subplot(gs[1, 0]) + final_losses = [r[1]['final_metrics']['val_loss'] for r in sorted_results] + best_losses = [min(r[1]['history']['val_losses']) for r in sorted_results] + ax3.plot(temps, final_losses, 'o-', color='darkred', linewidth=2, markersize=8, label='Final Loss') + ax3.plot(temps, best_losses, 's-', color='darkblue', linewidth=2, markersize=8, label='Best Loss') + ax3.set_xlabel('Temperature', fontsize=12) + ax3.set_ylabel('Validation Loss', fontsize=12) + ax3.set_title('Performance vs Temperature', fontsize=14, fontweight='bold') + ax3.legend(fontsize=10) + ax3.grid(True, alpha=0.3) + ax3.set_xscale('log') + + # Plot 4: Validation Accuracy vs Temperature + ax4 = fig.add_subplot(gs[1, 1]) + final_accs = [r[1]['final_metrics']['val_accuracy'] for r in sorted_results] + ax4.plot(temps, final_accs, 'o-', color='darkgreen', linewidth=2, markersize=8) + ax4.set_xlabel('Temperature', fontsize=12) + ax4.set_ylabel('Validation Accuracy', fontsize=12) + ax4.set_title('Accuracy vs Temperature', fontsize=14, fontweight='bold') + ax4.grid(True, alpha=0.3) + ax4.set_xscale('log') + + # Plot 5: Routing Entropy vs Temperature + ax5 = fig.add_subplot(gs[1, 2]) + final_entropies = [] + for name, data in sorted_results: + if 'routing_entropies' in data['history'] and data['history']['routing_entropies']: + final_entropies.append(data['history']['routing_entropies'][-1]) + else: + final_entropies.append(0) + ax5.plot(temps, final_entropies, 'o-', color='purple', linewidth=2, markersize=8) + ax5.set_xlabel('Temperature', fontsize=12) + ax5.set_ylabel('Routing Entropy', fontsize=12) + ax5.set_title('Routing Entropy vs Temperature', fontsize=14, fontweight='bold') + ax5.grid(True, alpha=0.3) + ax5.set_xscale('log') + + # Plot 6: Accuracy over Steps + ax6 = fig.add_subplot(gs[2, :2]) + for i, (name, data) in enumerate(sorted_results): + history = data['history'] + ax6.plot(history['steps'], history['val_accuracies'], + label=f"T={data['temperature']:.1f}", + color=colors[i], linewidth=2, marker='o', markersize=4) + ax6.set_xlabel('Training Steps', fontsize=12) + ax6.set_ylabel('Validation Accuracy', fontsize=12) + ax6.set_title('Validation Accuracy vs Training Steps', fontsize=14, fontweight='bold') + ax6.legend(fontsize=10, ncol=2) + ax6.grid(True, alpha=0.3) + + # Plot 7: Summary Statistics Table + ax7 = fig.add_subplot(gs[2, 2]) + ax7.axis('tight') + ax7.axis('off') + + # Find best temperature + best_idx = np.argmin(best_losses) + best_temp = temps[best_idx] + best_loss = best_losses[best_idx] + + table_data = [ + ['Metric', 'Value'], + ['Best Temperature', f'{best_temp:.2f}'], + ['Best Loss', f'{best_loss:.4f}'], + ['Worst Loss', f'{max(best_losses):.4f}'], + ['Improvement', f'{((max(best_losses) - best_loss) / max(best_losses) * 100):.1f}%'], + ['Temps Tested', f'{len(temps)}'], + ] + + table = ax7.table(cellText=table_data, cellLoc='left', loc='center', + colWidths=[0.6, 0.4]) + table.auto_set_font_size(False) + table.set_fontsize(10) + table.scale(1, 2) + + # Style header + for i in range(2): + table[(0, i)].set_facecolor('#4CAF50') + table[(0, i)].set_text_props(weight='bold', color='white') + + # Highlight best result + table[(1, 1)].set_facecolor('#FFD700') + table[(2, 1)].set_facecolor('#FFD700') + + ax7.set_title('Summary Statistics', fontsize=14, fontweight='bold') + + plt.savefig(output_dir / 'temperature_ablation_comprehensive.png', + dpi=300, bbox_inches='tight') + print(f"✅ Saved: {output_dir / 'temperature_ablation_comprehensive.png'}") + plt.close() + + +def plot_routing_dynamics(results: Dict, output_dir: Path): + """Plot routing dynamics over training""" + temp_results = {k: v for k, v in results.items() if k.startswith('temp_') and not k.endswith('_long')} + sorted_results = sorted(temp_results.items(), key=lambda x: x[1]['temperature']) + + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + fig.suptitle('Routing Dynamics Analysis', fontsize=18, fontweight='bold') + + temps = [r[1]['temperature'] for r in sorted_results] + colors = plt.cm.RdYlBu_r(np.linspace(0.2, 0.8, len(temps))) + + # Plot 1: Routing Entropy Evolution + ax = axes[0, 0] + for i, (name, data) in enumerate(sorted_results): + history = data['history'] + if 'routing_entropies' in history and history['routing_entropies']: + ax.plot(history['steps'], history['routing_entropies'], + label=f"T={data['temperature']:.1f}", + color=colors[i], linewidth=2) + ax.set_xlabel('Training Steps', fontsize=12) + ax.set_ylabel('Routing Entropy', fontsize=12) + ax.set_title('Routing Entropy Over Training', fontsize=14, fontweight='bold') + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3) + + # Plot 2: Selection Confidence Evolution + ax = axes[0, 1] + for i, (name, data) in enumerate(sorted_results): + history = data['history'] + if 'selection_confidences' in history and history['selection_confidences']: + ax.plot(history['steps'], history['selection_confidences'], + label=f"T={data['temperature']:.1f}", + color=colors[i], linewidth=2) + ax.set_xlabel('Training Steps', fontsize=12) + ax.set_ylabel('Top-1 Selection Confidence', fontsize=12) + ax.set_title('Selection Confidence Over Training', fontsize=14, fontweight='bold') + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3) + + # Plot 3: Load Balancing Loss Evolution + ax = axes[1, 0] + for i, (name, data) in enumerate(sorted_results): + history = data['history'] + if 'load_balancing_losses' in history and history['load_balancing_losses']: + ax.plot(history['steps'], history['load_balancing_losses'], + label=f"T={data['temperature']:.1f}", + color=colors[i], linewidth=2) + ax.set_xlabel('Training Steps', fontsize=12) + ax.set_ylabel('Load Balancing Loss', fontsize=12) + ax.set_title('Load Balancing Loss Over Training', fontsize=14, fontweight='bold') + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3) + + # Plot 4: Temperature vs Final Entropy and Confidence + ax = axes[1, 1] + final_entropies = [] + final_confidences = [] + for name, data in sorted_results: + history = data['history'] + if 'routing_entropies' in history and history['routing_entropies']: + final_entropies.append(history['routing_entropies'][-1]) + else: + final_entropies.append(0) + if 'selection_confidences' in history and history['selection_confidences']: + final_confidences.append(history['selection_confidences'][-1]) + else: + final_confidences.append(0) + + ax2 = ax.twinx() + line1 = ax.plot(temps, final_entropies, 'o-', color='purple', + linewidth=2, markersize=8, label='Entropy') + line2 = ax2.plot(temps, final_confidences, 's-', color='orange', + linewidth=2, markersize=8, label='Confidence') + + ax.set_xlabel('Temperature', fontsize=12) + ax.set_ylabel('Final Routing Entropy', fontsize=12, color='purple') + ax2.set_ylabel('Final Selection Confidence', fontsize=12, color='orange') + ax.set_title('Temperature vs Final Routing Metrics', fontsize=14, fontweight='bold') + ax.set_xscale('log') + ax.tick_params(axis='y', labelcolor='purple') + ax2.tick_params(axis='y', labelcolor='orange') + ax.grid(True, alpha=0.3) + + # Combine legends + lines = line1 + line2 + labels = [l.get_label() for l in lines] + ax.legend(lines, labels, fontsize=10) + + plt.tight_layout() + plt.savefig(output_dir / 'routing_dynamics.png', dpi=300, bbox_inches='tight') + print(f"✅ Saved: {output_dir / 'routing_dynamics.png'}") + plt.close() + + +def plot_expert_utilization(results: Dict, output_dir: Path): + """Plot expert utilization patterns""" + temp_results = {k: v for k, v in results.items() if k.startswith('temp_') and not k.endswith('_long')} + sorted_results = sorted(temp_results.items(), key=lambda x: x[1]['temperature']) + + fig, axes = plt.subplots(2, 4, figsize=(20, 10)) + fig.suptitle('Expert Utilization Patterns', fontsize=18, fontweight='bold') + + axes = axes.flatten() + + for idx, (name, data) in enumerate(sorted_results[:8]): # Plot up to 8 + ax = axes[idx] + + # Get final expert utilization + history = data['history'] + if 'expert_utilizations' in history and history['expert_utilizations']: + final_util = history['expert_utilizations'][-1] + + experts = list(range(len(final_util))) + bars = ax.bar(experts, final_util, color='steelblue', alpha=0.8) + + # Color bars by utilization + max_util = max(final_util) if final_util else 1 + for bar, util in zip(bars, final_util): + bar.set_color(plt.cm.RdYlGn(util / max_util)) + + ax.axhline(y=1.0/len(final_util), color='red', linestyle='--', + linewidth=2, label='Uniform') + ax.set_xlabel('Expert Index', fontsize=10) + ax.set_ylabel('Utilization', fontsize=10) + ax.set_title(f'Temperature = {data["temperature"]:.1f}', + fontsize=12, fontweight='bold') + ax.set_ylim(0, max(final_util) * 1.2 if final_util else 1) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3, axis='y') + + # Hide unused subplots + for idx in range(len(sorted_results), 8): + axes[idx].axis('off') + + plt.tight_layout() + plt.savefig(output_dir / 'expert_utilization.png', dpi=300, bbox_inches='tight') + print(f"✅ Saved: {output_dir / 'expert_utilization.png'}") + plt.close() + + +def plot_schedule_comparison(results: Dict, output_dir: Path): + """Plot temperature schedule comparisons""" + schedule_results = {k: v for k, v in results.items() if k.startswith('schedule_')} + + if not schedule_results: + print("⚠️ No schedule experiments found, skipping schedule comparison") + return + + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + fig.suptitle('Temperature Schedule Comparison', fontsize=18, fontweight='bold') + + # Plot 1: Validation Loss + ax = axes[0, 0] + for name, data in schedule_results.items(): + history = data['history'] + ax.plot(history['steps'], history['val_losses'], + label=name, linewidth=2, marker='o', markersize=4) + ax.set_xlabel('Training Steps', fontsize=12) + ax.set_ylabel('Validation Loss', fontsize=12) + ax.set_title('Loss: Schedule Comparison', fontsize=14, fontweight='bold') + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + # Plot 2: Temperature Evolution + ax = axes[0, 1] + for name, data in schedule_results.items(): + history = data['history'] + if 'temperatures' in history: + ax.plot(history['steps'], history['temperatures'], + label=name, linewidth=2) + ax.set_xlabel('Training Steps', fontsize=12) + ax.set_ylabel('Temperature', fontsize=12) + ax.set_title('Temperature Schedule Evolution', fontsize=14, fontweight='bold') + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + # Plot 3: Validation Accuracy + ax = axes[1, 0] + for name, data in schedule_results.items(): + history = data['history'] + ax.plot(history['steps'], history['val_accuracies'], + label=name, linewidth=2, marker='o', markersize=4) + ax.set_xlabel('Training Steps', fontsize=12) + ax.set_ylabel('Validation Accuracy', fontsize=12) + ax.set_title('Accuracy: Schedule Comparison', fontsize=14, fontweight='bold') + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + # Plot 4: Final Metrics Comparison + ax = axes[1, 1] + schedule_names = list(schedule_results.keys()) + final_losses = [data['final_metrics']['val_loss'] for data in schedule_results.values()] + final_accs = [data['final_metrics']['val_accuracy'] for data in schedule_results.values()] + + x = np.arange(len(schedule_names)) + width = 0.35 + + ax.bar(x - width/2, final_losses, width, label='Loss', alpha=0.8, color='coral') + ax2 = ax.twinx() + ax2.bar(x + width/2, final_accs, width, label='Accuracy', alpha=0.8, color='lightblue') + + ax.set_xlabel('Schedule', fontsize=12) + ax.set_ylabel('Final Loss', fontsize=12, color='coral') + ax2.set_ylabel('Final Accuracy', fontsize=12, color='lightblue') + ax.set_title('Final Performance Comparison', fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels(schedule_names, rotation=45, ha='right') + ax.tick_params(axis='y', labelcolor='coral') + ax2.tick_params(axis='y', labelcolor='lightblue') + ax.grid(True, alpha=0.3, axis='y') + + # Combine legends + lines1, labels1 = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax.legend(lines1 + lines2, labels1 + labels2, fontsize=10) + + plt.tight_layout() + plt.savefig(output_dir / 'schedule_comparison.png', dpi=300, bbox_inches='tight') + print(f"✅ Saved: {output_dir / 'schedule_comparison.png'}") + plt.close() + + +def generate_summary_report(results: Dict, output_dir: Path): + """Generate comprehensive summary report""" + report = { + 'experiment_overview': { + 'total_experiments': len(results), + 'experiment_names': list(results.keys()), + }, + 'best_results': {}, + 'temperature_analysis': {}, + 'schedule_analysis': {}, + } + + # Temperature ablation analysis + temp_results = {k: v for k, v in results.items() if k.startswith('temp_') and not k.endswith('_long')} + if temp_results: + best_exp = min(temp_results.items(), key=lambda x: min(x[1]['history']['val_losses'])) + best_name, best_data = best_exp + + report['best_results']['temperature_ablation'] = { + 'experiment': best_name, + 'temperature': best_data['temperature'], + 'best_loss': min(best_data['history']['val_losses']), + 'final_loss': best_data['final_metrics']['val_loss'], + 'final_accuracy': best_data['final_metrics']['val_accuracy'], + } + + # Temperature analysis + temps = sorted([(v['temperature'], min(v['history']['val_losses'])) + for v in temp_results.values()]) + report['temperature_analysis'] = { + 'tested_temperatures': [t[0] for t in temps], + 'losses': [t[1] for t in temps], + 'best_temperature': best_data['temperature'], + 'worst_temperature': max(temps, key=lambda x: x[1])[0], + 'improvement': ((max(t[1] for t in temps) - min(t[1] for t in temps)) / + max(t[1] for t in temps) * 100), + } + + # Schedule analysis + schedule_results = {k: v for k, v in results.items() if k.startswith('schedule_')} + if schedule_results: + best_schedule = min(schedule_results.items(), + key=lambda x: min(x[1]['history']['val_losses'])) + best_name, best_data = best_schedule + + report['best_results']['temperature_schedule'] = { + 'experiment': best_name, + 'schedule_type': best_data['temperature_schedule'], + 'best_loss': min(best_data['history']['val_losses']), + 'final_loss': best_data['final_metrics']['val_loss'], + 'final_accuracy': best_data['final_metrics']['val_accuracy'], + } + + schedule_losses = {k: min(v['history']['val_losses']) + for k, v in schedule_results.items()} + report['schedule_analysis'] = { + 'schedules_tested': list(schedule_losses.keys()), + 'losses': schedule_losses, + 'best_schedule': best_name, + } + + # Save report + report_file = output_dir / 'summary_report.json' + with open(report_file, 'w') as f: + json.dump(report, f, indent=2) + print(f"✅ Saved: {report_file}") + + # Print summary to console + print(f"\n{'='*80}") + print("EXPERIMENT SUMMARY") + print(f"{'='*80}\n") + + if 'temperature_ablation' in report['best_results']: + ta = report['best_results']['temperature_ablation'] + print(f"🏆 Best Temperature: {ta['temperature']:.2f}") + print(f" Loss: {ta['best_loss']:.4f}") + print(f" Accuracy: {ta['final_accuracy']:.4f}\n") + + if 'temperature_schedule' in report['best_results']: + ts = report['best_results']['temperature_schedule'] + print(f"🏆 Best Schedule: {ts['experiment']}") + print(f" Loss: {ts['best_loss']:.4f}") + print(f" Accuracy: {ts['final_accuracy']:.4f}\n") + + if 'temperature_analysis' in report: + ta = report['temperature_analysis'] + print(f"📊 Temperature Analysis:") + print(f" Improvement: {ta['improvement']:.2f}%") + print(f" Best: T={ta['best_temperature']:.2f}") + print(f" Worst: T={ta['worst_temperature']:.2f}\n") + + print(f"{'='*80}\n") + + +def main(): + parser = argparse.ArgumentParser(description="Plot temperature experiment results") + parser.add_argument('--results-dir', type=str, default='./results', + help='Directory containing experiment results') + parser.add_argument('--output-dir', type=str, default='./analysis', + help='Directory to save plots') + + args = parser.parse_args() + + results_dir = Path(args.results_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Loading results from {results_dir}...") + results = load_experiment_results(results_dir) + + if not results: + print(f"❌ No results found in {results_dir}") + return + + print(f"Found {len(results)} experiments") + print(f"Generating visualizations...\n") + + # Generate all plots + plot_temperature_comparison(results, output_dir) + plot_routing_dynamics(results, output_dir) + plot_expert_utilization(results, output_dir) + plot_schedule_comparison(results, output_dir) + generate_summary_report(results, output_dir) + + print(f"\n{'='*80}") + print(f"✅ All visualizations saved to {output_dir}") + print(f"{'='*80}\n") + + +if __name__ == "__main__": + main() + diff --git a/experiments/exp10_routing_temperature_specialization/quick_demo.sh b/experiments/exp10_routing_temperature_specialization/quick_demo.sh new file mode 100755 index 0000000..3bbdf25 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/quick_demo.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Quick demo script to run a subset of temperature experiments + +echo "=========================================" +echo "Temperature Routing Experiment - Quick Demo" +echo "=========================================" +echo "" +echo "This script runs a subset of temperature experiments:" +echo " 1. temp_0.7 - Sharp routing" +echo " 2. temp_1.0 - Baseline" +echo " 3. temp_2.0 - Soft routing" +echo "" +echo "Each experiment runs for 500 steps (~2-3 minutes on GPU)" +echo "" + +cd /root/blueberry-llm/experiments/exp10_routing_temperature_specialization + +# Run 3 representative temperatures +python run_experiment.py --experiments temp_0.7 temp_1.0 temp_2.0 --output-dir ./results + +echo "" +echo "=========================================" +echo "Generating visualizations..." +echo "=========================================" +echo "" + +# Generate plots and analysis +python plot_results.py --results-dir ./results --output-dir ./analysis +python analyze_specialization.py --results-dir ./results --output-dir ./analysis + +echo "" +echo "=========================================" +echo "Demo complete!" +echo "=========================================" +echo "" +echo "Results saved in:" +echo " - ./results/ - Individual experiment results" +echo " - ./analysis/ - Comparative plots and analysis" +echo "" + diff --git a/experiments/exp10_routing_temperature_specialization/quick_test.py b/experiments/exp10_routing_temperature_specialization/quick_test.py new file mode 100644 index 0000000..384a9bf --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/quick_test.py @@ -0,0 +1,59 @@ +""" +Quick test script - runs a very short experiment to verify the pipeline works +""" +import sys +import os +from pathlib import Path + +# Add paths +script_dir = Path(__file__).resolve().parent +project_root = script_dir.parent.parent +sys.path.insert(0, str(script_dir)) +sys.path.insert(0, str(project_root)) + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Import after paths are set +from config import TemperatureConfig, list_experiments + +# Create a very short test configuration +TEST_CONFIG = TemperatureConfig( + name="test_quick", + description="Quick test run with 50 steps", + temperature=1.0, + max_steps=50, # Very short for testing +) + +def main(): + print("\n" + "="*80) + print("QUICK TEST: Temperature Routing Experiment") + print("="*80 + "\n") + + print("This is a quick test to verify the experiment pipeline works.") + print(f"Running {TEST_CONFIG.max_steps} steps (should take <1 minute)\n") + + # List available experiments + print("Available experiment types:") + list_experiments() + + print("\n" + "="*80) + print("Test configuration:") + print("="*80) + print(f"Name: {TEST_CONFIG.name}") + print(f"Description: {TEST_CONFIG.description}") + print(f"Temperature: {TEST_CONFIG.temperature}") + print(f"Steps: {TEST_CONFIG.max_steps}") + print("="*80 + "\n") + + print("✅ Configuration test passed!") + print("\nTo run a full experiment, use:") + print(" python run_experiment.py --experiment temp_1.0") + print("\nTo run the quick demo (3 temps, 500 steps each):") + print(" bash quick_demo.sh") + print("\nTo run full temperature ablation:") + print(" python run_experiment.py --ablation") + print("") + +if __name__ == "__main__": + main() + diff --git a/experiments/exp10_routing_temperature_specialization/results/experiment_summary.json b/experiments/exp10_routing_temperature_specialization/results/experiment_summary.json new file mode 100644 index 0000000..6ace6f5 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/experiment_summary.json @@ -0,0 +1,19 @@ +{ + "experiments": [ + "temp_0.5", + "temp_0.7", + "temp_1.0", + "temp_1.5", + "temp_2.0", + "temp_3.0", + "temp_5.0", + "temp_10.0", + "schedule_linear", + "schedule_cosine", + "schedule_exp", + "schedule_step", + "temp_best_long" + ], + "num_completed": 13, + "num_requested": 13 +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/schedule_cosine/metrics.json b/experiments/exp10_routing_temperature_specialization/results/schedule_cosine/metrics.json new file mode 100644 index 0000000..f9f30ff --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/schedule_cosine/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "schedule_cosine", + "description": "Cosine decay from 5.0 \u2192 1.0", + "temperature": 5.0, + "temperature_schedule": "cosine", + "final_metrics": { + "val_loss": 24.06814131045931, + "val_accuracy": 0.01640931313229101, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.069657039642334, + 4.464751672744751, + 3.1876614570617674, + 1.8690967321395875, + 0.7620087236166, + 0.3555591553449631, + 0.22341943085193633, + 0.16186958104372023, + 0.11357317417860031, + 0.09325627535581589, + 0.049703499488532546, + 0.017852141335606576, + 0.013934982474893332, + 0.011283675767481327, + 0.010491264518350363, + 0.010769488289952278, + 0.009670319315046073, + 0.008400135021656752, + 0.008567010285332799, + 0.007932101562619209, + 0.009679869608953595, + 0.0010790330707095563, + 0.0008006245130673051, + 0.0008102733991108835, + 0.0008768657746259124, + 0.0008497504226397723, + 0.0007379433052847161, + 0.0005650699196849019, + 0.0005565871018916368, + 0.00046176541363820436, + 0.0003697640495374799, + 0.00033591274841455744, + 0.000280905666295439, + 0.0002793343985104002, + 0.00028857155557489024, + 0.00027574378764256836, + 0.00026764055073726924, + 0.0002348719324800186, + 0.0002476606299751438, + 0.00022617928916588426, + 0.00020913604239467532, + 0.00023350362753262743, + 0.00017466716381022706, + 0.0001733527475153096, + 0.00018735515623120592, + 0.00019870568939950317, + 0.00018067110213451087, + 0.00019405925704631954, + 0.00017952858615899459, + 0.0001831858404329978 + ], + "val_losses": [ + 10.773804614063708, + 10.73688121566503, + 10.739477326086469, + 11.276299658596725, + 12.465000371629694, + 14.964032119238755, + 16.711366835415575, + 19.372482084132756, + 20.90315522911692, + 22.871643154022973, + 23.895708144764175, + 24.861502967537923, + 25.180418371733, + 25.198307340642167, + 25.085633099289748, + 24.409256736297067, + 24.310409889625575, + 24.249587015212636, + 24.029307213773155, + 24.447470135907825, + 24.514754891816803, + 24.18615792749627, + 24.179366795839776, + 24.183270167967457, + 24.22122092634545, + 23.885647608618854, + 24.185753825696533, + 24.08758477524397, + 24.156039591812835, + 24.289343958608676, + 24.230099128750105, + 24.201055452596172, + 24.10836920047396, + 24.235484018763888, + 24.237351003047014, + 24.23093422394338, + 24.115427037431157, + 24.2962409986624, + 24.261355983073635, + 24.163045465314347, + 24.132357701817167, + 24.231192450641323, + 24.25127852342154, + 24.124626051832003, + 24.05677576368352, + 24.117051963671358, + 24.072576583484878, + 24.03192264705159, + 24.05440668166737, + 24.06814131045931 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.01640931313229101, + 0.01640931313229101 + ], + "val_perplexities": [ + 47753.35542228152, + 46022.294249708524, + 46141.92843201004, + 78928.65787734614, + 259108.03354063618, + 3153527.16484821, + 18098992.567186855, + 259037159.9629599, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.03468966484069824, + 0.0759025494257609, + 0.11283877293268839, + 0.15139784415562949, + 0.18865657647450765, + 0.22771941820780436, + 0.26534924904505414, + 0.3040532151858012, + 0.3415249824523926, + 0.38025893370310465, + 0.4173173268636068, + 0.45601539611816405, + 0.4932568669319153, + 0.5319392323493958, + 0.5693146387736002, + 0.607805609703064, + 0.6450861533482869, + 0.6838066299756368, + 0.7210066040356954, + 0.7700710892677307, + 0.8071453849474589, + 0.8458340962727865, + 0.8830396652221679, + 0.9215436100959777, + 0.9586852033933003, + 0.9969743450482687, + 1.0341073592503867, + 1.0726868629455566, + 1.1101802666982015, + 1.149470082918803, + 1.1872127453486125, + 1.225475537776947, + 1.2624629577000936, + 1.3007176319758098, + 1.3381258368492126, + 1.3765281518300374, + 1.4140727877616883, + 1.4532278378804524, + 1.4906831860542298, + 1.5291387438774109, + 1.5664683341979981, + 1.6045465071996052, + 1.6416728218396506, + 1.6802144686381022, + 1.7172197739283244, + 1.7555572470029195, + 1.792522370815277, + 1.8409465670585632, + 1.877986188729604, + 1.9163748661677042 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 4.99680310021795, + 4.985765209139627, + 4.9668904099906594, + 4.940253192980212, + 4.905958683034438, + 4.864142224916422, + 4.814968849082234, + 4.758632620381112, + 4.6953558721701665, + 4.625388328866188, + 4.549006120397467, + 4.46651069244512, + 4.378227616774697, + 4.284505306353169, + 4.185713640322119, + 4.082242504253752, + 3.9745002514506647, + 3.862912091361918, + 3.7479184114756006, + 3.6299730393106096, + 3.5095414513667644, + 3.387098936101721, + 3.263128718184565, + 3.138120051428812, + 3.012566287931118, + 2.886962931035951, + 2.7618056798102604, + 2.6375884727457253, + 2.5148015384091855, + 2.3939294607344523, + 2.2754492665909085, + 2.1598285431763875, + 2.0475235926641213, + 1.9389776313865315, + 1.8346190406628482, + 1.7348596761737511, + 1.6400932425551615, + 1.5506937396259066, + 1.4670139863813005, + 1.3893842285777562, + 1.318110835403662, + 1.2534750903801601, + 1.1957320812635945, + 1.1451096933306775, + 1.101807710019411, + 1.06599702447513, + 1.037818965113332, + 1.017384737860987, + 1.0047749872775495, + 1.0000394782877258 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12473323568701744, + 0.11382735644777615, + 0.139450674255689, + 0.12831438208619753, + 0.13368680079778036, + 0.12623234341541925, + 0.12073579803109169, + 0.11301920562982559 + ], + [ + 0.13074986139933267, + 0.11635972807804744, + 0.13683613141377768, + 0.12215333804488182, + 0.12273834273219109, + 0.12617023040850958, + 0.13310465092460314, + 0.11188747609655063 + ], + [ + 0.1346402702232202, + 0.11434124658505122, + 0.132917037854592, + 0.12263403460383415, + 0.11620396499832471, + 0.1256436457236608, + 0.13484868531425795, + 0.11877089738845825 + ], + [ + 0.13973531996210417, + 0.11876375352342923, + 0.13663236424326897, + 0.12599164744218191, + 0.10549194862445195, + 0.12095922604203224, + 0.13652594884236655, + 0.11589954420924187 + ], + [ + 0.13239124168952307, + 0.12700294330716133, + 0.14263620103398958, + 0.12503278503815332, + 0.10186273232102394, + 0.1242215596139431, + 0.13387257109085718, + 0.11297972500324249 + ], + [ + 0.12738784650961557, + 0.12895409390330315, + 0.139613206187884, + 0.1254501317938169, + 0.11713359256585439, + 0.12308299293120702, + 0.13411941130956015, + 0.10425849383076032 + ], + [ + 0.12448340406020482, + 0.12290012463927269, + 0.1353225732843081, + 0.12786233176787695, + 0.13233929251631102, + 0.11918221786618233, + 0.13031593461831412, + 0.10759389400482178 + ], + [ + 0.1220518263677756, + 0.12249967083334923, + 0.1253726544479529, + 0.13003160307804743, + 0.13357282554109892, + 0.12259802843133609, + 0.12276356667280197, + 0.12110959117611249 + ], + [ + 0.12160174796978633, + 0.12197340155641238, + 0.12153222287694614, + 0.13319215923547745, + 0.1325893277923266, + 0.12300913284222285, + 0.11601957430442174, + 0.13008220245440802 + ], + [ + 0.12285137176513672, + 0.12688441822926202, + 0.12505292519927025, + 0.12690721824765205, + 0.13247318441669145, + 0.11919539670149486, + 0.11709122980634372, + 0.12954403335849443 + ], + [ + 0.12633887181679407, + 0.12995417416095734, + 0.12844157591462135, + 0.12488291164239247, + 0.1297601635257403, + 0.12131790940960248, + 0.11710098758339882, + 0.12220316752791405 + ], + [ + 0.12519476438562074, + 0.1214562567571799, + 0.12268727521101634, + 0.12669246892134348, + 0.12849828973412514, + 0.12003543352087338, + 0.12331271047393481, + 0.13212257996201515 + ], + [ + 0.12595185140768686, + 0.11785857379436493, + 0.12245255087812741, + 0.12512790163358053, + 0.1291645703216394, + 0.11958248789111774, + 0.124679796397686, + 0.13518204043308893 + ], + [ + 0.13348439459999403, + 0.12604759633541107, + 0.11845917503039043, + 0.11861602341135342, + 0.12412416562438011, + 0.12362225602070491, + 0.1256026861568292, + 0.13004347681999207 + ], + [ + 0.12873519087831178, + 0.12751456225911775, + 0.10823257515827815, + 0.12079133962591489, + 0.12041651457548141, + 0.12902195130785307, + 0.12854893133044243, + 0.13673870265483856 + ], + [ + 0.13080348074436188, + 0.12207291647791862, + 0.12082832554976146, + 0.11692502349615097, + 0.13115477561950684, + 0.1197350745399793, + 0.12799837440252304, + 0.13048180441061655 + ], + [ + 0.11812729388475418, + 0.13565508648753166, + 0.11730894198020299, + 0.11602873851855595, + 0.13778389245271683, + 0.10869153589010239, + 0.12651361773411432, + 0.13989064594109854 + ], + [ + 0.12130231286088626, + 0.11515853057305019, + 0.12389607354998589, + 0.1297798603773117, + 0.1189784196515878, + 0.13041318580508232, + 0.11983061085144679, + 0.1406407654285431 + ], + [ + 0.13254226123293242, + 0.11489111185073853, + 0.12229649225870769, + 0.12720056747396788, + 0.10790204505125682, + 0.13950017342964807, + 0.1212441051999728, + 0.1344230199853579 + ], + [ + 0.1307671827574571, + 0.12482758983969688, + 0.1261162223915259, + 0.11147621770699818, + 0.11841998373468716, + 0.13240713501969972, + 0.11380701760451, + 0.14217842866977057 + ], + [ + 0.132644505550464, + 0.11844757199287415, + 0.11880502477288246, + 0.11552153279383977, + 0.13142997026443481, + 0.12773855403065681, + 0.11954698339104652, + 0.13586563616991043 + ], + [ + 0.12930652871727943, + 0.12672042970856032, + 0.11539290224512418, + 0.12450442835688591, + 0.12756239250302315, + 0.12999829029043516, + 0.12612264851729074, + 0.12039215117692947 + ], + [ + 0.1163909062743187, + 0.11674513667821884, + 0.10233827059467633, + 0.12813481812675795, + 0.1336732879281044, + 0.14153830707073212, + 0.1250646449625492, + 0.13611439739664397 + ], + [ + 0.11738963052630424, + 0.12948323786258698, + 0.11958456287781398, + 0.11339040969808896, + 0.12744004900256792, + 0.13497323046127954, + 0.13164062922199568, + 0.1260980156560739 + ], + [ + 0.12214399129152298, + 0.1334065484503905, + 0.10820431510607402, + 0.11478110899527867, + 0.12978268538912138, + 0.13313823441664377, + 0.13090246667464575, + 0.12764043609301248 + ], + [ + 0.1293958599368731, + 0.12335519616802533, + 0.11419479176402092, + 0.13307242095470428, + 0.11853573719660442, + 0.13586700210968652, + 0.12430663282672565, + 0.12127214421828587 + ], + [ + 0.11442689721783002, + 0.12603950748840967, + 0.14377529049913088, + 0.12445535138249397, + 0.12229052806893985, + 0.1283992330233256, + 0.11827573925256729, + 0.12233723203341167 + ], + [ + 0.12955031668146452, + 0.1247330072025458, + 0.12181367352604866, + 0.11640448495745659, + 0.11735814188917477, + 0.12860990688204765, + 0.12919481347004572, + 0.13233542566498122 + ], + [ + 0.12579264491796494, + 0.1267194946606954, + 0.1290343850851059, + 0.11855429783463478, + 0.12952119608720145, + 0.12258608018358548, + 0.1221676655113697, + 0.12562399357557297 + ], + [ + 0.11103501543402672, + 0.12818804507454237, + 0.12510508919755617, + 0.12843561048309007, + 0.1346649006009102, + 0.12529908989866576, + 0.12607909614841142, + 0.12119293957948685 + ], + [ + 0.10968777785698573, + 0.1284580094118913, + 0.12039777884880702, + 0.13105766226847967, + 0.1255891720453898, + 0.12861098100741705, + 0.1216950664917628, + 0.1345033347606659 + ], + [ + 0.12550330037871996, + 0.11542637770374616, + 0.12164511904120445, + 0.12761637568473816, + 0.12272314354777336, + 0.1218589221437772, + 0.12981553872426352, + 0.13541100298364958 + ], + [ + 0.12639697765310606, + 0.12224535768230756, + 0.11700818190972011, + 0.12871456146240234, + 0.1233878992497921, + 0.12472214053074519, + 0.12426582475503285, + 0.13325883572300276 + ], + [ + 0.12453013534347217, + 0.12501890336473784, + 0.12742897247274718, + 0.12272479136784871, + 0.1265820823609829, + 0.1227731704711914, + 0.12878064066171646, + 0.1221610854069392 + ], + [ + 0.11887981990973155, + 0.12262246881922086, + 0.12840694934129715, + 0.11971587811907132, + 0.12567466497421265, + 0.1228187804420789, + 0.12781167030334473, + 0.13406953091422716 + ], + [ + 0.12234862893819809, + 0.1274310052394867, + 0.12108884006738663, + 0.13159961998462677, + 0.12342866137623787, + 0.12614962458610535, + 0.12644772231578827, + 0.12150566652417183 + ], + [ + 0.11623904357353847, + 0.12574931730826697, + 0.12459515656034152, + 0.1316892591615518, + 0.11918096989393234, + 0.12365549181898434, + 0.13525293270746866, + 0.12363758559028308 + ], + [ + 0.12697953854997954, + 0.12157645324865977, + 0.119102676709493, + 0.12276362876097362, + 0.13029536480704942, + 0.13067887102564177, + 0.12420013422767322, + 0.12440310666958491 + ], + [ + 0.12084526444474857, + 0.12226648007829984, + 0.11892024924357732, + 0.1251565838853518, + 0.12728688369194666, + 0.12784678116440773, + 0.13224864502747855, + 0.12542887901266417 + ], + [ + 0.1230640560388565, + 0.12118242556850116, + 0.12733721857269606, + 0.13258815556764603, + 0.12404076506694157, + 0.12495979790886243, + 0.11910634984572728, + 0.1277210054298242 + ], + [ + 0.12727698559562364, + 0.11819724986950557, + 0.12680891901254654, + 0.12819817289710045, + 0.12524592503905296, + 0.12494295835494995, + 0.12223132327198982, + 0.12709823995828629 + ], + [ + 0.12163988873362541, + 0.1256983665128549, + 0.12736307208736738, + 0.12127974381049474, + 0.12543400873740515, + 0.12674884249766669, + 0.13067526618639627, + 0.12116058791677158 + ], + [ + 0.12180201212565105, + 0.12034623697400093, + 0.12037375569343567, + 0.12500147645672163, + 0.12568610409895578, + 0.1320569192369779, + 0.13425932948788008, + 0.12047395358482997 + ], + [ + 0.12279933566848437, + 0.12307638054092725, + 0.12647893776496252, + 0.12079178417722385, + 0.13047230119506517, + 0.1264300917585691, + 0.12567398448785147, + 0.12427695964773496 + ], + [ + 0.11546471218268077, + 0.11922568952043851, + 0.12555858368674913, + 0.12254588802655537, + 0.12672735129793486, + 0.13214183102051416, + 0.1292207626005014, + 0.12911495938897133 + ], + [ + 0.1240631639957428, + 0.12654056772589684, + 0.12577811007698378, + 0.1194851982096831, + 0.12523702283700308, + 0.12366656710704167, + 0.1270997958878676, + 0.1281293568511804 + ], + [ + 0.12079654633998871, + 0.124654454489549, + 0.12053526813785236, + 0.11557366574803989, + 0.13266202559073767, + 0.12400228654344876, + 0.1277042180299759, + 0.13407130042711893 + ], + [ + 0.11985222746928532, + 0.12334347764650981, + 0.12739611665407816, + 0.1287225882212321, + 0.1183634636302789, + 0.12666833400726318, + 0.13182911152640978, + 0.12382446105281512 + ], + [ + 0.12290976444880168, + 0.12571510300040245, + 0.1237668829659621, + 0.13069668784737587, + 0.12329148997863133, + 0.12108598525325458, + 0.13188782085975012, + 0.12064604833722115 + ], + [ + 0.12022721394896507, + 0.12498688821991284, + 0.12935581554969153, + 0.12518025562167168, + 0.12690354386965433, + 0.12309145058194797, + 0.12400985633333524, + 0.12624474118153253 + ] + ], + "load_balancing_losses": [ + 0.06004496216773987, + 0.060032990947365764, + 0.060045778006315234, + 0.06012898422777653, + 0.06028467863798141, + 0.060317789763212205, + 0.060203588008880614, + 0.06008061580359936, + 0.06000344827771187, + 0.059994347393512726, + 0.06003178991377354, + 0.05997398346662521, + 0.060018420591950415, + 0.06007139682769776, + 0.06022535488009453, + 0.06041265763342381, + 0.06053830087184906, + 0.06052085943520069, + 0.06074642017483711, + 0.06120501421391964, + 0.060415341332554814, + 0.06046857684850693, + 0.06050385870039463, + 0.060686058923602106, + 0.06057383455336094, + 0.060476109758019446, + 0.060880960896611214, + 0.06069233492016792, + 0.060470017790794375, + 0.060535190254449846, + 0.06045514903962612, + 0.060405534133315085, + 0.06034016571938992, + 0.06039765775203705, + 0.06030287072062492, + 0.06034483872354031, + 0.0603365920484066, + 0.06044961810112, + 0.06022889465093613, + 0.060366106778383256, + 0.06026962883770466, + 0.06024487540125847, + 0.06025802828371525, + 0.06027736067771912, + 0.06026877649128437, + 0.06021323874592781, + 0.06026680096983909, + 0.06016013324260712, + 0.06024349182844162, + 0.0601899154484272 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.11895587171117465, + 0.12490880365173022, + 0.1310289315879345, + 0.12453617031375568, + 0.12802273780107498, + 0.12285855039954185, + 0.12266377235452335, + 0.1270249473551909 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/schedule_exp/metrics.json b/experiments/exp10_routing_temperature_specialization/results/schedule_exp/metrics.json new file mode 100644 index 0000000..4c927aa --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/schedule_exp/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "schedule_exp", + "description": "Exponential decay from 5.0 \u2192 1.0", + "temperature": 5.0, + "temperature_schedule": "exponential", + "final_metrics": { + "val_loss": 24.10983462384227, + "val_accuracy": 0.016402398124649928, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.068965244293213, + 4.465076065063476, + 3.1882030725479127, + 1.8691477060317994, + 0.7625107705593109, + 0.3560707598924637, + 0.22384800761938095, + 0.16213105320930482, + 0.11350352391600609, + 0.09394646733999253, + 0.04979750290513039, + 0.017670586239546537, + 0.01402327986434102, + 0.011493024136871099, + 0.010451155342161655, + 0.010522020468488335, + 0.009666885854676366, + 0.008464573975652456, + 0.008570067444816231, + 0.00792343313805759, + 0.009731251373887062, + 0.0010506233782507479, + 0.0008077768434304744, + 0.0008309792669024318, + 0.0008686139248311519, + 0.0009503068053163588, + 0.0007482638000510633, + 0.0005941699841059744, + 0.0005462797154905275, + 0.0004534960258752108, + 0.00038899176579434427, + 0.0003774055352550931, + 0.0003225495383958332, + 0.00027401681145420296, + 0.00027798613009508697, + 0.00028098177135689185, + 0.000261071661952883, + 0.00025207002327078956, + 0.00025684188294690103, + 0.00023173732770374045, + 0.00021625933877658098, + 0.00024576874129706995, + 0.00017804754606913774, + 0.0001766106055583805, + 0.00018065856711473315, + 0.00018787720910040662, + 0.00017227771750185638, + 0.0001732482123770751, + 0.0001662593858782202, + 0.00018492242088541389 + ], + "val_losses": [ + 10.7737612943346, + 10.737049510537947, + 10.739191857327842, + 11.27567576014111, + 12.462526122588573, + 14.964096274898246, + 16.718264731417275, + 19.37405140745345, + 20.89652683541126, + 22.901575519844837, + 23.90412493055364, + 24.8683575202214, + 25.095762043875435, + 25.169508418430286, + 24.955148622762188, + 24.887071070317244, + 24.73069033268905, + 24.244159132347512, + 24.34232870176066, + 23.800271805941847, + 23.849213158705208, + 24.217105090407518, + 23.9178601254844, + 23.96636208038869, + 23.856432742870318, + 24.068271414551212, + 24.108461043017492, + 24.17178448518679, + 24.260891075269072, + 24.285787461082002, + 24.2680606842041, + 24.06867962759712, + 24.157068596290614, + 24.168768589572434, + 24.17590230261058, + 24.20804418691898, + 24.138611749709707, + 24.217521121560896, + 24.161620615227903, + 24.24126792206781, + 24.19493352398013, + 24.261854421123598, + 24.279224260956997, + 24.248636730989382, + 24.172418789813037, + 24.14021366645085, + 24.09944395462953, + 24.164141051760833, + 24.14445340591269, + 24.10983462384227 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928 + ], + "val_perplexities": [ + 47751.286804666895, + 46030.04021765572, + 46128.75823290956, + 78879.42976786029, + 258467.72820119502, + 3153729.4879532093, + 18224269.11120567, + 259443992.16114965, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.02916040817896525, + 0.07016578515370687, + 0.10749046007792155, + 0.14621901114781696, + 0.18361044724782308, + 0.2220892310142517, + 0.25910220940907797, + 0.2974821130434672, + 0.3343779643376668, + 0.3777258276939392, + 0.41875929832458497, + 0.4568229913711548, + 0.49366676807403564, + 0.5320061961809794, + 0.5688212831815084, + 0.6071259299914042, + 0.6440495769182841, + 0.6822756171226502, + 0.7192215522130331, + 0.7574461301167806, + 0.794350779056549, + 0.8328560511271159, + 0.8698898633321126, + 0.9082265575726827, + 0.9452325145403544, + 0.9834989428520202, + 1.020540710290273, + 1.0587744156519572, + 1.095594831307729, + 1.133586017290751, + 1.1701006809870402, + 1.2078976273536681, + 1.2447497924168904, + 1.282833707332611, + 1.3195554494857789, + 1.3575769742329915, + 1.4052613894144694, + 1.4433971405029298, + 1.4801481088002524, + 1.5182779908180237, + 1.5550084471702577, + 1.5929984013239542, + 1.6298532128334045, + 1.6678994615872702, + 1.7044124046961466, + 1.742259148756663, + 1.7788713653882344, + 1.8168433388074239, + 1.853508198261261, + 1.8916389187177023 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 4.857228608082785, + 4.703369993913554, + 4.554385038174705, + 4.410119361817498, + 4.270423475937018, + 4.135152626870936, + 4.004166646204948, + 3.8773298055294743, + 3.7545106757971203, + 3.635581991135161, + 3.5204205169719223, + 3.4089069223404307, + 3.3009256562269913, + 3.1963648278365793, + 3.095116090650967, + 2.9970745301594564, + 2.9021385551458683, + 2.8102097924191565, + 2.721192984878554, + 2.6349958928076282, + 2.551529198294969, + 2.4707064126824654, + 2.392443786945275, + 2.316660224910611, + 2.2432771992254317, + 2.1722186699859583, + 2.1034110059446958, + 2.036782908213323, + 1.972265336382383, + 1.909791436981222, + 1.8492964742040474, + 1.7907177628303228, + 1.7339946032699887, + 1.679068218666207, + 1.6258816939904517, + 1.574379917066834, + 1.5245095214645599, + 1.476218831199331, + 1.4294578071864006, + 1.3841779953897928, + 1.340332476613958, + 1.2978758178858367, + 1.2567640253769616, + 1.216954498816801, + 1.1784059873501214, + 1.141078546792618, + 1.1049334982405283, + 1.069933387991332, + 1.0360419487340158, + 1.003224061968681 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12485561395684878, + 0.11367421845595042, + 0.1393864837785562, + 0.12826512133081755, + 0.13371374706427255, + 0.12623665109276772, + 0.12074708690245946, + 0.11312085886796315 + ], + [ + 0.12929308786988258, + 0.11662843823432922, + 0.13638648142417273, + 0.12217692534128825, + 0.12327748412887256, + 0.12696860109766325, + 0.13329854980111122, + 0.1119702123105526 + ], + [ + 0.13398320972919464, + 0.1141152170797189, + 0.13255895425875983, + 0.12285911291837692, + 0.11669069776932399, + 0.1255652792751789, + 0.13464067627986273, + 0.11958662047982216 + ], + [ + 0.14168754716714224, + 0.1169067124525706, + 0.13430956502755484, + 0.12764071921507517, + 0.10625544935464859, + 0.12215130031108856, + 0.13621165851751962, + 0.11483683312932651 + ], + [ + 0.13145407289266586, + 0.12338733673095703, + 0.13988637800017992, + 0.1262051115433375, + 0.10298016046484311, + 0.1273453844090303, + 0.13716531420747438, + 0.11157601947585742 + ], + [ + 0.12532757222652435, + 0.12609850615262985, + 0.13879147544503212, + 0.12502315764625868, + 0.1176831026871999, + 0.12558555975556374, + 0.1358214591940244, + 0.10566893219947815 + ], + [ + 0.12476529305179913, + 0.1200369397799174, + 0.13493291661143303, + 0.12674071143070856, + 0.13195232674479485, + 0.12349242468674977, + 0.12947026267647743, + 0.10860891764362653 + ], + [ + 0.1257908046245575, + 0.12228637437025706, + 0.12306807314356168, + 0.1317208024362723, + 0.12956678370634714, + 0.12492959449688594, + 0.11894858380158742, + 0.12368876859545708 + ], + [ + 0.12341478963692983, + 0.12422532339890797, + 0.1197039857506752, + 0.13112345710396767, + 0.12947327519456545, + 0.12645240376393, + 0.11604409540692966, + 0.12956243008375168 + ], + [ + 0.12547581642866135, + 0.12763222803672156, + 0.1239625724653403, + 0.12617246061563492, + 0.13127280895908675, + 0.1218232586979866, + 0.117293285826842, + 0.1263673429687818 + ], + [ + 0.12698695063591003, + 0.12876654788851738, + 0.12788285439213118, + 0.12748648350437483, + 0.12781483804186186, + 0.1216656764348348, + 0.11484055717786153, + 0.1245558721323808 + ], + [ + 0.12584819768865904, + 0.12440149113535881, + 0.12666124353806177, + 0.12375274176398914, + 0.1291299027701219, + 0.1182375041147073, + 0.12045771504441898, + 0.13151097918550173 + ], + [ + 0.12211241448918979, + 0.12476342419783275, + 0.12939738109707832, + 0.12354674438635509, + 0.13137834767500559, + 0.11144621049364407, + 0.12395542984207471, + 0.13339982430140176 + ], + [ + 0.12572511161367098, + 0.1204067828754584, + 0.11719127496083577, + 0.12282430877288182, + 0.12182608370979627, + 0.13329841196537018, + 0.12436861669023831, + 0.1343591809272766 + ], + [ + 0.12376925597588222, + 0.11203268046180408, + 0.11202949533859889, + 0.12679147347807884, + 0.12852568800250688, + 0.1322839061419169, + 0.11925324300924937, + 0.14531399557987848 + ], + [ + 0.12788191437721252, + 0.12463614717125893, + 0.12495853876074155, + 0.11307935044169426, + 0.12449628487229347, + 0.12174329658349355, + 0.11947273338834445, + 0.14373150716225305 + ], + [ + 0.12148624286055565, + 0.13142489393552145, + 0.12118023013075192, + 0.11254092430075009, + 0.1266097662349542, + 0.13381477197011313, + 0.1168867473800977, + 0.13605618104338646 + ], + [ + 0.11908402231832345, + 0.11817820246020953, + 0.1342220070461432, + 0.12563073014219603, + 0.11666146541635196, + 0.13693335776527724, + 0.1105535663664341, + 0.13873642683029175 + ], + [ + 0.1275100608666738, + 0.11007784431179364, + 0.13981718694170317, + 0.11683579285939534, + 0.12403254707654317, + 0.13320632403095564, + 0.11032865320642789, + 0.13819136222203574 + ], + [ + 0.12579328194260597, + 0.12497226024667422, + 0.1237871100505193, + 0.12513664861520132, + 0.11656902730464935, + 0.12758862723906836, + 0.12390752633412679, + 0.13224529971679053 + ], + [ + 0.12160300711790721, + 0.12060223892331123, + 0.11166134104132652, + 0.13417561103900275, + 0.13546468690037727, + 0.12476549421747525, + 0.12246636549631755, + 0.12926102677981058 + ], + [ + 0.12879217664400736, + 0.13069551065564156, + 0.11821090057492256, + 0.1229519322514534, + 0.11497649550437927, + 0.1315327857931455, + 0.12655684848626456, + 0.1262830967704455 + ], + [ + 0.12462397168080012, + 0.12686825667818388, + 0.12818598002195358, + 0.12869665895899138, + 0.12085322414835294, + 0.1286333203315735, + 0.11629413440823555, + 0.12584421907862028 + ], + [ + 0.12498151262601216, + 0.1316138133406639, + 0.12565426900982857, + 0.11612492675582568, + 0.12732199455300966, + 0.11626806482672691, + 0.12785907089710236, + 0.13017612198988596 + ], + [ + 0.12428018202384312, + 0.13228718439737955, + 0.13365565364559492, + 0.115531870474418, + 0.12151946065326531, + 0.11310807739694913, + 0.1291658654808998, + 0.1304514743387699 + ], + [ + 0.1153706523279349, + 0.12369164576133092, + 0.12635996316870055, + 0.1250646635890007, + 0.13067020227511725, + 0.12725465248028436, + 0.1277061328291893, + 0.12388186032573383 + ], + [ + 0.1171652947862943, + 0.12504743536313376, + 0.12780422841509184, + 0.13203463703393936, + 0.13035098587473234, + 0.1192055381834507, + 0.11782566706339519, + 0.13056597361962 + ], + [ + 0.12759397675593695, + 0.12865092729528746, + 0.11505670845508575, + 0.12240302935242653, + 0.1223981926838557, + 0.1283405969540278, + 0.12461712335546811, + 0.13093921914696693 + ], + [ + 0.12317551548282306, + 0.12624570727348328, + 0.1165752944846948, + 0.1222602128982544, + 0.1240575263897578, + 0.13042164221405983, + 0.1264446216324965, + 0.13081925238172212 + ], + [ + 0.12798567488789558, + 0.12160845597585042, + 0.1304963454604149, + 0.12116737787922223, + 0.1246810182929039, + 0.12305398782094319, + 0.1242618424197038, + 0.12674507250388464 + ], + [ + 0.12724032128850618, + 0.12301877637704213, + 0.1321321241557598, + 0.11895551905035973, + 0.1264340616762638, + 0.12383625656366348, + 0.11782035107413928, + 0.13056235884626707 + ], + [ + 0.12711667145291963, + 0.11668947090705235, + 0.13059024761120477, + 0.13473080222805342, + 0.11373796314001083, + 0.12449256579081218, + 0.12331300353010495, + 0.12932905678947768 + ], + [ + 0.1303877184788386, + 0.1224971463282903, + 0.1277124620974064, + 0.13271591688195863, + 0.11641449853777885, + 0.1219242699444294, + 0.12594126909971237, + 0.12240649511416753 + ], + [ + 0.12807516877849898, + 0.12466320519646008, + 0.12422435482343037, + 0.12767142554124197, + 0.1266568787395954, + 0.1230384608109792, + 0.11985690767566363, + 0.12581336994965872 + ], + [ + 0.12971856941779455, + 0.11915318295359612, + 0.1241951510310173, + 0.12639003743728003, + 0.1266961619257927, + 0.12051826963822047, + 0.12865556528170904, + 0.12467281768719356 + ], + [ + 0.12803923462828, + 0.12925764297445616, + 0.11308783416946729, + 0.12089732040961583, + 0.12992224593957266, + 0.11935270080963771, + 0.12622397392988205, + 0.1332188161710898 + ], + [ + 0.1287880577147007, + 0.1298906368513902, + 0.11492066582043965, + 0.11987444013357162, + 0.13056172182162604, + 0.11539661263426144, + 0.12775215630729994, + 0.13281550258398056 + ], + [ + 0.13005571688214937, + 0.11925998205939929, + 0.11906168113152187, + 0.11963683118422826, + 0.12606584653258324, + 0.1287190467119217, + 0.12141153340538342, + 0.13578914105892181 + ], + [ + 0.12974248826503754, + 0.12470538169145584, + 0.11994581048687299, + 0.12303535764416058, + 0.11913976569970448, + 0.12873672197262445, + 0.12205617502331734, + 0.13263809184233347 + ], + [ + 0.12514144430557886, + 0.12544372429450354, + 0.12450891360640526, + 0.12541157628099123, + 0.1263533333937327, + 0.12651332964499792, + 0.12489095702767372, + 0.1217364991704623 + ], + [ + 0.1199805997312069, + 0.13138659546772638, + 0.11902318273981412, + 0.13307704652349153, + 0.12492071216305096, + 0.12559786066412926, + 0.12405492613712947, + 0.12195884560545285 + ], + [ + 0.12752681970596313, + 0.1236291912694772, + 0.11774665738145511, + 0.11949027826388676, + 0.1307892625530561, + 0.12515263880292574, + 0.12964054693778357, + 0.1260243703921636 + ], + [ + 0.12448018665115039, + 0.12275078768531482, + 0.1167822852730751, + 0.12093901261687279, + 0.12856840466459593, + 0.12701980397105217, + 0.12901113430658975, + 0.13044816131393114 + ], + [ + 0.12563513840238252, + 0.12517516439159712, + 0.12498832990725835, + 0.13020832960804304, + 0.12488080933690071, + 0.1204293929040432, + 0.1219904621442159, + 0.12669214606285095 + ], + [ + 0.12766009817520776, + 0.1252334030965964, + 0.1199633739888668, + 0.13300161063671112, + 0.12741689011454582, + 0.11826380218068759, + 0.11822403470675151, + 0.1302365499238173 + ], + [ + 0.1244881587723891, + 0.12193801378210385, + 0.1251870058476925, + 0.12510514756043753, + 0.1258563076456388, + 0.12090179945031802, + 0.13051972165703773, + 0.12600362052520117 + ], + [ + 0.12271783625086148, + 0.1192563995718956, + 0.12455520903070767, + 0.13248226419091225, + 0.12370283404986064, + 0.1182238794863224, + 0.13488741715749106, + 0.1241739292939504 + ], + [ + 0.11648796871304512, + 0.12621634329358736, + 0.13024556636810303, + 0.11987082163492839, + 0.12361180782318115, + 0.12266224871079127, + 0.12621753911177316, + 0.134687473376592 + ], + [ + 0.12044015650947888, + 0.12713673834999403, + 0.12766151378552118, + 0.12339554727077484, + 0.12033036102851231, + 0.1202590415875117, + 0.12801681210597357, + 0.1327595834930738 + ], + [ + 0.1297115795314312, + 0.12595757593711218, + 0.12865152955055237, + 0.11995997528235118, + 0.12986688315868378, + 0.11820754657189052, + 0.12342486158013344, + 0.12421980996926625 + ] + ], + "load_balancing_losses": [ + 0.06004593931138515, + 0.0600339375436306, + 0.060048482939600946, + 0.060137690603733064, + 0.0603069581091404, + 0.06034286506474018, + 0.06020378060638905, + 0.06007212102413177, + 0.06001120358705521, + 0.06000030003488064, + 0.06001151688396931, + 0.06001419834792614, + 0.06002929545938969, + 0.060153701528906825, + 0.060527877509593965, + 0.06061089225113392, + 0.060722329467535016, + 0.06101667806506157, + 0.06073810942471027, + 0.06076507270336151, + 0.060854962468147276, + 0.06103549487888813, + 0.06076062396168709, + 0.06106609776616097, + 0.06036446318030357, + 0.06077891327440739, + 0.06058696247637272, + 0.060572251304984094, + 0.060425080731511115, + 0.06038157232105732, + 0.060350871086120604, + 0.06036800928413868, + 0.06066816747188568, + 0.060368667170405385, + 0.060314373672008516, + 0.06036770828068257, + 0.06024561338126659, + 0.06031039804220199, + 0.060259034112095836, + 0.06028021611273289, + 0.06022758223116398, + 0.060240605100989345, + 0.06026870645582676, + 0.06022492237389088, + 0.060247373208403586, + 0.06015990637242794, + 0.060203001648187635, + 0.060231467336416246, + 0.06011442840099335, + 0.060170086473226546 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.13156402111053467, + 0.12605291853348413, + 0.12996131802598634, + 0.11828402305642764, + 0.13217767824729285, + 0.11647970353563626, + 0.12244507918755214, + 0.12303501988450687 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/schedule_linear/metrics.json b/experiments/exp10_routing_temperature_specialization/results/schedule_linear/metrics.json new file mode 100644 index 0000000..3f77f88 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/schedule_linear/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "schedule_linear", + "description": "Linear decay from 5.0 \u2192 1.0", + "temperature": 5.0, + "temperature_schedule": "linear", + "final_metrics": { + "val_loss": 24.145434989524816, + "val_accuracy": 0.01640931313229101, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.070031690597534, + 4.465091609954834, + 3.188120889663696, + 1.8695449113845826, + 0.7625617831945419, + 0.3557862102985382, + 0.22352709919214248, + 0.16159185469150544, + 0.11373347714543343, + 0.09363114908337593, + 0.04962252601981163, + 0.017872374691069125, + 0.01393712405115366, + 0.011232817452400923, + 0.010444090329110622, + 0.010588031448423862, + 0.00978593984618783, + 0.008294070605188608, + 0.008443481894209981, + 0.008057454833760858, + 0.009717215597629548, + 0.0011033077666070313, + 0.0007793460215907543, + 0.0007487315859179943, + 0.0008732144429814071, + 0.0008515115652699024, + 0.0007152528705773875, + 0.0005919538787566125, + 0.0005263863888103515, + 0.00046693679178133607, + 0.00037780235579703, + 0.00038398922915803266, + 0.0002719390075071715, + 0.00027064195164712147, + 0.0002750781291979365, + 0.0002567486691987142, + 0.00023856690240791066, + 0.0002285146591020748, + 0.00021825013682246207, + 0.00023597993422299623, + 0.0002029337629210204, + 0.0002184864366427064, + 0.00015284798719221725, + 0.00016130286239786075, + 0.00017042641993612052, + 0.00019311631040181966, + 0.00017337404715362937, + 0.0001750983894453384, + 0.00017176414839923382, + 0.0001713901772745885 + ], + "val_losses": [ + 10.773725516383303, + 10.736913546235318, + 10.739020883405166, + 11.276839121491665, + 12.462660425543364, + 14.965424824097974, + 16.712340614399725, + 19.379005937609993, + 20.913550366782466, + 22.861593232980464, + 23.895618209569278, + 24.852926065559522, + 25.1373522728155, + 25.19762152412334, + 25.148491128173397, + 24.31815361049908, + 24.481804574757497, + 24.511629515738875, + 24.47874001816389, + 24.338014723976592, + 24.649533955873956, + 24.286741762194954, + 24.26859600215413, + 23.86274015608609, + 23.92851288327059, + 24.124601134984317, + 24.117018622138897, + 24.21431399234192, + 24.34963379310635, + 24.343611282510388, + 24.2570842500289, + 24.29744011437514, + 24.32507728603619, + 24.339487534108518, + 24.346734987130855, + 24.35856858917344, + 24.20420335291131, + 24.35857063131703, + 24.276608200882013, + 24.343424187110926, + 24.25743609182405, + 24.30380964110681, + 24.32124428361549, + 24.23489137925866, + 24.134092027643966, + 24.241847135995386, + 24.149835242820714, + 24.209763260696466, + 24.171414324757066, + 24.145434989524816 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016423143147573177, + 0.01640931313229101, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.01640931313229101, + 0.016402398124649928, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101 + ], + "val_perplexities": [ + 47749.57839201517, + 46023.782200780624, + 46120.87209234769, + 78971.24844660138, + 258502.4435119428, + 3157922.157207594, + 18116625.569713347, + 260732604.84343988, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.029142594337463378, + 0.07024418910344442, + 0.10747782786687216, + 0.14619672695795696, + 0.18348251978556315, + 0.22264109055201212, + 0.25998326142628986, + 0.2985979199409485, + 0.33575506607691447, + 0.37417654593785604, + 0.4111841599146525, + 0.4495713909467061, + 0.48660025199254353, + 0.5252195437749226, + 0.5708261211713155, + 0.6093855738639832, + 0.646639327208201, + 0.6851914564768473, + 0.7225726087888081, + 0.7612001379330953, + 0.7983153343200684, + 0.8372756759325664, + 0.8745674093564352, + 0.9132968584696451, + 0.9504027048746745, + 0.9885418812433878, + 1.025281063715617, + 1.0634507854779562, + 1.1004807154337566, + 1.1386935154596964, + 1.1756099939346314, + 1.2136433204015096, + 1.2504660725593566, + 1.2885148723920186, + 1.3253190358479818, + 1.36347439289093, + 1.400240703423818, + 1.4466269453366598, + 1.4835795283317565, + 1.5218237916628519, + 1.5590248147646586, + 1.5974567453066508, + 1.6347886284192403, + 1.6732629855473837, + 1.7106112599372865, + 1.7492258270581564, + 1.7862712860107421, + 1.8248716910680136, + 1.8621562401453653, + 1.901061487197876 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 4.928, + 4.848, + 4.768, + 4.688, + 4.608, + 4.5280000000000005, + 4.448, + 4.368, + 4.288, + 4.208, + 4.128, + 4.048, + 3.968, + 3.888, + 3.808, + 3.7279999999999998, + 3.6479999999999997, + 3.568, + 3.488, + 3.408, + 3.3280000000000003, + 3.248, + 3.168, + 3.088, + 3.008, + 2.928, + 2.848, + 2.768, + 2.688, + 2.608, + 2.528, + 2.448, + 2.368, + 2.288, + 2.208, + 2.128, + 2.048, + 1.968, + 1.888, + 1.8079999999999998, + 1.7280000000000002, + 1.6480000000000001, + 1.568, + 1.488, + 1.408, + 1.3279999999999998, + 1.2480000000000002, + 1.1680000000000001, + 1.088, + 1.008 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12469794352849324, + 0.11388363565007846, + 0.13924207414189974, + 0.1281409872074922, + 0.13378159701824188, + 0.12630454699198404, + 0.12079832702875137, + 0.11315067609151204 + ], + [ + 0.12935267885526022, + 0.11619463935494423, + 0.13630455483992895, + 0.12198563044269879, + 0.12278859689831734, + 0.12778587763508162, + 0.13311370834708214, + 0.11247409383455913 + ], + [ + 0.13470966493089995, + 0.11352355778217316, + 0.13292831679185232, + 0.12197398766875267, + 0.1164175992210706, + 0.12625022853414217, + 0.1348417835930983, + 0.11935463547706604 + ], + [ + 0.14101043716073036, + 0.116659810145696, + 0.135072011500597, + 0.12756881614526114, + 0.10644927745064099, + 0.12189064671595891, + 0.13637045274178186, + 0.1149783122042815 + ], + [ + 0.13111937418580055, + 0.12826944390932718, + 0.14080193638801575, + 0.12510155389706293, + 0.10103541612625122, + 0.12472505867481232, + 0.1342096192141374, + 0.11473733435074489 + ], + [ + 0.12698482473691305, + 0.1308593861758709, + 0.14028861622015634, + 0.12340415517489116, + 0.11703313514590263, + 0.12604564676682153, + 0.13153670728206635, + 0.10384729877114296 + ], + [ + 0.12387339149912198, + 0.12169012427330017, + 0.13627174372474352, + 0.12566579381624857, + 0.13129279389977455, + 0.12439063563942909, + 0.12971521417299905, + 0.10710007821520169 + ], + [ + 0.12342374896009763, + 0.12326434006293614, + 0.12522225951155028, + 0.12855522831281027, + 0.13235349828998247, + 0.12650818253556886, + 0.12175159901380539, + 0.11892091731230418 + ], + [ + 0.12469503407677014, + 0.124389352897803, + 0.1220089594523112, + 0.13039343183239302, + 0.13061834002534548, + 0.12594481681783995, + 0.11773145322998364, + 0.12421839560071628 + ], + [ + 0.12188258891304334, + 0.12610707556207976, + 0.11911704267064731, + 0.12833569198846817, + 0.1283252090215683, + 0.1251093583802382, + 0.12106883401672046, + 0.13005397965510687 + ], + [ + 0.12068058922886848, + 0.12718840315937996, + 0.12203726917505264, + 0.12739302342136702, + 0.12486746907234192, + 0.12598958238959312, + 0.12311563516656558, + 0.12872780114412308 + ], + [ + 0.12502824142575264, + 0.12448026860753696, + 0.12404221668839455, + 0.12211523080865543, + 0.12746370459596315, + 0.12292117377122243, + 0.126586036135753, + 0.12736288458108902 + ], + [ + 0.12622454638282457, + 0.12528113151590028, + 0.12829751893877983, + 0.12254100292921066, + 0.13030771166086197, + 0.11728355164329211, + 0.12502766400575638, + 0.12503664195537567 + ], + [ + 0.12945967664321265, + 0.1228997843960921, + 0.11708857988317807, + 0.1293798660238584, + 0.1226284106572469, + 0.1310043310125669, + 0.12212288752198219, + 0.1254162254432837 + ], + [ + 0.13135185837745667, + 0.11744219685594241, + 0.11930737520257632, + 0.12207083155711491, + 0.11953859155376752, + 0.13493581861257553, + 0.11983300124605496, + 0.13552011052767435 + ], + [ + 0.12598646928866705, + 0.10853325948119164, + 0.11716299007336299, + 0.13630649199088415, + 0.12420849998792012, + 0.13760811711351076, + 0.11915805439154308, + 0.13103589663902918 + ], + [ + 0.12387570614616077, + 0.11745924750963847, + 0.12718180318673453, + 0.12296153977513313, + 0.12002382427453995, + 0.1362664078672727, + 0.11539214104413986, + 0.13683910171190897 + ], + [ + 0.11757323394219081, + 0.12475062906742096, + 0.11465857550501823, + 0.12496423969666164, + 0.13400891919930777, + 0.13301980371276537, + 0.12706253801782927, + 0.12396182492375374 + ], + [ + 0.11657704785466194, + 0.12942078337073326, + 0.12024860580762227, + 0.1262603352467219, + 0.12149768074353536, + 0.13377471764882407, + 0.12495888024568558, + 0.12726170693834624 + ], + [ + 0.12636959056059519, + 0.12664777537186941, + 0.10938482731580734, + 0.12277675171693166, + 0.12742514287432036, + 0.1284345549841722, + 0.12372781708836555, + 0.13523330291112265 + ], + [ + 0.12464109311501186, + 0.12064798052112262, + 0.12338299055894215, + 0.11302951474984486, + 0.13176817322770754, + 0.13359948620200157, + 0.115101491411527, + 0.13782903800408045 + ], + [ + 0.12713224068284035, + 0.12206538021564484, + 0.1373858725031217, + 0.11394193520148595, + 0.1279068092505137, + 0.1240323359767596, + 0.12591709941625595, + 0.12161810199419658 + ], + [ + 0.12685105949640274, + 0.12461868425210317, + 0.13376359517375627, + 0.11912581076224645, + 0.12921296308437982, + 0.12590766325592995, + 0.1295206012825171, + 0.11099938799937566 + ], + [ + 0.1285433111091455, + 0.11613670115669568, + 0.12639348953962326, + 0.12146028007070224, + 0.11565679187575977, + 0.1336175041894118, + 0.12855219841003418, + 0.1296394889553388 + ], + [ + 0.13442979380488396, + 0.11940562476714452, + 0.1252367409567038, + 0.11640631780028343, + 0.12729826817909876, + 0.1192333089808623, + 0.1261553280055523, + 0.13183439895510674 + ], + [ + 0.1230368788043658, + 0.12078605592250824, + 0.12520241240660349, + 0.12587020049492517, + 0.1306257943312327, + 0.12741605192422867, + 0.12442437807718913, + 0.12263799210389455 + ], + [ + 0.1151316116253535, + 0.12293659150600433, + 0.12772803008556366, + 0.12716157113512358, + 0.12658913433551788, + 0.12907299771904945, + 0.1185876689851284, + 0.1327921859920025 + ], + [ + 0.13161567598581314, + 0.11673690006136894, + 0.12689935291806856, + 0.13067953288555145, + 0.12383261322975159, + 0.125590480864048, + 0.11484868576129277, + 0.12979653850197792 + ], + [ + 0.129223412523667, + 0.11433490241567294, + 0.12556362648804983, + 0.12273898969093959, + 0.13604150091608366, + 0.11973029747605324, + 0.1160333938896656, + 0.13633364563186964 + ], + [ + 0.1235324318210284, + 0.1258481778204441, + 0.11807644988099734, + 0.12651588146885237, + 0.1210456316669782, + 0.13223377615213394, + 0.12572738031546274, + 0.12702003493905067 + ], + [ + 0.12939325844248137, + 0.12632202729582787, + 0.12058161695798238, + 0.1230415366590023, + 0.12147464603185654, + 0.12438251450657845, + 0.1223433328171571, + 0.13246084252993265 + ], + [ + 0.12088499342401822, + 0.13107529655098915, + 0.11511276910702388, + 0.12342227747042973, + 0.1297032249470552, + 0.12656418730815253, + 0.12288010865449905, + 0.1303569090863069 + ], + [ + 0.1207020990550518, + 0.12032982334494591, + 0.11619511246681213, + 0.12561865275104842, + 0.1320885606110096, + 0.12892189621925354, + 0.125174880027771, + 0.13096874207258224 + ], + [ + 0.11659535393118858, + 0.12108862151702245, + 0.12394801775614421, + 0.12267851456999779, + 0.12799211591482162, + 0.12871196369330087, + 0.12417874361077945, + 0.1348064343134562 + ], + [ + 0.12341472133994102, + 0.12303343042731285, + 0.11957977960507075, + 0.11796309426426888, + 0.12829195583860079, + 0.13497542341550192, + 0.12229030206799507, + 0.13045107324918112 + ], + [ + 0.1215393381814162, + 0.12383247663577397, + 0.11887773250540097, + 0.1221556340654691, + 0.1312519982457161, + 0.12900342543919882, + 0.12437451630830765, + 0.12896466627717018 + ], + [ + 0.12197028597195943, + 0.12078525871038437, + 0.12600433205564818, + 0.12051290397842725, + 0.12822645157575607, + 0.1305689588189125, + 0.12612008675932884, + 0.12581149861216545 + ], + [ + 0.12489322697122891, + 0.13288636008898416, + 0.11925294995307922, + 0.12400917212168376, + 0.13207783550024033, + 0.1301965775589148, + 0.12007814397414525, + 0.11660551528135936 + ], + [ + 0.11834153532981873, + 0.13366939375797907, + 0.12364806731541951, + 0.12686879808704057, + 0.12241518124938011, + 0.12736316894491514, + 0.12264170621832211, + 0.12505192930499712 + ], + [ + 0.12399844204386075, + 0.12757805486520132, + 0.1231104942659537, + 0.10859241088231404, + 0.1244613379240036, + 0.1337025041381518, + 0.12963677570223808, + 0.12891975169380507 + ], + [ + 0.12426480650901794, + 0.13280999039610228, + 0.11895533526937167, + 0.10835339998205502, + 0.11907912418246269, + 0.13236587742964426, + 0.12689227859179178, + 0.13727894673744837 + ], + [ + 0.12530652433633804, + 0.1270507238805294, + 0.13280983393390974, + 0.11662950615088145, + 0.1235556664566199, + 0.12147925421595573, + 0.1248152069747448, + 0.12835306425889334 + ], + [ + 0.12291006992260615, + 0.12989936769008636, + 0.1310910421113173, + 0.11993459363778432, + 0.1250849279264609, + 0.11856419468919437, + 0.12336839859684308, + 0.12914716949065527 + ], + [ + 0.12713438520828882, + 0.12698238467176756, + 0.11913535992304485, + 0.12211070209741592, + 0.12439057727654775, + 0.12495543683568637, + 0.12895508607228598, + 0.12633585557341576 + ], + [ + 0.123618067552646, + 0.12467818210522334, + 0.1254742443561554, + 0.1282066690425078, + 0.12193657457828522, + 0.12189907332261403, + 0.12139041970173518, + 0.13279655203223228 + ], + [ + 0.13040602703889212, + 0.12990628803769746, + 0.12115683530767758, + 0.1239775816599528, + 0.12245930110414822, + 0.12172008926669757, + 0.12624084452788034, + 0.12413281450668971 + ], + [ + 0.12633098910252252, + 0.1255469135940075, + 0.13031799842913946, + 0.12539088850220045, + 0.1254921294748783, + 0.12254236141840617, + 0.12154730906089146, + 0.12283118565877278 + ], + [ + 0.13056531051794687, + 0.12650401641925177, + 0.1308995820581913, + 0.12618571271499, + 0.11668508003155391, + 0.12460313737392426, + 0.12299310540159543, + 0.12156381706396739 + ], + [ + 0.12268947189052899, + 0.1283542439341545, + 0.13488017891844115, + 0.12108633667230606, + 0.11830473194519679, + 0.12569372355937958, + 0.1217837780714035, + 0.1272073102494081 + ], + [ + 0.1296000530322393, + 0.12586997573574385, + 0.12741975113749504, + 0.11512747034430504, + 0.12619579086701074, + 0.1248207576572895, + 0.12333483497301738, + 0.12763113901019096 + ] + ], + "load_balancing_losses": [ + 0.06004537232220173, + 0.06003381907939911, + 0.06004717685282231, + 0.06013218648731709, + 0.06029795669019222, + 0.06032728105783462, + 0.06020462512969971, + 0.06007204353809357, + 0.06000982187688351, + 0.05999823845922947, + 0.06001656837761402, + 0.05998987890779972, + 0.06000747457146645, + 0.06008636653423309, + 0.06028089486062527, + 0.06038857027888298, + 0.06075892895460129, + 0.06083236113190651, + 0.06033258102834225, + 0.06042707376182079, + 0.060734807327389714, + 0.06103012971580028, + 0.06049002073705197, + 0.060631276667118074, + 0.060927897319197656, + 0.060731960088014604, + 0.06047301962971687, + 0.06034380383789539, + 0.060457519814372064, + 0.060315877199172974, + 0.06031375601887703, + 0.06038779243826866, + 0.06037567481398583, + 0.060400766879320146, + 0.06028025932610035, + 0.06033550947904587, + 0.06029500253498554, + 0.06037123613059521, + 0.060221965238451955, + 0.06025633402168751, + 0.060248232632875445, + 0.0602688517421484, + 0.06030745953321457, + 0.06019967049360275, + 0.060255687311291696, + 0.06022413447499275, + 0.06019153743982315, + 0.06019332818686962, + 0.060173381492495535, + 0.060262182354927064 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.1314996456106504, + 0.1257651982208093, + 0.1272132694721222, + 0.11209296683470409, + 0.12733851994077364, + 0.12480566650629044, + 0.12325656786561012, + 0.12802794699867567 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/schedule_step/metrics.json b/experiments/exp10_routing_temperature_specialization/results/schedule_step/metrics.json new file mode 100644 index 0000000..e75d4e2 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/schedule_step/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "schedule_step", + "description": "Step decay: 5.0 (0-100) \u2192 2.0 (100-300) \u2192 1.0 (300+)", + "temperature": 5.0, + "temperature_schedule": "step", + "final_metrics": { + "val_loss": 24.094549772174958, + "val_accuracy": 0.01640931313229101, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.07071623802185, + 4.465154457092285, + 3.1885207653045655, + 1.868643581867218, + 0.762579345703125, + 0.35614443123340606, + 0.22326562702655792, + 0.1623530715703964, + 0.11411942094564438, + 0.09362907260656357, + 0.04971583783626556, + 0.017940876819193364, + 0.014104246255010366, + 0.011159149557352066, + 0.010378857795149087, + 0.010482308035716415, + 0.009924645349383355, + 0.00837753526866436, + 0.008655276522040367, + 0.007938291877508163, + 0.00976140475831926, + 0.0010861268092412502, + 0.000795760320033878, + 0.0008412151946686208, + 0.0008460664888843894, + 0.0008643631648737937, + 0.0007437039545038715, + 0.0005996645282721147, + 0.0005187173868762329, + 0.00047464616945944724, + 0.0004646173649234697, + 0.0003870115935569629, + 0.00030004128057044, + 0.00027745208353735507, + 0.0002667202890734188, + 0.00027821202238556, + 0.000258975918404758, + 0.0002218959867605008, + 0.000233758136164397, + 0.0002433822737657465, + 0.00021195087319938465, + 0.00025273076171288266, + 0.0001758968675858341, + 0.00017917404766194523, + 0.00018666021060198545, + 0.0001948652454302646, + 0.00017928506713360548, + 0.00018170062539866195, + 0.00017144189187092706, + 0.00017479144007666038 + ], + "val_losses": [ + 10.773711312364775, + 10.737200083243973, + 10.738981432291308, + 11.27408709711405, + 12.462973793488088, + 14.961436662572854, + 16.71159933649609, + 19.37254965684439, + 20.90263010840534, + 22.885077142883947, + 23.89394095592701, + 24.84343722852296, + 25.131157345990832, + 25.10689289241292, + 24.883487546401817, + 24.954385204786966, + 24.69574846058768, + 24.299748188194034, + 24.348401794164005, + 23.848283046547177, + 24.019065290794778, + 24.214635869218267, + 24.015344033392918, + 24.133731410697155, + 24.050027806851553, + 23.896345489016692, + 24.044844327461593, + 24.205259788162717, + 24.255766865221435, + 24.16331317483747, + 23.98499810990512, + 23.78645470959559, + 23.950416659297876, + 23.9963799196924, + 24.05398943735938, + 24.17783509379141, + 24.184680379321634, + 24.214976401716577, + 24.18391570101357, + 24.27957757201717, + 24.230000512760014, + 24.21488265182441, + 24.265513861558464, + 24.155658270360725, + 24.1332480966413, + 24.23034612028843, + 24.154558525489833, + 24.244177889065693, + 24.23645980619289, + 24.094549772174958 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.01640931313229101, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.01640931313229101, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.01640931313229101, + 0.01640931313229101 + ], + "val_perplexities": [ + 47748.90016093579, + 46036.97160719573, + 46119.052608462094, + 78754.21642143441, + 258583.46258509942, + 3145352.9342664676, + 18103201.09173924, + 259054664.39767745, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.029269937674204508, + 0.07058748404184977, + 0.10781962871551513, + 0.1467361847559611, + 0.18435633182525635, + 0.22385857502619425, + 0.2618454615275065, + 0.30055872599283856, + 0.337985094388326, + 0.3768680771191915, + 0.41428521474202473, + 0.4531663497289022, + 0.4976534763971964, + 0.540140982468923, + 0.5772211154301962, + 0.6152919212977092, + 0.6521628618240356, + 0.6903137524922689, + 0.7272444049517314, + 0.7652266502380372, + 0.8018552184104919, + 0.8400063991546631, + 0.8769107977549235, + 0.9152838428815205, + 0.9524499336878459, + 0.9908791899681091, + 1.0278935511906941, + 1.0661877512931823, + 1.1031587998072305, + 1.1415211359659831, + 1.1785194913546244, + 1.217111353079478, + 1.2542162537574768, + 1.2928197145462037, + 1.3300504644711812, + 1.3694180687268576, + 1.4063974857330321, + 1.4446008563041688, + 1.4815242091814678, + 1.5304654836654663, + 1.567549459139506, + 1.6058348536491394, + 1.6432595014572144, + 1.6819604992866517, + 1.7190428813298544, + 1.7576088945070902, + 1.7947184761365256, + 1.833271034558614, + 1.8704389890034994, + 1.908950916926066 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12477366750439008, + 0.1140362670024236, + 0.13932976375023523, + 0.12814310689767203, + 0.1337856501340866, + 0.12618506203095117, + 0.12057418624560039, + 0.1131720927854379 + ], + [ + 0.1318524293601513, + 0.115696981549263, + 0.13665896654129028, + 0.12187346071004868, + 0.12179289261500041, + 0.126997459679842, + 0.13321477423111597, + 0.11191281427939732 + ], + [ + 0.13643500208854675, + 0.1142299473285675, + 0.13236294935146967, + 0.12246725956598918, + 0.11585500463843346, + 0.12551955630381903, + 0.13501791035135588, + 0.11811213319500287 + ], + [ + 0.14069602265954018, + 0.11617022504409154, + 0.1385672390460968, + 0.12719231843948364, + 0.10580382744471233, + 0.1189831333855788, + 0.13735241070389748, + 0.11523459975918134 + ], + [ + 0.13112321744362512, + 0.12335971121986707, + 0.14395072932044664, + 0.12674802045027414, + 0.10137061402201653, + 0.11972772081693013, + 0.13870752478639284, + 0.11501221731305122 + ], + [ + 0.12519129241506258, + 0.12916060661276182, + 0.14236090456446013, + 0.1250865546365579, + 0.11374425888061523, + 0.12113221486409505, + 0.1342721271018187, + 0.1090518149236838 + ], + [ + 0.1253638962904612, + 0.12098499139149983, + 0.13632624099651972, + 0.1261260323226452, + 0.13058202589551607, + 0.12358010808626811, + 0.12985783318678537, + 0.10717864831288655 + ], + [ + 0.12317144001523654, + 0.12074599787592888, + 0.12649464110533395, + 0.12831137577692667, + 0.13066353524724642, + 0.12686927616596222, + 0.12450541059176128, + 0.11923810591300328 + ], + [ + 0.1256582277516524, + 0.12262818962335587, + 0.12104182069500287, + 0.1304104228814443, + 0.1313458370665709, + 0.12400708595911662, + 0.1182579609254996, + 0.12665020922819772 + ], + [ + 0.12580947702129683, + 0.12732602035005888, + 0.1234468308587869, + 0.12613261366883913, + 0.1319932353993257, + 0.12286854907870293, + 0.1177068365116914, + 0.1247162198026975 + ], + [ + 0.13080346584320068, + 0.12697291374206543, + 0.1306106001138687, + 0.1242411881685257, + 0.1304509254793326, + 0.12068192784984906, + 0.1197231076657772, + 0.11651562402645747 + ], + [ + 0.12057893102367719, + 0.12035692979892094, + 0.11674173052112262, + 0.12889776875575384, + 0.12658174460132918, + 0.12808958192666373, + 0.12614516417185465, + 0.13260793313384056 + ], + [ + 0.11884483695030212, + 0.1222492940723896, + 0.11342890560626984, + 0.12628394116957983, + 0.12830450013279915, + 0.13156825179855028, + 0.12736538921793303, + 0.13195465380946794 + ], + [ + 0.12191858887672424, + 0.12652468557159105, + 0.11943374201655388, + 0.1234676589568456, + 0.13077114522457123, + 0.1136033721268177, + 0.12903592611352602, + 0.1352446327606837 + ], + [ + 0.12135304262240727, + 0.12679661065340042, + 0.12019409611821175, + 0.11740628133217494, + 0.1427453396221002, + 0.11463899289568265, + 0.14077800263961157, + 0.11608740563193957 + ], + [ + 0.12525992592175803, + 0.14050405969222388, + 0.1180319885412852, + 0.10591718927025795, + 0.1279467431207498, + 0.12505279978116354, + 0.12759792059659958, + 0.12968911106387773 + ], + [ + 0.11562805809080601, + 0.12839056675632796, + 0.11995011195540428, + 0.11963618795077006, + 0.12627016877134642, + 0.1318651090065638, + 0.12963976711034775, + 0.12861980001131693 + ], + [ + 0.1180493136246999, + 0.13157360379894575, + 0.13059414674838385, + 0.11437078813711803, + 0.12069108088811238, + 0.12894285718599954, + 0.12411849449078242, + 0.1316594866414865 + ], + [ + 0.1251655655602614, + 0.12215384592612584, + 0.1332979996999105, + 0.11680462459723155, + 0.1202247825761636, + 0.13291629776358604, + 0.11304903651277225, + 0.13638762260476747 + ], + [ + 0.12188639243443807, + 0.12651864935954413, + 0.12886597837011018, + 0.13183443993330002, + 0.11761683970689774, + 0.12824413801232973, + 0.1220944772164027, + 0.1229388676583767 + ], + [ + 0.11617748066782951, + 0.1268979236483574, + 0.12800212080279985, + 0.117298923432827, + 0.13653909166653952, + 0.12301822751760483, + 0.12627601251006126, + 0.12578998878598213 + ], + [ + 0.12478468691309293, + 0.12971306343873343, + 0.10870283717910449, + 0.12221584593256314, + 0.12794790044426918, + 0.1326818565527598, + 0.1281427058080832, + 0.12581086655457815 + ], + [ + 0.12738158678015074, + 0.12829257796208063, + 0.11542264744639397, + 0.11948423087596893, + 0.1264430470764637, + 0.12492277721563975, + 0.12151739746332169, + 0.13653550669550896 + ], + [ + 0.12154716501633327, + 0.12933520103494325, + 0.13164405897259712, + 0.1112684632341067, + 0.12304026012619336, + 0.12068304046988487, + 0.1279565691947937, + 0.13452501346667609 + ], + [ + 0.1194085602958997, + 0.12591787427663803, + 0.12605944772561392, + 0.12488897393147151, + 0.12758771081765494, + 0.12203469748298328, + 0.12328100949525833, + 0.1308214838306109 + ], + [ + 0.1248119759062926, + 0.12064410125215848, + 0.12917405366897583, + 0.11871604124704997, + 0.121875395377477, + 0.12454744180043538, + 0.1290691668788592, + 0.13116159538427988 + ], + [ + 0.1147387536863486, + 0.11347606033086777, + 0.12424837052822113, + 0.12912242114543915, + 0.12457458799084027, + 0.12595230465133986, + 0.13189987341562906, + 0.13598738610744476 + ], + [ + 0.1125618927180767, + 0.129627155760924, + 0.1313978247344494, + 0.12314011404911678, + 0.1253949503103892, + 0.12920589124162993, + 0.12501226862271628, + 0.12365967159469922 + ], + [ + 0.1155334102610747, + 0.12335794046521187, + 0.12765294313430786, + 0.12259259199102719, + 0.13059732566277185, + 0.12745845193664232, + 0.1184571422636509, + 0.13434997076789537 + ], + [ + 0.11972518637776375, + 0.13301543643077215, + 0.11573319882154465, + 0.12672620515028635, + 0.12475844845175743, + 0.12484427417318027, + 0.13055121898651123, + 0.12464580933252971 + ], + [ + 0.11061582838495572, + 0.13066187997659048, + 0.11816610147555669, + 0.12362361575166385, + 0.12960200384259224, + 0.12913112342357635, + 0.12355597193042438, + 0.13464323431253433 + ], + [ + 0.12715447694063187, + 0.12036458154519399, + 0.11913050462802251, + 0.12772884964942932, + 0.1263183889289697, + 0.1213305542866389, + 0.12348261227210362, + 0.13448980947335562 + ], + [ + 0.12826630348960558, + 0.12164976075291634, + 0.1291889784236749, + 0.12304915611942609, + 0.1212444044649601, + 0.12528503437836966, + 0.12156184762716293, + 0.1297542875011762 + ], + [ + 0.1277348486085733, + 0.12765240917603174, + 0.12182928870121638, + 0.11657238006591797, + 0.12257601320743561, + 0.13117648288607597, + 0.12016312157114346, + 0.1322952260573705 + ], + [ + 0.123804056396087, + 0.12006213143467903, + 0.12397123873233795, + 0.11155272523562114, + 0.12702043975392976, + 0.13208813220262527, + 0.1280679019788901, + 0.13343314081430435 + ], + [ + 0.12667405232787132, + 0.12516785909732184, + 0.1231542353828748, + 0.1210236685971419, + 0.13159372905890146, + 0.12402419870098431, + 0.1192910298705101, + 0.12907099723815918 + ], + [ + 0.12331263720989227, + 0.12576153129339218, + 0.12121884400645892, + 0.11981530239184697, + 0.13310788323481879, + 0.12612053627769151, + 0.1209803856909275, + 0.1296826476852099 + ], + [ + 0.12407490362723668, + 0.12342624738812447, + 0.12340182686845462, + 0.12667138005296388, + 0.12664742022752762, + 0.11963717515269916, + 0.12501258651415506, + 0.13112823416789374 + ], + [ + 0.12559262290596962, + 0.12638258188962936, + 0.12107307836413383, + 0.12532181292772293, + 0.1285279504954815, + 0.12038104732831319, + 0.12107313300172488, + 0.13164753591020903 + ], + [ + 0.11952805146574974, + 0.12299402306477229, + 0.1194044401248296, + 0.12491398801406224, + 0.12806501115361849, + 0.12776660919189453, + 0.1281691255668799, + 0.12915852790077528 + ], + [ + 0.11956280718247096, + 0.12458479404449463, + 0.11809740836421649, + 0.12119243418176968, + 0.12687241658568382, + 0.1254558116197586, + 0.13116390506426492, + 0.1330701895058155 + ], + [ + 0.12234697366754214, + 0.12488103533784549, + 0.12445553267995517, + 0.12650651733080545, + 0.11996771643559138, + 0.13063446432352066, + 0.12146537750959396, + 0.12974215671420097 + ], + [ + 0.12650620068113008, + 0.12319091334939003, + 0.12313208480676015, + 0.13124867031971613, + 0.12123145287235577, + 0.13128752261400223, + 0.11796652028958003, + 0.12543639168143272 + ], + [ + 0.12263242403666179, + 0.12267122293512027, + 0.12937555462121964, + 0.12656973799069723, + 0.1256453556319078, + 0.1197126420835654, + 0.12112898255387942, + 0.13226384172836939 + ], + [ + 0.1236723226805528, + 0.12298436835408211, + 0.12774880602955818, + 0.13186448564132056, + 0.12423414240280788, + 0.11872113744417827, + 0.12023921807607015, + 0.1305353045463562 + ], + [ + 0.1215864544113477, + 0.1262065333624681, + 0.12275992582241695, + 0.12551250929633775, + 0.12626762191454569, + 0.12526891132195792, + 0.1228443259994189, + 0.1295534943540891 + ], + [ + 0.12414230778813362, + 0.1261384797592958, + 0.12097220371166865, + 0.12945099423329035, + 0.12688841670751572, + 0.12433202192187309, + 0.11993373557925224, + 0.12814161678155264 + ], + [ + 0.1195494644343853, + 0.12912695358196893, + 0.12485285724202792, + 0.122922799239556, + 0.13027895614504814, + 0.12379278366764386, + 0.12089732165137927, + 0.12857863679528236 + ], + [ + 0.12431702266136806, + 0.1296307717760404, + 0.12177605802814166, + 0.12177164107561111, + 0.12856043254335722, + 0.12804650391141573, + 0.11929148559768994, + 0.12660585095485052 + ], + [ + 0.1248128612836202, + 0.12676760057608286, + 0.12432139863570531, + 0.12760947520534197, + 0.1243838481605053, + 0.11973448718587558, + 0.12222603460152943, + 0.13014407828450203 + ] + ], + "load_balancing_losses": [ + 0.06004493832588196, + 0.06003242656588555, + 0.060045579075813295, + 0.06012785471975803, + 0.06028471104800701, + 0.06030931584537029, + 0.06019735634326935, + 0.060086073353886604, + 0.0600259069353342, + 0.06000634171068668, + 0.060015855729579924, + 0.06019938327372074, + 0.06032469943165779, + 0.06041383557021618, + 0.06088683344423771, + 0.06121671535074711, + 0.06119343712925911, + 0.060738954320549966, + 0.06105032339692116, + 0.06099824942648411, + 0.060721630230546, + 0.06077244952321052, + 0.06074267402291298, + 0.06088492423295975, + 0.060579366609454156, + 0.0604194276034832, + 0.060493505001068114, + 0.0605908315628767, + 0.0605136726051569, + 0.0603449359536171, + 0.060860810428857805, + 0.06069913767278194, + 0.06077303811907768, + 0.06054488755762577, + 0.060510614141821864, + 0.06046343073248863, + 0.0602861650288105, + 0.0603190965950489, + 0.060254532098770144, + 0.060238535329699514, + 0.06027446873486042, + 0.06026398316025734, + 0.060227270051836965, + 0.06027206815779209, + 0.060248592495918275, + 0.06021219827234745, + 0.06019999235868454, + 0.060231142491102216, + 0.06003213934600353, + 0.06010932885110378 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.12483104318380356, + 0.12675206611553827, + 0.12467950582504272, + 0.1291585514942805, + 0.12385845184326172, + 0.11746429279446602, + 0.12175804749131203, + 0.1314978152513504 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/temp_0.5/metrics.json b/experiments/exp10_routing_temperature_specialization/results/temp_0.5/metrics.json new file mode 100644 index 0000000..120b730 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/temp_0.5/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "temp_0.5", + "description": "Very sharp routing (strong exploitation)", + "temperature": 0.5, + "temperature_schedule": null, + "final_metrics": { + "val_loss": 23.70006944096973, + "val_accuracy": 0.016395483117008846, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.144560050964356, + 4.526667594909668, + 3.237144422531128, + 1.90134859085083, + 0.7837494105100632, + 0.3667227953672409, + 0.22926726788282395, + 0.16440271735191345, + 0.11588505581021309, + 0.09812678918242454, + 0.051057306677103044, + 0.01911924220621586, + 0.015134618338197469, + 0.01223741751164198, + 0.011992707569152116, + 0.01141246035695076, + 0.010454405844211579, + 0.008931473782286047, + 0.009475251939147711, + 0.008736734511330724, + 0.010804664529860019, + 0.0012871234386693687, + 0.0009459709574002773, + 0.000948911055456847, + 0.0010725339932832866, + 0.0010384018125478179, + 0.0009104772587306797, + 0.0007092287065461278, + 0.0006396141048753634, + 0.0005995057494146749, + 0.0005254744959529489, + 0.00045139196154195813, + 0.0004262233996996656, + 0.00042784894758369775, + 0.0004688720509875566, + 0.0003861472563585266, + 0.0003662309580249712, + 0.0003619653667556122, + 0.00040222569077741355, + 0.00041492521413601937, + 0.0003454021061770618, + 0.00036678869219031185, + 0.00026380949857411905, + 0.0002892925447667949, + 0.0002916311626904644, + 0.00030481590365525333, + 0.0002845260431058705, + 0.0002980848774313927, + 0.0002694806331419386, + 0.00026188207411905753 + ], + "val_losses": [ + 10.775087359937256, + 10.734495075347146, + 10.735080395486245, + 11.25632559031564, + 12.413784128195827, + 14.875232838067907, + 16.61496204200987, + 19.166657329869356, + 20.566684770078627, + 22.5847883392981, + 23.434426169513394, + 24.257394541278746, + 24.535657154797665, + 23.970591912421238, + 23.997228689834003, + 23.904251866964064, + 23.599130313725016, + 22.89697472535258, + 22.839926271472297, + 22.670232294305052, + 22.608450839039293, + 23.095901239887144, + 23.013133840931598, + 23.044736437578504, + 23.077098057885053, + 23.3490033503556, + 23.500254883783025, + 23.656872105682698, + 23.855227130040685, + 23.952079227029646, + 23.84635386787118, + 23.881982479836832, + 23.89619252370019, + 24.04246560086631, + 24.012590718353596, + 23.985630123017113, + 23.99403113779668, + 23.927537540664943, + 23.86883792944595, + 23.853859331919534, + 23.827011984565655, + 23.837052186891807, + 23.867749096226778, + 23.79155464981133, + 23.748872878273467, + 23.85055281163947, + 23.796526447201785, + 23.831372682281604, + 23.82838166307645, + 23.70006944096973 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846 + ], + "val_perplexities": [ + 47814.65014626992, + 45912.6095115082, + 45939.49095285087, + 77367.77199056088, + 246171.5988893455, + 2885569.497916141, + 16435629.093374733, + 210849956.08258787, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.05945862929026286, + 0.09970323244730632, + 0.13675736983617146, + 0.17485446532567342, + 0.21180752515792847, + 0.24993139902750652, + 0.2868975003560384, + 0.32509169181187947, + 0.36561097304026285, + 0.4082769195238749, + 0.4452169934908549, + 0.4834104855855306, + 0.5202555855115255, + 0.5587355971336365, + 0.5958332578341167, + 0.6342577139536539, + 0.6712398568789164, + 0.7095060507456462, + 0.7464210232098897, + 0.7846049865086874, + 0.8213121136029561, + 0.8597021738688151, + 0.8966187040011088, + 0.9348151763280232, + 0.971663228670756, + 1.0100767652193705, + 1.057422113418579, + 1.0955376664797465, + 1.1323136885960896, + 1.1705454905827841, + 1.207456644376119, + 1.2454842964808146, + 1.2822751919428508, + 1.3204218745231628, + 1.35732581615448, + 1.395689602692922, + 1.4326767802238465, + 1.470899470647176, + 1.5076398054758708, + 1.545754595597585, + 1.5828205029169717, + 1.6217491110165914, + 1.6596386273701986, + 1.6985048532485962, + 1.7356390953063965, + 1.7740787943204244, + 1.8207136432329813, + 1.8590702374776205, + 1.8967880884806314, + 1.9357895453770955 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5, + 0.5 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12489314253131549, + 0.11538034801681836, + 0.1392521932721138, + 0.12765915070970854, + 0.13447885463635126, + 0.12497525786360104, + 0.1217424434920152, + 0.11161840210358302 + ], + [ + 0.13475091879566511, + 0.11562763527035713, + 0.1333719144264857, + 0.1229245625436306, + 0.12197200953960419, + 0.13096169382333755, + 0.1288150871793429, + 0.11157594497005145 + ], + [ + 0.13582054153084755, + 0.11752270037929217, + 0.12508786966403326, + 0.12327940265337627, + 0.11954270924131076, + 0.1335661051174005, + 0.13265361761053404, + 0.11252681662638982 + ], + [ + 0.14121857533852258, + 0.1201777954896291, + 0.1219904695947965, + 0.12393187607328097, + 0.11556172867616017, + 0.13259286557634672, + 0.13163785512248674, + 0.11288861806193988 + ], + [ + 0.13601584484179816, + 0.12416341404120128, + 0.1203830527762572, + 0.12698513145248094, + 0.1107034869492054, + 0.13164812326431274, + 0.13132713486750922, + 0.11877357835570972 + ], + [ + 0.12934383377432823, + 0.12344019487500191, + 0.12022254119316737, + 0.12799510856469473, + 0.12728174279133478, + 0.1253740700582663, + 0.1255892055730025, + 0.12075307841102283 + ], + [ + 0.1273275427520275, + 0.1225972647468249, + 0.1249833678205808, + 0.12505940720438957, + 0.1352445942660173, + 0.12052597353855769, + 0.12117260818680127, + 0.1230890154838562 + ], + [ + 0.12810913970073065, + 0.1238622876505057, + 0.12659715116024017, + 0.12370188285907109, + 0.12651319553454718, + 0.12940546746055284, + 0.12162479634086291, + 0.12018585453430812 + ], + [ + 0.12368856867154439, + 0.1279382382829984, + 0.12462495888272922, + 0.12257636214296024, + 0.12905889252821603, + 0.12702859317262968, + 0.12853120639920235, + 0.11655294025937717 + ], + [ + 0.11895465105772018, + 0.1208535706003507, + 0.13220032304525375, + 0.1210919866959254, + 0.14198759694894156, + 0.12408760314186414, + 0.1253624272843202, + 0.11546161149938901 + ], + [ + 0.11656401927272479, + 0.13009908298651376, + 0.13193458691239357, + 0.1244925099114577, + 0.1314594807724158, + 0.12587662413716316, + 0.13644214843710264, + 0.1031313215692838 + ], + [ + 0.10956315944592158, + 0.13123837485909462, + 0.12477318445841472, + 0.11995861927668254, + 0.13450312738617262, + 0.11849225436647733, + 0.13217019910613695, + 0.12930085634191832 + ], + [ + 0.11641354486346245, + 0.12387130285302798, + 0.1523391641676426, + 0.1218426376581192, + 0.11966235376894474, + 0.12413237988948822, + 0.12205732117096584, + 0.1196810578306516 + ], + [ + 0.12082756807406743, + 0.12492124860485394, + 0.1115306740005811, + 0.1328707362214724, + 0.11915713797012965, + 0.12142320970694225, + 0.13976961001753807, + 0.12949959623316923 + ], + [ + 0.12409573048353195, + 0.11580195898811023, + 0.12089811265468597, + 0.13245859742164612, + 0.1311631128191948, + 0.12894239649176598, + 0.1299689101676146, + 0.11667094379663467 + ], + [ + 0.11291537806391716, + 0.12622884040077528, + 0.12993097181121507, + 0.12755161399642626, + 0.13166394581397375, + 0.1184805581967036, + 0.12563432256380716, + 0.12759415184458098 + ], + [ + 0.10947434604167938, + 0.12353318681319554, + 0.12812436992923418, + 0.12915965666373572, + 0.1350412406027317, + 0.12685905396938324, + 0.12290513639648755, + 0.1249027947584788 + ], + [ + 0.11697788039843242, + 0.12664615859587988, + 0.12968199203411737, + 0.12836830193797746, + 0.12650749211510023, + 0.12662740796804428, + 0.12057684858640035, + 0.12461369236310323 + ], + [ + 0.1217580574254195, + 0.128155047694842, + 0.12898329024513563, + 0.12618319193522134, + 0.12871338178714117, + 0.12000210583209991, + 0.11673271531860034, + 0.12947197879354158 + ], + [ + 0.11736448978384335, + 0.12572254116336504, + 0.1197560653090477, + 0.12098861361543338, + 0.13338097060720125, + 0.12395314499735832, + 0.12381001561880112, + 0.13502392917871475 + ], + [ + 0.12727844839294752, + 0.1204719605545203, + 0.13443210845192274, + 0.11939483508467674, + 0.12722532575329146, + 0.12008217473824818, + 0.12073629597822826, + 0.13037861759463945 + ], + [ + 0.11982591450214386, + 0.1311459926267465, + 0.12544148042798042, + 0.12320375194152196, + 0.1318699630598227, + 0.11992330724994342, + 0.12716366474827132, + 0.12142569820086162 + ], + [ + 0.12001492455601692, + 0.12993334730466208, + 0.12616118043661118, + 0.12558047845959663, + 0.1293785534799099, + 0.1262988199790319, + 0.11762647330760956, + 0.12500600889325142 + ], + [ + 0.12432612851262093, + 0.12697297210494676, + 0.12139139696955681, + 0.12516535694400469, + 0.132976446300745, + 0.12418774267037709, + 0.11499736458063126, + 0.12998235722382864 + ], + [ + 0.12388576318820317, + 0.12411639218529065, + 0.1222571295996507, + 0.12633834655086199, + 0.13061792651812235, + 0.12673960998654366, + 0.1176796667277813, + 0.1283649317920208 + ], + [ + 0.12550217161575952, + 0.12844899048407873, + 0.12292332326372464, + 0.12429769833882649, + 0.1293539231022199, + 0.12310909976561864, + 0.12207399681210518, + 0.12429058303435643 + ], + [ + 0.1242099404335022, + 0.132162028302749, + 0.11925280715028445, + 0.12496921171744664, + 0.1291153977314631, + 0.12401337673266728, + 0.12548762559890747, + 0.12078938881556193 + ], + [ + 0.12300269802411397, + 0.12760676071047783, + 0.12121839076280594, + 0.12402171765764554, + 0.12774277354280153, + 0.12503591179847717, + 0.12774854277571043, + 0.12362298741936684 + ], + [ + 0.1208794539173444, + 0.1281136435767015, + 0.11974153046806653, + 0.12357223903139432, + 0.13088044275840124, + 0.12648063277204832, + 0.12779866655667624, + 0.12253315870960553 + ], + [ + 0.11865467578172684, + 0.12986683597167334, + 0.12611709038416544, + 0.1224153737227122, + 0.13093186790744463, + 0.12272992233435313, + 0.1218056579430898, + 0.1274783524374167 + ], + [ + 0.11571623881657918, + 0.1304558850824833, + 0.12886138136188188, + 0.12369953220089276, + 0.12517291928331056, + 0.1270586314300696, + 0.12378467991948128, + 0.12525050466259322 + ], + [ + 0.1295826956629753, + 0.12572671845555305, + 0.1227076289554437, + 0.1242753304541111, + 0.11937149862448375, + 0.12615970646341643, + 0.12635145088036856, + 0.12582474822799364 + ], + [ + 0.1286743904153506, + 0.12549515068531036, + 0.1256577211121718, + 0.12240219116210938, + 0.12598624701301256, + 0.12306476011872292, + 0.12694871798157692, + 0.12177061165372531 + ], + [ + 0.12314232935508092, + 0.12686155488093695, + 0.11951021229227383, + 0.12697360664606094, + 0.12791781375805536, + 0.125789658476909, + 0.12446790809432666, + 0.12533668180306753 + ], + [ + 0.1241135907669862, + 0.12972777461012205, + 0.11911755179365476, + 0.12830432380239168, + 0.12358428786198299, + 0.12492980683843295, + 0.12470047796765964, + 0.1255219615995884 + ], + [ + 0.12240238611896832, + 0.1311912996073564, + 0.12120572353402774, + 0.12344393258293469, + 0.12595620627204576, + 0.12456697598099709, + 0.12551271791259447, + 0.12572052453955015 + ], + [ + 0.12427420789996783, + 0.1326091860731443, + 0.11788347860177358, + 0.12398536627491315, + 0.12529929851492247, + 0.1253680313626925, + 0.12684941291809082, + 0.12373079732060432 + ], + [ + 0.125291109085083, + 0.1312564785281817, + 0.11925191804766655, + 0.12449102724591891, + 0.12465670083959897, + 0.12118223930398624, + 0.12489944448073705, + 0.12897085895140967 + ], + [ + 0.12587743749221167, + 0.12903260191281637, + 0.12158710757891338, + 0.12311855579415958, + 0.12475524594386418, + 0.12417393550276756, + 0.12488512446482976, + 0.12656976530949274 + ], + [ + 0.12473955502112706, + 0.12809143712123236, + 0.12220078830917676, + 0.1234413670996825, + 0.12370897953708966, + 0.12586859737833342, + 0.12746144707004228, + 0.12448759501179059 + ], + [ + 0.12433573727806409, + 0.12748964875936508, + 0.12268320471048355, + 0.12315206105510394, + 0.12125281989574432, + 0.12636147563656172, + 0.12954217319687208, + 0.12518265470862389 + ], + [ + 0.11830837776263554, + 0.1296043060719967, + 0.12109218041102092, + 0.12455314149459203, + 0.12382195269068082, + 0.1270710453391075, + 0.12621599808335304, + 0.12933278332153955 + ], + [ + 0.11780968184272449, + 0.1294283593694369, + 0.12071596210201581, + 0.12650182594855627, + 0.12468611697355907, + 0.12508624295393625, + 0.1269967518746853, + 0.1287748341759046 + ], + [ + 0.12523077055811882, + 0.12627979988853136, + 0.12292365605632465, + 0.12404945368568103, + 0.12284694860378902, + 0.12474338461955388, + 0.12686366339524588, + 0.1270621009171009 + ], + [ + 0.12570624674359956, + 0.12653281167149544, + 0.12346539398034413, + 0.12407474095622699, + 0.12006995330254237, + 0.12720628455281258, + 0.12463844940066338, + 0.12830590829253197 + ], + [ + 0.12223443140586217, + 0.1298845075070858, + 0.12083819756905238, + 0.1235511377453804, + 0.11848084752758344, + 0.12701461339990297, + 0.12788705031077066, + 0.130109004676342 + ], + [ + 0.12085733438531558, + 0.12706730142235756, + 0.1209068310757478, + 0.1273010733226935, + 0.11718778436382611, + 0.12877833594878516, + 0.12815284977356592, + 0.1297482637067636 + ], + [ + 0.12403937925895055, + 0.12928534795840582, + 0.12168120841185252, + 0.12345623721679051, + 0.12253078321615855, + 0.12534143775701523, + 0.129843320697546, + 0.12382205078999202 + ], + [ + 0.12303353225191434, + 0.13010175650318465, + 0.12115228548645973, + 0.12308774515986443, + 0.12382585431138675, + 0.12673079346617064, + 0.12730770061413446, + 0.12476008757948875 + ], + [ + 0.11872967332601547, + 0.1277495672305425, + 0.12213504686951637, + 0.1231309895714124, + 0.12067561844984691, + 0.13095348328351974, + 0.12651504948735237, + 0.1301103631655375 + ] + ], + "load_balancing_losses": [ + 0.060456137359142306, + 0.060292360931634904, + 0.06032455898821354, + 0.060554913431406024, + 0.06062224581837654, + 0.06041250079870224, + 0.06013858206570148, + 0.06017278544604778, + 0.060145171359181404, + 0.060342563688755034, + 0.060983060672879216, + 0.06179597340524197, + 0.06163978241384029, + 0.06330655254423619, + 0.0628902368247509, + 0.06176042817533016, + 0.061490241810679434, + 0.06150852479040623, + 0.06064310297369957, + 0.0606524184346199, + 0.06090476848185063, + 0.060939608886837957, + 0.06066209562122822, + 0.060476723685860635, + 0.06057591550052166, + 0.060483630374073984, + 0.06039879322052002, + 0.06035055033862591, + 0.06033693589270115, + 0.06036769822239876, + 0.06037281677126884, + 0.06030120328068733, + 0.0603391744196415, + 0.060228198766708374, + 0.060263966396451, + 0.060319313779473305, + 0.06021736338734627, + 0.06025288291275501, + 0.06022521629929543, + 0.060135988518595695, + 0.060185694321990016, + 0.060134252160787584, + 0.06017463430762291, + 0.060160320997238156, + 0.06013125330209732, + 0.06014687418937683, + 0.060117506980896, + 0.06017342619597912, + 0.06009880602359772, + 0.06017463803291321 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.11714443191885948, + 0.12773115436236063, + 0.12167937681078911, + 0.12277515605092049, + 0.11939165989557902, + 0.1326965851088365, + 0.12684522941708565, + 0.13173617919286093 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/temp_0.7/metrics.json b/experiments/exp10_routing_temperature_specialization/results/temp_0.7/metrics.json new file mode 100644 index 0000000..28773ba --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/temp_0.7/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "temp_0.7", + "description": "Sharp routing (moderate exploitation)", + "temperature": 0.7, + "temperature_schedule": null, + "final_metrics": { + "val_loss": 24.06204849755385, + "val_accuracy": 0.016395483117008846, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.112067937850952, + 4.496769285202026, + 3.210463786125183, + 1.8810135126113892, + 0.7712685972452163, + 0.3616963684558868, + 0.22627402395009993, + 0.16312743723392487, + 0.11461180076003075, + 0.09478219896554947, + 0.05062285885214805, + 0.018510314356535672, + 0.014590423833578826, + 0.011878180038183927, + 0.011375087313354015, + 0.011114948987960815, + 0.010101511608809233, + 0.00876059541478753, + 0.009050211170688272, + 0.008253366919234395, + 0.010212093964219094, + 0.0011610745335929097, + 0.0008900154498405755, + 0.0009279754129238426, + 0.0009829458431340754, + 0.0009763274458236993, + 0.0008119105710648, + 0.0006678252131678164, + 0.0006168811582028866, + 0.0005886599596124142, + 0.00042787316197063775, + 0.00037357669207267464, + 0.00035634434898383914, + 0.0003332898369990289, + 0.0003163952234899625, + 0.00035783400817308576, + 0.0003093982246355154, + 0.00026944268611259756, + 0.0002971166497445665, + 0.00028948978579137476, + 0.0002538443906814791, + 0.00026563800929579884, + 0.00019774129759753123, + 0.00020743943750858306, + 0.00021030032658018172, + 0.00024408763565588742, + 0.0002110597284627147, + 0.00019490414706524462, + 0.00020896066998830066, + 0.0002056415833067149 + ], + "val_losses": [ + 10.774517844506793, + 10.7349182135646, + 10.7373808668696, + 11.269729310969161, + 12.438262336245694, + 14.913704396979126, + 16.65020022712411, + 19.264517174171473, + 20.772590590028795, + 22.582829202443044, + 23.60784914805274, + 24.343325928327474, + 24.69480530701762, + 24.33820892981, + 24.108848099994997, + 24.141309785337416, + 24.056912149220388, + 23.3765770335922, + 23.284320568448663, + 23.473308637369648, + 23.355081167322165, + 23.471337153296588, + 23.285484394841816, + 23.095294419952502, + 23.005234337527003, + 23.30739671686934, + 23.376461116669457, + 23.57397195391436, + 23.76514703730391, + 23.92136697128889, + 23.816413919833018, + 23.867370436975055, + 23.92273301033586, + 24.1076847049457, + 24.12402884277775, + 24.087319390933843, + 24.081154253794534, + 24.151920466877968, + 24.01231571642333, + 24.053209250891587, + 23.987569849398447, + 24.182945089710895, + 24.18281513389345, + 24.091866914459338, + 24.052738371670458, + 24.132946856872774, + 24.06822200545995, + 24.150545868351266, + 24.130033290849557, + 24.06204849755385 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846 + ], + "val_perplexities": [ + 47787.426718026465, + 45932.041002068734, + 46045.295090187814, + 78411.769086319, + 252271.79484452913, + 2998744.913653764, + 17025116.06193612, + 232527062.52418977, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.02899393637975057, + 0.07029741207758586, + 0.10743565162022908, + 0.14592480659484863, + 0.1830313801765442, + 0.2214730461438497, + 0.2584453026453654, + 0.29664746125539143, + 0.34285277128219604, + 0.3811239957809448, + 0.4183243672053019, + 0.4570102532704671, + 0.49400624831517537, + 0.5322786490122478, + 0.5691534479459127, + 0.6072999318440755, + 0.6442915519078573, + 0.6827718615531921, + 0.7197797973950704, + 0.7580424189567566, + 0.7948365171750387, + 0.8331233620643616, + 0.8701134959856669, + 0.9081664880116781, + 0.9448094725608825, + 0.9828936497370402, + 1.0197464108467102, + 1.0580111622810364, + 1.0948072751363118, + 1.1328449368476867, + 1.1696253339449565, + 1.2167354861895243, + 1.2532144943873087, + 1.2912716428438822, + 1.327982449531555, + 1.366007419427236, + 1.4030836701393128, + 1.4416930397351584, + 1.4788535873095194, + 1.5172484596570333, + 1.554825242360433, + 1.5931318879127503, + 1.6302163084348043, + 1.6686877489089966, + 1.7056994557380676, + 1.7441300789515177, + 1.7813162446022033, + 1.830000674724579, + 1.8678601384162903, + 1.9063661336898803 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7, + 0.7 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12453143174449603, + 0.11496849730610847, + 0.1393900103867054, + 0.12799495086073875, + 0.13397281616926193, + 0.1260800895591577, + 0.12121189509828885, + 0.11185011391838391 + ], + [ + 0.13227659339706102, + 0.11566060781478882, + 0.1363562506934007, + 0.12218702336152394, + 0.12148407846689224, + 0.1289877568682035, + 0.12936324998736382, + 0.11368423079450925 + ], + [ + 0.13637947912017503, + 0.11605746174852054, + 0.1259271614253521, + 0.12267788872122765, + 0.1192453180750211, + 0.13343347484866777, + 0.13287650545438132, + 0.11340246473749478 + ], + [ + 0.14725747207800546, + 0.11396152153611183, + 0.12428456793228786, + 0.12208759536345799, + 0.11511942744255066, + 0.13082696249087652, + 0.13446866224209467, + 0.11199358229835828 + ], + [ + 0.14236059536536536, + 0.11933838948607445, + 0.12279383093118668, + 0.12463349973162015, + 0.11191822464267413, + 0.13222296411792436, + 0.13113710284233093, + 0.11559516812364261 + ], + [ + 0.13510920852422714, + 0.11966578289866447, + 0.12196752180655797, + 0.12424168487389882, + 0.1277990092833837, + 0.12884431332349777, + 0.12402083600560825, + 0.11835140983263652 + ], + [ + 0.12952748065193495, + 0.12345612173279126, + 0.12448244790236156, + 0.12950273106495538, + 0.13035212581356367, + 0.12151261543234189, + 0.12221999714771907, + 0.11894626418749492 + ], + [ + 0.1251110943655173, + 0.12001607194542885, + 0.12763463954130808, + 0.12912959853808084, + 0.13108844434221587, + 0.12541219343741736, + 0.12250244741638501, + 0.11910528192917506 + ], + [ + 0.12302539745966594, + 0.11866408834854762, + 0.12754279499252638, + 0.12626054510474205, + 0.13392052799463272, + 0.1267679197092851, + 0.12052155658602715, + 0.12329696118831635 + ], + [ + 0.1270607809225718, + 0.12723393738269806, + 0.12690831099947295, + 0.12423736229538918, + 0.12701247384150824, + 0.12451280280947685, + 0.12723606949051222, + 0.11579805240035057 + ], + [ + 0.12361529345313708, + 0.1265696461002032, + 0.1298068861166636, + 0.1200558344523112, + 0.13444407532612482, + 0.12331046164035797, + 0.12629258135954538, + 0.11590499927600224 + ], + [ + 0.11837935323516528, + 0.127372487137715, + 0.1259137230614821, + 0.12630994245409966, + 0.1252996933956941, + 0.12047553310791652, + 0.1276534932355086, + 0.12859554712971052 + ], + [ + 0.10550033052762349, + 0.1314114717145761, + 0.14918473859628043, + 0.12052537004152934, + 0.13746980081001917, + 0.11758442223072052, + 0.1231200968225797, + 0.11520353456338246 + ], + [ + 0.13388494650522867, + 0.12259298687179883, + 0.1287511611978213, + 0.12280752261479695, + 0.11934237678845723, + 0.12334975848595302, + 0.13346663738290468, + 0.11580439408620198 + ], + [ + 0.12943816805879274, + 0.10754191999634106, + 0.13460869590441385, + 0.13112997884551683, + 0.12401883800824483, + 0.1300627402961254, + 0.12237226714690526, + 0.12082716450095177 + ], + [ + 0.11855907614032428, + 0.12126392250259717, + 0.1440577395260334, + 0.12114035213987033, + 0.11847339197993279, + 0.12792957201600075, + 0.12670168528954187, + 0.12187402943770091 + ], + [ + 0.12161451329787572, + 0.13499638189872107, + 0.12065786123275757, + 0.12501701091726622, + 0.13020226980249086, + 0.12976282214124998, + 0.12540542458494505, + 0.11234350750843684 + ], + [ + 0.12973142166932425, + 0.13174738610784212, + 0.12156182030836742, + 0.12095919996500015, + 0.13606150448322296, + 0.1244531696041425, + 0.1186086895565192, + 0.11687659099698067 + ], + [ + 0.11833133175969124, + 0.11668251951535542, + 0.12443765128652255, + 0.12493685508767764, + 0.13784114023049673, + 0.12191308910648029, + 0.13168937588731447, + 0.12416783968607585 + ], + [ + 0.11963944633801778, + 0.1230054038266341, + 0.12803484747807184, + 0.12142223368088405, + 0.131513811647892, + 0.12900624424219131, + 0.11891813700397809, + 0.1284596547484398 + ], + [ + 0.12358661244312923, + 0.1176016591489315, + 0.12936269864439964, + 0.12325115377704303, + 0.12730239828427634, + 0.1339393568535646, + 0.11940756067633629, + 0.1255483292043209 + ], + [ + 0.12481745332479477, + 0.13168203085660934, + 0.1277110862235228, + 0.11503095676501592, + 0.13383876283963522, + 0.12488507603605588, + 0.119178406894207, + 0.12285600105921428 + ], + [ + 0.12199984242518742, + 0.1301881397763888, + 0.13069594651460648, + 0.11823334296544392, + 0.13638809571663538, + 0.12914506097634634, + 0.11119408657153447, + 0.12215522925059001 + ], + [ + 0.1291948234041532, + 0.12748181695739427, + 0.1227190059920152, + 0.1270525778333346, + 0.13118489955862364, + 0.1218670184413592, + 0.11501999323566754, + 0.12547961995005608 + ], + [ + 0.12901792923609415, + 0.12177031611402829, + 0.12041119610269864, + 0.12369502211610477, + 0.131863571703434, + 0.12662316486239433, + 0.122923177977403, + 0.12369539837042491 + ], + [ + 0.11837125817934673, + 0.13158255442976952, + 0.12795273214578629, + 0.12508020674188933, + 0.1280895248055458, + 0.12880869209766388, + 0.11864763870835304, + 0.12146716192364693 + ], + [ + 0.12532461062073708, + 0.12721378604571024, + 0.12727057188749313, + 0.12146658698717754, + 0.12477250148852666, + 0.13018481557567915, + 0.11790328100323677, + 0.12586362287402153 + ], + [ + 0.12078424170613289, + 0.12637578075130781, + 0.12781731535991034, + 0.11976419016718864, + 0.13019287462035814, + 0.12740293269356093, + 0.12066513299942017, + 0.1269973044594129 + ], + [ + 0.12396290153265, + 0.12829026828209558, + 0.12758610770106316, + 0.11942330747842789, + 0.1291553055246671, + 0.12778126945098242, + 0.1237215759853522, + 0.12007903928558032 + ], + [ + 0.12240591024359067, + 0.13029284899433455, + 0.12588968873023987, + 0.12113125001390775, + 0.12899786358078322, + 0.13325552145640054, + 0.11435889825224876, + 0.12366778900225957 + ], + [ + 0.12865598127245903, + 0.127031572163105, + 0.12451412156224251, + 0.12327592571576436, + 0.1241954118013382, + 0.12527688468496004, + 0.11522976433237393, + 0.13182011991739273 + ], + [ + 0.11861250922083855, + 0.12714369843403497, + 0.12744780133167902, + 0.12245933463176091, + 0.12882312883933386, + 0.12698503956198692, + 0.12331462527314822, + 0.12521362925569215 + ], + [ + 0.12100931877891223, + 0.1273407724996408, + 0.12754465887943903, + 0.12195862457156181, + 0.1268415947755178, + 0.1275085707505544, + 0.12030911942323048, + 0.12748711183667183 + ], + [ + 0.12320331359903018, + 0.12195243189732234, + 0.1290948453048865, + 0.12181657925248146, + 0.12985377882917723, + 0.126993956665198, + 0.12525227293372154, + 0.12183261041839917 + ], + [ + 0.12247705087065697, + 0.1265727716187636, + 0.12870260576407114, + 0.12400179480512936, + 0.1264149770140648, + 0.1266352174182733, + 0.12139490619301796, + 0.12380044410626094 + ], + [ + 0.12582687785228094, + 0.12596194942792258, + 0.12499444683392842, + 0.12555823599298796, + 0.12745029603441557, + 0.12706583614150682, + 0.11914282788832982, + 0.1239993025859197 + ], + [ + 0.12873218084375063, + 0.12828142692645392, + 0.1235573614637057, + 0.12279041111469269, + 0.12636741499106088, + 0.1256412404278914, + 0.12011441215872765, + 0.12451532483100891 + ], + [ + 0.11999987314144771, + 0.12533981601397196, + 0.12709044168392816, + 0.1248030997812748, + 0.1285798338552316, + 0.1270418787995974, + 0.1185736854871114, + 0.12857115268707275 + ], + [ + 0.12316559006770451, + 0.12585913638273874, + 0.12433304513494174, + 0.12648453190922737, + 0.1273066128293673, + 0.12436126048366229, + 0.1223741148908933, + 0.12611548602581024 + ], + [ + 0.12334290146827698, + 0.12372694785396258, + 0.1239702266951402, + 0.12513155738512674, + 0.12601656715075174, + 0.12793964023391405, + 0.11898062253991763, + 0.13089129949609438 + ], + [ + 0.12190930545330048, + 0.1219970074792703, + 0.12516677503784499, + 0.12698117519418398, + 0.12417911365628242, + 0.12688796843091646, + 0.12129763389627139, + 0.13158080105980238 + ], + [ + 0.12313465525706609, + 0.1303013637661934, + 0.12590567395091057, + 0.12068767473101616, + 0.12716281414031982, + 0.12722331037124, + 0.11955539012948672, + 0.12602889786163965 + ], + [ + 0.12320040787259738, + 0.12801932419339815, + 0.12599809716145197, + 0.12118489543596904, + 0.12826721494396529, + 0.12658055250843367, + 0.11931534980734189, + 0.12743391344944635 + ], + [ + 0.12490696460008621, + 0.12639600411057472, + 0.12540500486890474, + 0.12124392886956532, + 0.12679476042588553, + 0.12724610914786658, + 0.12013433128595352, + 0.127872663239638 + ], + [ + 0.12320289512475331, + 0.12719344968597093, + 0.12557957569758096, + 0.12252599621812503, + 0.12148617580533028, + 0.12646541371941566, + 0.12355413660407066, + 0.12999213859438896 + ], + [ + 0.12592874467372894, + 0.12387706836064656, + 0.12500338380535445, + 0.11868727455536525, + 0.12665331612030664, + 0.12947743758559227, + 0.12183257689078648, + 0.12853996207316717 + ], + [ + 0.12881905709703764, + 0.1244052139421304, + 0.12312877054015796, + 0.12160760536789894, + 0.1253266098598639, + 0.12823248902956644, + 0.12108215938011806, + 0.12739786505699158 + ], + [ + 0.1221221312880516, + 0.12825991585850716, + 0.12641378367940584, + 0.12271094198028247, + 0.1255770760277907, + 0.1289760246872902, + 0.12049512565135956, + 0.12544478724400202 + ], + [ + 0.12435842305421829, + 0.12852666775385538, + 0.12567375600337982, + 0.12097492938240369, + 0.12561175723870596, + 0.13016132886211076, + 0.11950503041346867, + 0.12518788749972978 + ], + [ + 0.12167050565282504, + 0.1271545017759005, + 0.125619205335776, + 0.12350450828671455, + 0.12360409523049991, + 0.13052166004975638, + 0.12296523402134578, + 0.12496006612976392 + ] + ], + "load_balancing_losses": [ + 0.06033435836434364, + 0.06022239103913307, + 0.06026037149131298, + 0.060485242307186125, + 0.060628772526979444, + 0.06041885502636433, + 0.060126332193613054, + 0.06010495983064175, + 0.06014377623796463, + 0.06015973538160324, + 0.06057863309979439, + 0.06095994673669338, + 0.061476511880755424, + 0.06260076276957989, + 0.06211287751793861, + 0.06199055723845959, + 0.062156564369797704, + 0.06126252263784408, + 0.06117308884859085, + 0.061188652738928796, + 0.060848427936434746, + 0.06091635078191757, + 0.060828397423028945, + 0.06061984449625015, + 0.060786657780408856, + 0.060439277440309525, + 0.06046934574842453, + 0.06045899540185928, + 0.06035973466932774, + 0.06043998375535011, + 0.06030420996248722, + 0.060284508392214775, + 0.060312869027256966, + 0.06038755737245083, + 0.06037391498684883, + 0.06016887500882149, + 0.06023351810872555, + 0.06018978990614414, + 0.06017041876912117, + 0.06018405072391033, + 0.06026977226138115, + 0.06019855812191963, + 0.06017014607787132, + 0.060180007666349414, + 0.06015354059636593, + 0.06014581024646759, + 0.06009987182915211, + 0.06012827195227146, + 0.0601534079760313, + 0.06009095869958401 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.12098584448297818, + 0.12730055674910545, + 0.1256709930797418, + 0.12364924574891727, + 0.12284448618690173, + 0.1314833971361319, + 0.12321910013755162, + 0.1248461405436198 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/temp_1.0/metrics.json b/experiments/exp10_routing_temperature_specialization/results/temp_1.0/metrics.json new file mode 100644 index 0000000..a71abb1 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/temp_1.0/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "temp_1.0", + "description": "Standard softmax (baseline)", + "temperature": 1.0, + "temperature_schedule": null, + "final_metrics": { + "val_loss": 24.14784315533857, + "val_accuracy": 0.016395483117008846, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.09113883972168, + 4.480490636825562, + 3.1976430892944334, + 1.871941614151001, + 0.7655209273099899, + 0.3588774919509888, + 0.22509971708059312, + 0.16330183446407318, + 0.1141592837870121, + 0.09381150230765342, + 0.05000832825899124, + 0.017994978558272125, + 0.014239040575921536, + 0.011616227310150861, + 0.01084116119891405, + 0.01057725427672267, + 0.009997121896594763, + 0.008504253765568136, + 0.008790405746549369, + 0.008216845290735363, + 0.009981157537549735, + 0.0011251060583163052, + 0.0008337775245308876, + 0.0008390640781726688, + 0.0009229881630744785, + 0.0009262891195248813, + 0.0007446481584338471, + 0.0005801771621918306, + 0.0005416615080321207, + 0.0005083174066385255, + 0.0004116814641747624, + 0.00037529021501541137, + 0.00030076043331064286, + 0.000311051748576574, + 0.0003004116573720239, + 0.0002794522748445161, + 0.00026996448723366485, + 0.00023380829952657223, + 0.00023257495486177503, + 0.00023896607453934848, + 0.00025653688062448053, + 0.00023033128964016213, + 0.00018986025388585405, + 0.00018278957286383958, + 0.00018535426788730548, + 0.00018667437107069417, + 0.00018351613398408516, + 0.0001852859539212659, + 0.0001780228951247409, + 0.0001856664166552946 + ], + "val_losses": [ + 10.773937434273979, + 10.736033197005309, + 10.739948950049733, + 11.275681165418861, + 12.455180340015424, + 14.939341049733516, + 16.685722148881784, + 19.332226824002216, + 20.83824706330316, + 22.759531054816904, + 23.76758439970522, + 24.48413311045077, + 24.779609444284606, + 24.67976861960475, + 24.552185537115847, + 24.056931229446466, + 23.87612458758135, + 24.288348059772183, + 24.253868999413804, + 23.709496535176523, + 23.709736409541154, + 23.540967368405614, + 23.5197979553007, + 23.17517275186815, + 23.209820184606546, + 23.544872937691085, + 23.524504671669682, + 23.643126248470885, + 23.811226396594368, + 23.885758888595095, + 23.870604868912444, + 23.851947919218784, + 23.974923555084338, + 24.186039462948855, + 24.17390729681763, + 24.24803643513063, + 24.21730394666692, + 24.245885626587345, + 24.18392424027406, + 24.29600830886895, + 24.242489777689688, + 24.30887871327754, + 24.33785851684139, + 24.319907501813802, + 24.336846718939793, + 24.217444483888443, + 24.23894689445361, + 24.260208824076837, + 24.22805333221759, + 24.14784315533857 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846 + ], + "val_perplexities": [ + 47759.69845422191, + 45983.283028881095, + 46163.69520363189, + 78879.85613423931, + 256576.0369503451, + 3076616.6156222885, + 17640750.42753917, + 248816646.5638863, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.0288797656695048, + 0.07006478706995646, + 0.107588263352712, + 0.14602460463841757, + 0.1833832343419393, + 0.22215340932210287, + 0.2593945463498433, + 0.2984349727630615, + 0.3358225464820862, + 0.37452745040257773, + 0.41175854206085205, + 0.4500883499781291, + 0.4870405356089274, + 0.5252404928207397, + 0.5620970646540324, + 0.6006728291511536, + 0.6418313185373942, + 0.6853635986646016, + 0.7223888317743937, + 0.7614094177881877, + 0.798551599184672, + 0.8370514988899231, + 0.8740403135617574, + 0.9129669467608134, + 0.9503952383995056, + 0.9892154057820638, + 1.0264459649721782, + 1.0651225129763284, + 1.1024566372235616, + 1.14079806804657, + 1.1777773300806682, + 1.2161459565162658, + 1.253126041094462, + 1.2915813644727072, + 1.3286664605140686, + 1.3668164054552714, + 1.4037321050961813, + 1.44761483669281, + 1.489804244041443, + 1.5283455967903137, + 1.5654837449391683, + 1.6044031023979186, + 1.6417177279790243, + 1.6804942727088927, + 1.7176470756530762, + 1.7565688769022623, + 1.793628489971161, + 1.8324754158655803, + 1.8700798432032266, + 1.9087648431460063 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12434487789869308, + 0.11439689124623935, + 0.13921349371472994, + 0.12820876265565553, + 0.13394259909788767, + 0.1266429846485456, + 0.12089247008164723, + 0.11235771079858144 + ], + [ + 0.13128342355291048, + 0.11351537828644116, + 0.13641736408074698, + 0.12234945967793465, + 0.12184692546725273, + 0.12907233337561289, + 0.1307966746389866, + 0.11471820250153542 + ], + [ + 0.1363802229364713, + 0.11212662234902382, + 0.1303014655907949, + 0.1212917019923528, + 0.11884549756844838, + 0.13127440586686134, + 0.1338394284248352, + 0.11594041312734286 + ], + [ + 0.14164022356271744, + 0.11179485668738683, + 0.12981270998716354, + 0.12350728238622348, + 0.11282514284054439, + 0.13022646059592566, + 0.13862709452708563, + 0.11156600217024486 + ], + [ + 0.13874489814043045, + 0.11908886457482974, + 0.12262414768338203, + 0.12564167007803917, + 0.11032796278595924, + 0.12965262060364088, + 0.13668192674716315, + 0.11723766848444939 + ], + [ + 0.1320007344086965, + 0.12069530785083771, + 0.12405812988678615, + 0.12654931098222733, + 0.12289359668890636, + 0.12867771089076996, + 0.12719210733970007, + 0.11793286850055058 + ], + [ + 0.12706061080098152, + 0.12407400707403819, + 0.12485427036881447, + 0.12521816293398538, + 0.12819632266958556, + 0.12472877651453018, + 0.12549391513069472, + 0.12037371844053268 + ], + [ + 0.12257216374079387, + 0.1252812718351682, + 0.12332311520973842, + 0.12687496468424797, + 0.12468025088310242, + 0.12848203629255295, + 0.12345575292905171, + 0.12533021345734596 + ], + [ + 0.12569964801271757, + 0.12504238883654276, + 0.12266098459561665, + 0.1252934585014979, + 0.12283783157666524, + 0.12893504897753397, + 0.12204102923472722, + 0.1274893879890442 + ], + [ + 0.1241666649778684, + 0.12545820325613022, + 0.1247733694811662, + 0.12188234676917394, + 0.13003683338562647, + 0.12641836951176325, + 0.12277806177735329, + 0.12448594222466151 + ], + [ + 0.13071334486206374, + 0.12533838798602423, + 0.12794277692834535, + 0.12035442019502322, + 0.13564429432153702, + 0.11804419755935669, + 0.12109979490439098, + 0.12086255848407745 + ], + [ + 0.12798994282881418, + 0.1263785921037197, + 0.1127744975189368, + 0.12458395833770435, + 0.12577854469418526, + 0.12425562739372253, + 0.12120387454827626, + 0.13703472912311554 + ], + [ + 0.13041002675890923, + 0.12867620835701624, + 0.12115462993582089, + 0.11732279881834984, + 0.12475077932079633, + 0.12708219264944395, + 0.122138316432635, + 0.1284648155172666 + ], + [ + 0.1122934768597285, + 0.13360764707128206, + 0.13780534639954567, + 0.11906703685720761, + 0.1355229082206885, + 0.11220678314566612, + 0.12830886617302895, + 0.12118769312898318 + ], + [ + 0.129513098547856, + 0.1251868481437365, + 0.1188648280998071, + 0.11354967455069225, + 0.14039851228396097, + 0.12249288832147916, + 0.121981892734766, + 0.12801204125086466 + ], + [ + 0.1145308402677377, + 0.1230184758702914, + 0.12542719642321268, + 0.11447481686870258, + 0.13360965624451637, + 0.1333265999952952, + 0.12860063090920448, + 0.12701155742009482 + ], + [ + 0.12394643947482109, + 0.11446065083146095, + 0.12349232658743858, + 0.13022988786300024, + 0.13386522606015205, + 0.12588074058294296, + 0.11780523136258125, + 0.1303192414343357 + ], + [ + 0.12499604125817616, + 0.12121791144212087, + 0.13631032655636469, + 0.12124486640095711, + 0.13621129964788756, + 0.12020563582579295, + 0.11758425086736679, + 0.12222942585746448 + ], + [ + 0.1103852925201257, + 0.13740763440728188, + 0.11933425193031628, + 0.11140703906615575, + 0.1334521621465683, + 0.1267652486761411, + 0.11805485064784686, + 0.1431932896375656 + ], + [ + 0.11508527646462123, + 0.12559345488746962, + 0.13083524132768312, + 0.1320414344469706, + 0.12841453775763512, + 0.11866897841294606, + 0.11492867519458135, + 0.13443215936422348 + ], + [ + 0.11627472812930743, + 0.13153652101755142, + 0.12963611135880151, + 0.12677199766039848, + 0.12542630483706793, + 0.11930591613054276, + 0.12408128629128139, + 0.12696688125530878 + ], + [ + 0.1234877494474252, + 0.1225827510158221, + 0.1193958359460036, + 0.12666485582788786, + 0.1325356848537922, + 0.12866836786270142, + 0.12036860982577006, + 0.12629592046141624 + ], + [ + 0.12213220323125522, + 0.13608613734443983, + 0.12889859328667322, + 0.11959930509328842, + 0.12571467086672783, + 0.12602645282944044, + 0.12241137151916821, + 0.11913103734453519 + ], + [ + 0.12297567601005237, + 0.1265838456650575, + 0.12253548825780551, + 0.12760759641726813, + 0.134151807675759, + 0.11923159783085187, + 0.11881750325361888, + 0.12809624274571738 + ], + [ + 0.13069531818230948, + 0.13089853276809058, + 0.11867112666368484, + 0.1212017151216666, + 0.12660516798496246, + 0.12078311791022618, + 0.12157808616757393, + 0.12956670547525087 + ], + [ + 0.11950798084338506, + 0.1368637022872766, + 0.1302067736784617, + 0.12286792695522308, + 0.12405128652850787, + 0.12274115160107613, + 0.1170007052520911, + 0.12676023816068968 + ], + [ + 0.12232265373071034, + 0.12526065980394682, + 0.12844128534197807, + 0.1314048394560814, + 0.12170721466342609, + 0.1282384693622589, + 0.11439798027276993, + 0.12822666888435683 + ], + [ + 0.11936916038393974, + 0.12230427314837773, + 0.12743104745944342, + 0.1209075537820657, + 0.1284241129954656, + 0.12085816264152527, + 0.1291585167249044, + 0.13154693941275278 + ], + [ + 0.12203991661469142, + 0.12505094955364862, + 0.12608455245693526, + 0.12083844716350238, + 0.12768646826346716, + 0.12343813230593999, + 0.12590215355157852, + 0.12895914042989412 + ], + [ + 0.12382717430591583, + 0.12852186957995096, + 0.1250376415749391, + 0.12402642145752907, + 0.12728481367230415, + 0.12147467210888863, + 0.12458273147543271, + 0.12524444113175073 + ], + [ + 0.12038155148426692, + 0.12821376944581667, + 0.13108545914292336, + 0.12159218887488048, + 0.1340160146355629, + 0.12215173865358035, + 0.12030848364035289, + 0.1222505656381448 + ], + [ + 0.12307899942000707, + 0.12642152607440948, + 0.11786011358102162, + 0.1278838999569416, + 0.12766646593809128, + 0.12390709295868874, + 0.11893854041894276, + 0.13424313813447952 + ], + [ + 0.12411728252967198, + 0.1296741304298242, + 0.12419415886203448, + 0.12676671768228212, + 0.1318709266682466, + 0.11860968420902888, + 0.11814658592144649, + 0.12662029514710108 + ], + [ + 0.1263708434998989, + 0.13159728795289993, + 0.12242995947599411, + 0.12435454999407132, + 0.13213580350081125, + 0.12388414517045021, + 0.11419734607140224, + 0.1250298097729683 + ], + [ + 0.12643771121899286, + 0.13027775039275488, + 0.1227647215127945, + 0.12599819774429002, + 0.13238563388586044, + 0.12271384646495183, + 0.11936441560586293, + 0.12005749841531117 + ], + [ + 0.1254092405239741, + 0.12836809332172075, + 0.12571025639772415, + 0.12284444520870845, + 0.126833309729894, + 0.12042707577347755, + 0.12164800117413203, + 0.12875933945178986 + ], + [ + 0.12303312122821808, + 0.1283995620906353, + 0.12262720490495364, + 0.12347207466761272, + 0.13011173903942108, + 0.11714374522368114, + 0.12444138775269191, + 0.13077092915773392 + ], + [ + 0.12402197966972987, + 0.13021314268310866, + 0.125326469540596, + 0.12685117373863855, + 0.12002443770567577, + 0.12418365105986595, + 0.12152611836791039, + 0.12785280992587408 + ], + [ + 0.12008632843693097, + 0.13112975656986237, + 0.12426272655526797, + 0.1224189040561517, + 0.1258206951121489, + 0.12541619564096132, + 0.1227942022184531, + 0.12807096416751543 + ], + [ + 0.1258283443748951, + 0.12800854071974754, + 0.12865645935138068, + 0.12098223343491554, + 0.1254639724890391, + 0.12303851669033368, + 0.12062683080633481, + 0.12739486371477446 + ], + [ + 0.12633213897546133, + 0.12195236856738727, + 0.12871958936254183, + 0.11841499184568723, + 0.12568736697236696, + 0.12559561555584273, + 0.12633172050118446, + 0.12696599091092745 + ], + [ + 0.1252294940253099, + 0.12782330438494682, + 0.12411715090274811, + 0.12459852794806163, + 0.12674092128872871, + 0.12490290775895119, + 0.12097565953930219, + 0.12561181435982385 + ], + [ + 0.1252512795229753, + 0.12745548163851103, + 0.12545490637421608, + 0.12578727304935455, + 0.12578758473197618, + 0.12563205262025198, + 0.12170236309369405, + 0.12292881558338802 + ], + [ + 0.12489505981405576, + 0.1286754459142685, + 0.12371337289611499, + 0.12271923447648685, + 0.12558831398685774, + 0.12168664981921513, + 0.12488722304503123, + 0.12783448646465936 + ], + [ + 0.12619990855455399, + 0.12945982192953429, + 0.12630522375305495, + 0.12421594932675362, + 0.12238456060489018, + 0.11799658586581548, + 0.12714539468288422, + 0.12629234294096628 + ], + [ + 0.12650888785719872, + 0.12378709266583125, + 0.12695964922507605, + 0.12054035440087318, + 0.1267902416487535, + 0.12696699673930803, + 0.1233333374063174, + 0.1251132215062777 + ], + [ + 0.12230642884969711, + 0.12733921284476915, + 0.13051322226723036, + 0.11977638925115268, + 0.12997990598281225, + 0.1235600213209788, + 0.12314777697126071, + 0.12337683141231537 + ], + [ + 0.1269850768148899, + 0.12656385948260626, + 0.1273737202088038, + 0.12057073538502057, + 0.1251718190809091, + 0.12433378770947456, + 0.12215924387176831, + 0.12684154386321703 + ], + [ + 0.12745792046189308, + 0.1265581138432026, + 0.12788088247179985, + 0.12158059577147166, + 0.12522323802113533, + 0.12421569600701332, + 0.1213696151971817, + 0.12571370850006738 + ], + [ + 0.12524767716725668, + 0.12834259743491808, + 0.128499086946249, + 0.12479745720823605, + 0.12450083096822102, + 0.1230442648132642, + 0.12339879696567853, + 0.1221690687040488 + ] + ], + "load_balancing_losses": [ + 0.06023424156010151, + 0.060157297924160955, + 0.06018536314368248, + 0.0604026548564434, + 0.06059783287346363, + 0.06039171740412712, + 0.060144422203302385, + 0.06007430963218212, + 0.060014762356877326, + 0.06006012298166752, + 0.060176534205675127, + 0.060428624227643016, + 0.06087825559079647, + 0.061343811079859735, + 0.061993790417909624, + 0.06183895617723465, + 0.06149957664310932, + 0.0617400299757719, + 0.06165820844471455, + 0.061519939452409744, + 0.06081792041659355, + 0.06071990579366684, + 0.06079737544059753, + 0.060797262191772464, + 0.06082608960568905, + 0.06076700799167156, + 0.060684779286384584, + 0.060609976574778554, + 0.06047063022851944, + 0.06035999692976475, + 0.06038411594927311, + 0.06048354506492615, + 0.0603485681116581, + 0.06037918590009213, + 0.060289007052779196, + 0.0603128869086504, + 0.06020910255610943, + 0.06014908812940121, + 0.06023341231048107, + 0.06022571809589863, + 0.060230017825961116, + 0.06027653627097607, + 0.060231597349047664, + 0.0601585678756237, + 0.060177840292453766, + 0.06012939065694809, + 0.06012634225189686, + 0.06016096025705338, + 0.06015205271542072, + 0.06011253893375397 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.12492133180300395, + 0.12909898658593497, + 0.12932936598857245, + 0.1251355099181334, + 0.1242454784611861, + 0.12260574474930763, + 0.12337358668446541, + 0.12128978346784909 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/temp_1.5/metrics.json b/experiments/exp10_routing_temperature_specialization/results/temp_1.5/metrics.json new file mode 100644 index 0000000..b7b95f8 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/temp_1.5/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "temp_1.5", + "description": "Slightly softer routing", + "temperature": 1.5, + "temperature_schedule": null, + "final_metrics": { + "val_loss": 24.230346302261623, + "val_accuracy": 0.01640931313229101, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.07941255569458, + 4.470050525665283, + 3.189748001098633, + 1.8675041794776917, + 0.7632676482200622, + 0.35737786293029783, + 0.2243212342262268, + 0.16297896057367325, + 0.11432855054736138, + 0.09402412101626396, + 0.049738299660384656, + 0.01794921662658453, + 0.0140377645380795, + 0.01125486772507429, + 0.010487127676606178, + 0.010802598716691137, + 0.009789128880947828, + 0.008598889177665114, + 0.008833114383742213, + 0.008047992503270507, + 0.00974704665131867, + 0.0010814637818839401, + 0.0008665073430165649, + 0.0007731647638138384, + 0.0009059289353899658, + 0.0008798401046078652, + 0.0007464868423994631, + 0.0006158098258310929, + 0.0005335523921530694, + 0.000530580009217374, + 0.0003918131784303114, + 0.0003698200103826821, + 0.0002865417030989192, + 0.00032454111496917906, + 0.00028384837933117526, + 0.0002613380755065009, + 0.0002685751984245144, + 0.00022631593892583624, + 0.0002360674800002016, + 0.0002152905595721677, + 0.00022060711780795826, + 0.0002190691127907485, + 0.00016181429236894473, + 0.00015173296123975887, + 0.00016477183380629866, + 0.0001685488285147585, + 0.0001629557751584798, + 0.00016697193059371783, + 0.00015870419592829422, + 0.0001628738158615306 + ], + "val_losses": [ + 10.773804263596821, + 10.736812830813783, + 10.740654288248123, + 11.277853797265582, + 12.459404136603796, + 14.953661210966615, + 16.698000358608503, + 19.345720452891644, + 20.877604346393277, + 22.796806503942914, + 23.77202825242976, + 24.744333017841246, + 25.005298459487754, + 25.068294403831025, + 24.77007559247236, + 24.610581691189285, + 24.156209197566703, + 24.439545735874784, + 24.288322489168955, + 24.16876786841942, + 24.121293677879308, + 23.651119872453776, + 23.5872390194411, + 23.53049388117167, + 23.29745666719578, + 23.723795597629074, + 23.661165763127087, + 23.739367852362644, + 23.807156424640347, + 23.97727686609059, + 23.890479751694752, + 23.994819041275726, + 24.072168619809638, + 24.047475565448668, + 24.210196404069556, + 24.242271888382444, + 24.25704604913826, + 24.30280567479218, + 24.24752543419073, + 24.298186878433498, + 24.302981481114042, + 24.265265178343434, + 24.3629711339836, + 24.295208711927433, + 24.262049071780364, + 24.291918737728267, + 24.144373957765396, + 24.30951096902045, + 24.255207358316483, + 24.230346302261623 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.01640931313229101, + 0.016416228139932095, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.01640931313229101 + ], + "val_perplexities": [ + 47753.338686314615, + 46019.14712957113, + 46196.26770720294, + 79051.41932608224, + 257662.05388259448, + 3120991.229108778, + 17858682.430600658, + 252196840.30600345, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.028845703601837157, + 0.069787069161733, + 0.10685143470764161, + 0.15474619070688883, + 0.19242485761642455, + 0.2314830501874288, + 0.2693642497062683, + 0.30776947339375815, + 0.3452380021413167, + 0.3839008847872416, + 0.4212732275327047, + 0.46029427846272786, + 0.49734952052434284, + 0.5361509482065837, + 0.5731298685073852, + 0.6119277715682984, + 0.6491577665011088, + 0.6879922072092692, + 0.7255515535672505, + 0.7643712043762207, + 0.8014694690704346, + 0.840302030245463, + 0.877705204486847, + 0.9165882507960001, + 0.9538931528727214, + 0.9926469524701437, + 1.0295473337173462, + 1.0681593100229898, + 1.1056783676147461, + 1.1444274107615153, + 1.181662952899933, + 1.2240190943082174, + 1.2675413052241007, + 1.306263542175293, + 1.343692457675934, + 1.3824226101239523, + 1.419699756304423, + 1.4582550168037414, + 1.495357088247935, + 1.5339864412943522, + 1.5710699796676635, + 1.6097832798957825, + 1.6473271131515503, + 1.6858972350756327, + 1.7229277650515238, + 1.7616557359695435, + 1.7988004604975383, + 1.8373739043871562, + 1.8742409388224284, + 1.9130041400591533 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5, + 1.5 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12438557545344035, + 0.11416050419211388, + 0.13931202640136084, + 0.12829205269614855, + 0.1339040199915568, + 0.12665784855683646, + 0.1206250861287117, + 0.1126626767218113 + ], + [ + 0.13092279310027757, + 0.11317815134922664, + 0.13608486577868462, + 0.12269389381011327, + 0.12112885837753613, + 0.12925957764188448, + 0.13316190615296364, + 0.11356973896423976 + ], + [ + 0.13285020117958388, + 0.1136388008793195, + 0.13313506543636322, + 0.12162574504812558, + 0.11911803856492043, + 0.12985794494549432, + 0.13315003365278244, + 0.11662394305070241 + ], + [ + 0.14122172445058823, + 0.11569391439358394, + 0.13163222124179205, + 0.1266154187421004, + 0.11264701435963313, + 0.12551368897159895, + 0.1393063080807527, + 0.1073694850007693 + ], + [ + 0.13782472411791483, + 0.12028268973032634, + 0.13253470758597055, + 0.12732083847125372, + 0.1101675418516, + 0.12192655603090923, + 0.1401393860578537, + 0.10980332021911939 + ], + [ + 0.13298237696290016, + 0.1253955438733101, + 0.1280648373067379, + 0.13001761212944984, + 0.1202141356964906, + 0.12076797460516293, + 0.12974747146169344, + 0.11280981078743935 + ], + [ + 0.12845013414820036, + 0.1256443994740645, + 0.12555263812343279, + 0.12675793220599493, + 0.12574927881360054, + 0.12323842073480289, + 0.12530120213826498, + 0.11930576463540395 + ], + [ + 0.12512410804629326, + 0.1263566936055819, + 0.12037284299731255, + 0.1284050059815248, + 0.12679601709047952, + 0.12488981957236926, + 0.12079626445968945, + 0.12725901727875075 + ], + [ + 0.12177896375457446, + 0.1277801531056563, + 0.11933862417936325, + 0.1270314355691274, + 0.12633414069811502, + 0.12746884673833847, + 0.12149357423186302, + 0.128774031996727 + ], + [ + 0.12377914786338806, + 0.12665440514683723, + 0.12570683906475702, + 0.12556021536389986, + 0.12935950979590416, + 0.12394789978861809, + 0.12639013677835464, + 0.11860160902142525 + ], + [ + 0.12556897227962813, + 0.12556802108883858, + 0.1361290750404199, + 0.12173316379388173, + 0.13365462174018225, + 0.12030337750911713, + 0.12259456515312195, + 0.11444796870152156 + ], + [ + 0.12251974269747734, + 0.12213814755280812, + 0.11841649934649467, + 0.1254787047704061, + 0.12309875463445981, + 0.13421429445346197, + 0.12538658206661543, + 0.1287470373014609 + ], + [ + 0.1337949683268865, + 0.11917304371794064, + 0.11193906391660373, + 0.1272085097928842, + 0.12231325854857762, + 0.13468137383460999, + 0.12233986208836238, + 0.1285497061908245 + ], + [ + 0.12283102547128995, + 0.12849227090676626, + 0.11863423387209575, + 0.13367295016845068, + 0.13477065910895666, + 0.12512044856945673, + 0.12127372870842616, + 0.11520446961124738 + ], + [ + 0.1233710174759229, + 0.12805797904729843, + 0.11766039828459422, + 0.12846644471089044, + 0.12605340282122293, + 0.12266098707914352, + 0.12727216258645058, + 0.12645737951000532 + ], + [ + 0.12991464510560036, + 0.12015410140156746, + 0.13439288983742395, + 0.12505463262399039, + 0.11486523225903511, + 0.13050407792131105, + 0.1220979889233907, + 0.12301619971791904 + ], + [ + 0.13672183578213057, + 0.1247493326663971, + 0.13039072478810945, + 0.1244959607720375, + 0.11795181905229886, + 0.11843760311603546, + 0.12303643052776654, + 0.1242160511513551 + ], + [ + 0.12401789178450902, + 0.1314046954115232, + 0.11350093657771747, + 0.12536342442035675, + 0.12280518934130669, + 0.1282393385966619, + 0.11677200223008792, + 0.13789628694454828 + ], + [ + 0.12937326853473982, + 0.13642243420084318, + 0.10938338811198871, + 0.12648497770229974, + 0.1334418604771296, + 0.12862301617860794, + 0.11457288265228271, + 0.12169792751471202 + ], + [ + 0.12073943267265956, + 0.1260242611169815, + 0.12009921421607335, + 0.12048778558770816, + 0.11710826555887859, + 0.1315478285153707, + 0.11866983274618785, + 0.1453231597940127 + ], + [ + 0.11480336139599483, + 0.12643311296900114, + 0.12642310311396918, + 0.12638875593741736, + 0.13676346465945244, + 0.11970566461483638, + 0.1149379312992096, + 0.1345443738003572 + ], + [ + 0.1241380547483762, + 0.12243873005112012, + 0.1269079049428304, + 0.1280613417426745, + 0.12271630764007568, + 0.12194547553857167, + 0.1251202697555224, + 0.12867169082164764 + ], + [ + 0.12141696363687515, + 0.1183098703622818, + 0.13520249972740808, + 0.12820393467942873, + 0.12827587003509203, + 0.11684217179814975, + 0.115824144333601, + 0.13592431818445525 + ], + [ + 0.1333228088915348, + 0.12750348697106043, + 0.11685294782121976, + 0.12324087197581927, + 0.12668801471590996, + 0.11857536186774571, + 0.12169829259316127, + 0.13211797922849655 + ], + [ + 0.11646071821451187, + 0.12713789443174997, + 0.12590383862455687, + 0.12470094487071037, + 0.12404770155747731, + 0.1215706616640091, + 0.13077143703897795, + 0.12940658008058867 + ], + [ + 0.11766957864165306, + 0.128808772812287, + 0.12169709180792172, + 0.13192020605007806, + 0.11947798853119214, + 0.12368159120281537, + 0.12754168485601744, + 0.12920285885532698 + ], + [ + 0.11972353607416153, + 0.13117474814256033, + 0.12430170178413391, + 0.12345552320281665, + 0.12105988959471385, + 0.12799129262566566, + 0.1246196838716666, + 0.12767339994510016 + ], + [ + 0.1174546331167221, + 0.1302631509800752, + 0.13204246759414673, + 0.13427030170957246, + 0.12365019073088963, + 0.12363155434528987, + 0.11874865864713986, + 0.11993882308403651 + ], + [ + 0.11609988907972972, + 0.1367713063955307, + 0.134916124244531, + 0.12329821661114693, + 0.11552627260486285, + 0.12894495949149132, + 0.11901888375480969, + 0.12542411809166273 + ], + [ + 0.1295727603137493, + 0.12445501610636711, + 0.12729419892032942, + 0.12122323612372081, + 0.12303097918629646, + 0.12580526620149612, + 0.12381022796034813, + 0.12480808670322101 + ], + [ + 0.1273537315428257, + 0.11900346726179123, + 0.12272086987892787, + 0.12362919623653094, + 0.12456123779217403, + 0.13091702262560526, + 0.12144726763168971, + 0.13036698599656424 + ], + [ + 0.12543089936176935, + 0.1266807975868384, + 0.12529757246375084, + 0.12193675090869267, + 0.12841280301411948, + 0.12132296338677406, + 0.12614940231045088, + 0.12476859117547671 + ], + [ + 0.12358718365430832, + 0.13165529444813728, + 0.12561427553494772, + 0.12110686798890431, + 0.12332377831141154, + 0.11689541985591252, + 0.12265941873192787, + 0.13515754168232283 + ], + [ + 0.12490201493104298, + 0.12535703058044115, + 0.12288288896282513, + 0.1235720453162988, + 0.12638425702850023, + 0.12470771744847298, + 0.13143361484011015, + 0.12076019868254662 + ], + [ + 0.12571966523925462, + 0.13027526065707207, + 0.12363024180134137, + 0.12079863622784615, + 0.12396549060940742, + 0.12675606707731882, + 0.12907325848937035, + 0.11978113527099292 + ], + [ + 0.12183178464571635, + 0.13026247918605804, + 0.12239019572734833, + 0.12114463249842326, + 0.12930464868744215, + 0.12590120360255241, + 0.12235495199759801, + 0.1268098863462607 + ], + [ + 0.12087582424283028, + 0.132381501297156, + 0.11640484010179837, + 0.12483899171153705, + 0.1250749702254931, + 0.13166004419326782, + 0.12195277710755666, + 0.12681081394354501 + ], + [ + 0.12601598724722862, + 0.12703165908654532, + 0.12979954356948534, + 0.11729113260904948, + 0.1291932798922062, + 0.12630534544587135, + 0.12265470375617345, + 0.12170811618367831 + ], + [ + 0.12518277267615, + 0.12639442334572473, + 0.1299180748562018, + 0.11909671624501546, + 0.12873251363635063, + 0.12777290865778923, + 0.12318987771868706, + 0.11971248934666316 + ], + [ + 0.12509067356586456, + 0.12398034955064456, + 0.12748890245954195, + 0.1268929714957873, + 0.12167892108360927, + 0.1264516587058703, + 0.12209489569067955, + 0.12632140144705772 + ], + [ + 0.12187027682860692, + 0.12412883838017781, + 0.12828159828980765, + 0.12790234511097273, + 0.12057583406567574, + 0.12852845340967178, + 0.12366065134604771, + 0.12505176290869713 + ], + [ + 0.1260147988796234, + 0.12601004540920258, + 0.1305252139767011, + 0.12630808974305788, + 0.12011448790629704, + 0.12747875601053238, + 0.11977786446611087, + 0.12377050643165906 + ], + [ + 0.12590845301747322, + 0.12769702076911926, + 0.13055924698710442, + 0.12631526589393616, + 0.11883044242858887, + 0.12562002738316855, + 0.12049156179030736, + 0.12457774331172307 + ], + [ + 0.12782422577341399, + 0.12179611499110858, + 0.13047875836491585, + 0.122393732269605, + 0.12181684498985608, + 0.12368230770031612, + 0.1254530275861422, + 0.12655474618077278 + ], + [ + 0.12869844834009805, + 0.12037623549501102, + 0.12877188374598822, + 0.12355334063371022, + 0.12133604163924853, + 0.1279363197584947, + 0.127300463616848, + 0.1220270407696565 + ], + [ + 0.12761169796188673, + 0.12509587158759436, + 0.1257390690346559, + 0.12475279346108437, + 0.13100110739469528, + 0.12227964401245117, + 0.1185175987581412, + 0.12500199427207312 + ], + [ + 0.12815691530704498, + 0.12185247863332431, + 0.1251034140586853, + 0.12850852186481157, + 0.13096537565191588, + 0.12077334895730019, + 0.12014709164698918, + 0.1244926390548547 + ], + [ + 0.12643119941155115, + 0.1281477821369966, + 0.1286927325030168, + 0.12155028432607651, + 0.12146355335911115, + 0.12657606725891432, + 0.12121976539492607, + 0.1259184069931507 + ], + [ + 0.12883725514014563, + 0.12798677509029707, + 0.12542845060427985, + 0.12221631531914075, + 0.12310976535081863, + 0.12573018421729407, + 0.12576384842395782, + 0.12092716867725055 + ], + [ + 0.13032918175061545, + 0.12565343206127486, + 0.1277940832078457, + 0.1244165872534116, + 0.1210847944021225, + 0.1249586654206117, + 0.12145774687329929, + 0.12430527185400327 + ] + ], + "load_balancing_losses": [ + 0.0601530771702528, + 0.060107164084911346, + 0.060138048604130745, + 0.06034043990075588, + 0.060609963908791545, + 0.060501329600811005, + 0.060178323462605474, + 0.06000780127942562, + 0.05998391322791576, + 0.06001683808863163, + 0.060155109688639644, + 0.060315899178385736, + 0.06062373332679272, + 0.061051640659570694, + 0.061039886251091954, + 0.06133884266018867, + 0.061058754846453664, + 0.06172330603003502, + 0.06121929809451103, + 0.060974479839205745, + 0.06157100014388561, + 0.06093739382922649, + 0.060940145328640935, + 0.06091573052108288, + 0.060921508446335794, + 0.060844093933701514, + 0.06042032241821289, + 0.06055892035365105, + 0.060673868656158446, + 0.060558656603097914, + 0.060291763022542, + 0.06040613912045956, + 0.060207856073975566, + 0.06023397259414196, + 0.06024359874427319, + 0.060260246694087985, + 0.060203251987695695, + 0.060237521305680275, + 0.06014332696795464, + 0.060173075273633005, + 0.060148946940898895, + 0.060143569856882094, + 0.060256035253405574, + 0.06011180579662323, + 0.06025103032588959, + 0.06004696860909462, + 0.060220596939325334, + 0.060201894491910934, + 0.060018280521035194, + 0.060075846314430234 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.1315293957789739, + 0.1253383755683899, + 0.12861466531952223, + 0.124490886926651, + 0.12019731352726619, + 0.12489264458417892, + 0.12061314781506856, + 0.12432333330313365 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/temp_10.0/metrics.json b/experiments/exp10_routing_temperature_specialization/results/temp_10.0/metrics.json new file mode 100644 index 0000000..faf45a2 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/temp_10.0/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "temp_10.0", + "description": "Nearly uniform routing (extreme exploration)", + "temperature": 10.0, + "temperature_schedule": null, + "final_metrics": { + "val_loss": 24.534089428797206, + "val_accuracy": 0.01640931313229101, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.068264722824097, + 4.465277051925659, + 3.189622640609741, + 1.8716424226760864, + 0.7630594700574875, + 0.35631293058395386, + 0.22326055467128753, + 0.16183508336544036, + 0.11360260844230652, + 0.09381789118051528, + 0.04979727417230606, + 0.01808085720986128, + 0.01420236025005579, + 0.011518755275756121, + 0.01053143097087741, + 0.010752688348293304, + 0.010078129824250937, + 0.008437707088887691, + 0.008746477123349906, + 0.007960466342046857, + 0.009688302129507064, + 0.0010764821025077253, + 0.0007704558374825865, + 0.0007279628363903612, + 0.0008628889219835401, + 0.0008339468506164849, + 0.0007226858462672681, + 0.0005910470703383907, + 0.0005120214656926692, + 0.00046458949509542435, + 0.0003719866246683523, + 0.0003500081642414443, + 0.0002588351402664557, + 0.00024604533246019855, + 0.0002625833367346786, + 0.00023123554419726133, + 0.0002618829777929932, + 0.00020186282345093786, + 0.00020178912236588076, + 0.0001978141692234203, + 0.00018666159739950673, + 0.00019303785666124895, + 0.0001494130934588611, + 0.00014691594114992767, + 0.00015031059738248586, + 0.00016266530437860637, + 0.00014724691718583926, + 0.00015177241148194297, + 0.00014238867806852794, + 0.00014112655917415396 + ], + "val_losses": [ + 10.773727821377056, + 10.73678692659304, + 10.738793858369753, + 11.27316135568248, + 12.463284465533684, + 14.96090166813072, + 16.701985558014457, + 19.373080169354225, + 20.90962668105486, + 22.901409816404957, + 23.926277147165035, + 24.90034708622909, + 25.251076998222, + 25.363293725273213, + 25.174692524616795, + 24.551542713448352, + 24.366343825107748, + 24.454975970642305, + 24.252385055218483, + 24.44813204654114, + 24.530196011277052, + 24.20904587519885, + 24.06261319406462, + 24.1793480256421, + 24.003676215667184, + 24.185511073459583, + 24.22671034310816, + 24.01071585530527, + 24.299354310591735, + 24.347810246497918, + 24.236002878249746, + 24.274673030570202, + 24.32213894469999, + 24.389316444262178, + 24.562161159178395, + 24.441820589476677, + 24.48110008239746, + 24.64803719941803, + 24.587256670840638, + 24.515058153509674, + 24.421561318657, + 24.61476015286395, + 24.602063903539005, + 24.481866513039957, + 24.464706468076674, + 24.54651286660993, + 24.53059490999148, + 24.523314978124397, + 24.47907891964323, + 24.534089428797206 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101 + ], + "val_perplexities": [ + 47749.68845462194, + 46017.9550548654, + 46110.40268818106, + 78681.34411604276, + 258663.80971853615, + 3143670.637977793, + 17929994.84392696, + 259192132.59894717, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.02979321082433065, + 0.07130091985066732, + 0.10871097246805826, + 0.14937918186187743, + 0.1926921010017395, + 0.23119659821192423, + 0.2686460773150126, + 0.3073225458463033, + 0.34438505172729494, + 0.3880822817484538, + 0.4246596693992615, + 0.4627709786097209, + 0.4995801091194153, + 0.5380072553952535, + 0.5748650272687276, + 0.613588273525238, + 0.6509426514307658, + 0.689512832959493, + 0.7270946860313415, + 0.7658172170321147, + 0.8030741612116495, + 0.8418009003003438, + 0.8790761868158976, + 0.917706286907196, + 0.9549813667933146, + 0.9936194062232971, + 1.0318836450576783, + 1.0706943233807882, + 1.1171416680018107, + 1.1557463645935058, + 1.1926999648412069, + 1.231392467021942, + 1.269055755933126, + 1.3080167373021443, + 1.3451811989148459, + 1.3838251908620198, + 1.4216639916102092, + 1.4605100512504579, + 1.4973213076591492, + 1.5356154759724936, + 1.581496294339498, + 1.6195396463076273, + 1.6565751830736797, + 1.6945164561271668, + 1.7311775048573812, + 1.7692173798878987, + 1.8058592518170675, + 1.8530375997225443, + 1.8898051182428997, + 1.9280782659848532 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0, + 10.0 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12475734949111938, + 0.11415268604954083, + 0.13942573964595795, + 0.1281355284154415, + 0.13373866180578867, + 0.12607929358879724, + 0.1205051839351654, + 0.11320533975958824 + ], + [ + 0.1301103321214517, + 0.11779473101099332, + 0.13712164387106895, + 0.12251835688948631, + 0.12325665106376012, + 0.12545251349608103, + 0.13187644382317862, + 0.11186910048127174 + ], + [ + 0.13521152858932814, + 0.11539142330487569, + 0.13321960593263307, + 0.1236368256310622, + 0.1149498571952184, + 0.12498927985628445, + 0.134603434552749, + 0.11799781148632367 + ], + [ + 0.13625483959913254, + 0.12545987839500108, + 0.13867557421326637, + 0.12440796444813411, + 0.1016904500623544, + 0.12077739586432774, + 0.1335135499636332, + 0.11922009910146396 + ], + [ + 0.1324511058628559, + 0.12465024242798488, + 0.14583221822977066, + 0.1283962478240331, + 0.09982325633366902, + 0.12210110823313396, + 0.13640265663464865, + 0.11034289747476578 + ], + [ + 0.12888634825746217, + 0.1292747420569261, + 0.14769109835227331, + 0.1288108949859937, + 0.11176889886458714, + 0.11718731373548508, + 0.13174274687965712, + 0.10463768988847733 + ], + [ + 0.12434182191888492, + 0.12352693701783816, + 0.14615780239303908, + 0.13107160106301308, + 0.1275394670665264, + 0.11386569465200107, + 0.13063726077477136, + 0.1028591866294543 + ], + [ + 0.12477055937051773, + 0.11920305093129475, + 0.12877154598633447, + 0.12979640811681747, + 0.13505701224009195, + 0.12404882535338402, + 0.12141157065828641, + 0.11694079140822093 + ], + [ + 0.1258808026711146, + 0.11920333902041118, + 0.12159875283638637, + 0.1303616613149643, + 0.13190317898988724, + 0.1250385989745458, + 0.11742920676867168, + 0.1285842371483644 + ], + [ + 0.12281344706813495, + 0.12618222584327063, + 0.1248336856563886, + 0.1272098496556282, + 0.13572068760792413, + 0.12408339853088061, + 0.1124403749903043, + 0.1267160860200723 + ], + [ + 0.12377495691180229, + 0.13059894988934198, + 0.1284498025973638, + 0.12505865966280302, + 0.13317212089896202, + 0.12358132253090541, + 0.11173214390873909, + 0.12363182380795479 + ], + [ + 0.12657488510012627, + 0.1252892812093099, + 0.12383341044187546, + 0.1271346261103948, + 0.12945905327796936, + 0.12335907792051633, + 0.11511347567041715, + 0.12923596675197283 + ], + [ + 0.12970876569549242, + 0.12276192754507065, + 0.12170528123776118, + 0.12514248862862587, + 0.125508235146602, + 0.12501083314418793, + 0.11653861030936241, + 0.1336236260831356 + ], + [ + 0.12198760484655698, + 0.12941299999753633, + 0.11910657708843549, + 0.12415915230909984, + 0.12574579442540804, + 0.12231840938329697, + 0.12500870848695436, + 0.13226050635178885 + ], + [ + 0.12240646531184514, + 0.12910514821608862, + 0.11905434479316075, + 0.1270875222980976, + 0.12028478334347407, + 0.1266894874473413, + 0.12301312759518623, + 0.13235888381799063 + ], + [ + 0.12463104352355003, + 0.12012861917416255, + 0.13025695209701857, + 0.1250311533610026, + 0.12640166655182838, + 0.12276270364721616, + 0.1196834755440553, + 0.13110414519906044 + ], + [ + 0.1183101733525594, + 0.11791704470912616, + 0.1274061774214109, + 0.12678576757510504, + 0.13083267832795778, + 0.13351508354147276, + 0.11248345300555229, + 0.1327494097252687 + ], + [ + 0.12454496324062347, + 0.125611229489247, + 0.12495716537038486, + 0.1176088775197665, + 0.1359927405913671, + 0.11699694395065308, + 0.11986668407917023, + 0.13442117720842361 + ], + [ + 0.116949662566185, + 0.136666605869929, + 0.11298202226559322, + 0.12915883089105287, + 0.13171145568291345, + 0.12296587104598682, + 0.12201270585258801, + 0.12755261237422624 + ], + [ + 0.12290549899140994, + 0.12719006836414337, + 0.13708969950675964, + 0.11931110794345538, + 0.1265569031238556, + 0.12246136491497357, + 0.11677561079462369, + 0.12770951166749 + ], + [ + 0.1197788454592228, + 0.12957320859034857, + 0.12590826178590456, + 0.11612184097369511, + 0.13170770183205605, + 0.1293474162618319, + 0.12048567955692609, + 0.12707680215438208 + ], + [ + 0.12340472886959712, + 0.1353023685514927, + 0.13883400335907936, + 0.11019724359114964, + 0.12916767597198486, + 0.11399855216344197, + 0.1270280902584394, + 0.122067096332709 + ], + [ + 0.1293817770977815, + 0.1283504217863083, + 0.14040563255548477, + 0.1138590710858504, + 0.1205001895626386, + 0.12060712277889252, + 0.12035812934239705, + 0.12653742109735808 + ], + [ + 0.1260175903638204, + 0.1279651535054048, + 0.13730273519953093, + 0.11376505717635155, + 0.1305093914270401, + 0.11493492498993874, + 0.12827244897683462, + 0.1212324673930804 + ], + [ + 0.12612541392445564, + 0.12312224507331848, + 0.14401683459679285, + 0.12110670407613118, + 0.12278088554739952, + 0.12290563310186069, + 0.11606776714324951, + 0.12387430171171825 + ], + [ + 0.1359871986011664, + 0.12329619626204173, + 0.12476186330119769, + 0.1101313903927803, + 0.12323386097947757, + 0.12056959296266238, + 0.11972042794028918, + 0.1422992448012034 + ], + [ + 0.1268480954070886, + 0.13417003552118936, + 0.12457818662126859, + 0.11963048328955968, + 0.12197400629520416, + 0.12433907141288121, + 0.1163644976913929, + 0.13209537665049234 + ], + [ + 0.12879868845144907, + 0.12528249869743982, + 0.12185345714290936, + 0.13196158160765967, + 0.12073765695095062, + 0.11881024638811748, + 0.13155735656619072, + 0.12099827577670415 + ], + [ + 0.12847638751069704, + 0.12298012649019559, + 0.12177741900086403, + 0.12294192736347516, + 0.11939886584877968, + 0.12320824215809505, + 0.13010661055644354, + 0.13111018389463425 + ], + [ + 0.12147548794746399, + 0.12652823328971863, + 0.11930768564343452, + 0.11896049603819847, + 0.12575763339797655, + 0.12381692851583163, + 0.13224623600641885, + 0.13190707812706629 + ], + [ + 0.12256162986159325, + 0.12961395954092345, + 0.11774891863266627, + 0.12076971555749576, + 0.12150900935133298, + 0.12269485990206401, + 0.13349668309092522, + 0.13160498564442 + ], + [ + 0.12418140346805255, + 0.12679996838172278, + 0.1345394862194856, + 0.12380540867646535, + 0.12605996057391167, + 0.11515754585464795, + 0.12190447996060054, + 0.12755152707298598 + ], + [ + 0.12563188249866167, + 0.1264975145459175, + 0.12911412492394447, + 0.12607293700178465, + 0.12585424507657686, + 0.1119723121325175, + 0.12872465451558432, + 0.12613212938110033 + ], + [ + 0.12813166032234827, + 0.12711484854420027, + 0.12829082210858664, + 0.12179117401440938, + 0.1247620830933253, + 0.11850947886705399, + 0.12621763969461122, + 0.12518205990393957 + ], + [ + 0.12180225923657417, + 0.13502820457021394, + 0.12743196512262026, + 0.11522972956299782, + 0.13340910648306212, + 0.11625270421306293, + 0.12462474157412846, + 0.1262210545440515 + ], + [ + 0.11791208510597546, + 0.12558827300866446, + 0.1289877158900102, + 0.12567409252127013, + 0.11663069079319636, + 0.12829248855511347, + 0.12422366191943486, + 0.13269076496362686 + ], + [ + 0.11928222576777141, + 0.12646425515413284, + 0.12420725698272388, + 0.12440399453043938, + 0.12493415176868439, + 0.13043782487511635, + 0.12195932616790135, + 0.12831074123581251 + ], + [ + 0.124223576237758, + 0.12304221590360005, + 0.1246515562136968, + 0.12422652915120125, + 0.12839695562918982, + 0.11355058724681537, + 0.12485586106777191, + 0.1370524726808071 + ], + [ + 0.12035450090964635, + 0.12361603230237961, + 0.12485636894901593, + 0.12080579871932666, + 0.13541978721817335, + 0.11755349238713582, + 0.12380445127685864, + 0.13358933354417482 + ], + [ + 0.1320065644880136, + 0.11922232309977214, + 0.12546679750084877, + 0.12354664877057076, + 0.1200331375002861, + 0.13046452403068542, + 0.12505543480316797, + 0.12420434132218361 + ], + [ + 0.1268814653158188, + 0.12331616505980492, + 0.13032426064213118, + 0.11926669006546338, + 0.12287125115593274, + 0.12446459755301476, + 0.12445826455950737, + 0.12841705729564032 + ], + [ + 0.13259093960126242, + 0.12501884500185648, + 0.12780598054329553, + 0.12250152230262756, + 0.12661489844322205, + 0.1222449317574501, + 0.1196626362701257, + 0.12356001387039821 + ], + [ + 0.12960620348652205, + 0.12394314383467038, + 0.12426593527197838, + 0.1225111149251461, + 0.12416996931036313, + 0.12882959842681885, + 0.11837407201528549, + 0.12829973424474397 + ], + [ + 0.12571200852592787, + 0.12183291837573051, + 0.12704839060703912, + 0.1237550787627697, + 0.12681583066781363, + 0.12810961281259856, + 0.12831582749883333, + 0.11841009681423505 + ], + [ + 0.12985892966389656, + 0.12586528062820435, + 0.12964728847146034, + 0.11585574597120285, + 0.13315641755859056, + 0.11906841521461804, + 0.12670340140660605, + 0.11984427521626155 + ], + [ + 0.1214165488878886, + 0.12919339165091515, + 0.12395029639204343, + 0.12528432781497637, + 0.12205278252561887, + 0.12408674384156863, + 0.12896213556329408, + 0.1250535361468792 + ], + [ + 0.12035524969299634, + 0.12938466916481653, + 0.11548375338315964, + 0.11844343692064285, + 0.12930243338147798, + 0.12833304951588312, + 0.13140876218676567, + 0.12728841975331306 + ], + [ + 0.12706941117842993, + 0.12197142591079076, + 0.12877880533536276, + 0.12523424749573073, + 0.12844732279578844, + 0.12149745598435402, + 0.11527417227625847, + 0.13172692681352297 + ], + [ + 0.12883985166748366, + 0.12065920109550159, + 0.13134139652053514, + 0.12599028398593268, + 0.12598469232519469, + 0.12111465136210124, + 0.11948269108931224, + 0.12658701092004776 + ], + [ + 0.12157981842756271, + 0.1218073417743047, + 0.12362908696134885, + 0.12341074893871944, + 0.13080820441246033, + 0.12536800901095072, + 0.12306187922755878, + 0.13033468897144 + ] + ], + "load_balancing_losses": [ + 0.06002229750156403, + 0.060016362369060515, + 0.060022668540477754, + 0.06006469540297985, + 0.06015543900430202, + 0.0601876262575388, + 0.06013292707502842, + 0.06008040122687817, + 0.06003468371927738, + 0.060033298656344415, + 0.06006777808070183, + 0.06002825312316418, + 0.0599848248064518, + 0.0600111123174429, + 0.0599911343306303, + 0.05997913852334023, + 0.0600641205906868, + 0.06007573455572128, + 0.06015910059213638, + 0.060127760469913485, + 0.060149314999580386, + 0.060190624371170995, + 0.06027577221393585, + 0.060223388671875, + 0.06014956124126911, + 0.06019870862364769, + 0.06014889590442181, + 0.06027382574975491, + 0.060243554413318634, + 0.06029498241841793, + 0.05999616086483002, + 0.06011883318424225, + 0.06016740649938583, + 0.0600437019020319, + 0.06018727235496044, + 0.06010599434375763, + 0.060181071236729625, + 0.060141276195645334, + 0.06004407815635204, + 0.06006361022591591, + 0.060180061310529706, + 0.060103414580225945, + 0.06008672639727593, + 0.06010033488273621, + 0.06003771759569645, + 0.06004065088927746, + 0.06008836440742016, + 0.06002587974071503, + 0.06009359806776047, + 0.06010164059698582 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.11995649586121242, + 0.12131934861342113, + 0.12244458869099617, + 0.12273925791184108, + 0.13267036279042563, + 0.12591564282774925, + 0.12301545962691307, + 0.13193860525886217 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/temp_2.0/metrics.json b/experiments/exp10_routing_temperature_specialization/results/temp_2.0/metrics.json new file mode 100644 index 0000000..f3c5a9c --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/temp_2.0/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "temp_2.0", + "description": "Softer routing (more exploration)", + "temperature": 2.0, + "temperature_schedule": null, + "final_metrics": { + "val_loss": 24.250242354591826, + "val_accuracy": 0.01640931313229101, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.074095726013184, + 4.4681319236755375, + 3.1874616146087646, + 1.8671492338180542, + 0.7619717627763748, + 0.356376177072525, + 0.22397017776966094, + 0.1624505713582039, + 0.11418198570609092, + 0.09382357224822044, + 0.04969589747488499, + 0.017800973262637854, + 0.01395582091063261, + 0.011227598693221808, + 0.010429644864052534, + 0.010469185328111053, + 0.009674670547246933, + 0.008378511853516101, + 0.008513005962595344, + 0.008007832616567612, + 0.009628096595406532, + 0.0010730466921813786, + 0.0008426614396739752, + 0.0008448214735835791, + 0.0009014891751576215, + 0.0008854725572746247, + 0.0007156451552873478, + 0.0005760653264587745, + 0.0005527318600798026, + 0.0004627990041626617, + 0.00038955551281105727, + 0.0003273580994573422, + 0.0002806221615173854, + 0.00029873685271013527, + 0.00028062340279575435, + 0.0002762895950581878, + 0.0002597459853859618, + 0.00023212085070554168, + 0.0002304781533894129, + 0.00022262510610744358, + 0.00020723446941701696, + 0.00021982782345730811, + 0.00015516662097070366, + 0.000162401823035907, + 0.00016333321109414102, + 0.00017892217292683199, + 0.0001647739074542187, + 0.0001759289880283177, + 0.0001543503618449904, + 0.0001639068650547415 + ], + "val_losses": [ + 10.773730109521441, + 10.737245370979444, + 10.740853990345878, + 11.277360751013873, + 12.462699910355962, + 14.959907278997738, + 16.71073891333051, + 19.36057002468581, + 20.888871364795698, + 22.859370309135517, + 23.88046497169737, + 24.74393996248818, + 25.059663071649236, + 25.051000001994964, + 24.872414814709774, + 24.85287452670795, + 24.578729642996098, + 24.32579605318211, + 24.298237008677777, + 24.08863448422705, + 24.123871476405924, + 23.841968185910066, + 23.618393294802825, + 23.950490493235232, + 23.76114335076969, + 23.81839840487962, + 23.78829224792049, + 23.935886618947816, + 24.06441830072302, + 24.182095578197035, + 24.06601018197966, + 24.09492320138237, + 24.195577520363745, + 24.183269123306545, + 24.219877384576698, + 24.3051082557166, + 24.317483976114765, + 24.32363692105027, + 24.270709223123827, + 24.283330095109164, + 24.292315601038847, + 24.287409098325263, + 24.410878737486716, + 24.304125492648605, + 24.263806218393277, + 24.34350709274885, + 24.274643759845425, + 24.326952425414177, + 24.31762618479375, + 24.250242354591826 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.01640931313229101, + 0.01640931313229101, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101 + ], + "val_perplexities": [ + 47749.79771292846, + 46039.05656459891, + 46205.49412000883, + 79012.45292699913, + 258512.65063399286, + 3140546.1597887506, + 18087631.37738077, + 255969799.47547528, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.02928327719370524, + 0.07093901634216308, + 0.1082366426785787, + 0.14678502082824707, + 0.1840856631596883, + 0.22276692390441893, + 0.2603001117706299, + 0.2989465037981669, + 0.33612744410832723, + 0.37480958302815753, + 0.41186049381891887, + 0.4502035935719808, + 0.4966278274854024, + 0.5350424925486247, + 0.5724462191263835, + 0.6108430425326029, + 0.6483000834782918, + 0.687307580312093, + 0.7249638239542643, + 0.763864545027415, + 0.8012601852416992, + 0.8403975566228231, + 0.87832190990448, + 0.917172356446584, + 0.9548972924550374, + 0.9941118001937866, + 1.0312793930371602, + 1.0698648691177368, + 1.1073028564453125, + 1.145914328098297, + 1.1829926013946532, + 1.221861986319224, + 1.259406816959381, + 1.2980709552764893, + 1.3349693814913433, + 1.378048551082611, + 1.4202849706013998, + 1.4589483141899109, + 1.496508208910624, + 1.5360045870145163, + 1.5732601881027222, + 1.6116382280985515, + 1.6488594651222228, + 1.6873130957285563, + 1.7242108980814617, + 1.7626423994700113, + 1.7995170632998148, + 1.8379081646601358, + 1.8749777237574259, + 1.9134431719779967 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.1245174730817477, + 0.11406407629450162, + 0.13931066046158472, + 0.12827964747945467, + 0.13375522196292877, + 0.12656426802277565, + 0.12063625454902649, + 0.11287217463056247 + ], + [ + 0.13121316457788149, + 0.11531487728158633, + 0.13586007555325827, + 0.1220346949994564, + 0.12065183495481809, + 0.12843538199861845, + 0.13296451171239218, + 0.11352522547046344 + ], + [ + 0.13410628711183867, + 0.11401279022296269, + 0.13167570531368256, + 0.12109431127707164, + 0.11725631107886632, + 0.1318149690826734, + 0.1333083969851335, + 0.11673099299271901 + ], + [ + 0.1399474876622359, + 0.11670397718747456, + 0.13294423619906107, + 0.12612041706840196, + 0.11090035364031792, + 0.12656626477837563, + 0.13654237737258276, + 0.11027463525533676 + ], + [ + 0.13737289980053902, + 0.12233941505352657, + 0.13478067020575205, + 0.12593267733852068, + 0.10700869560241699, + 0.12248080720504124, + 0.13837861145536104, + 0.11170598119497299 + ], + [ + 0.13074414556225142, + 0.12472593908508618, + 0.13176733255386353, + 0.12880188102523485, + 0.11014247685670853, + 0.12091415996352832, + 0.1340667928258578, + 0.11883703370889027 + ], + [ + 0.1276616429289182, + 0.12514040619134903, + 0.12751680860916773, + 0.1265499691168467, + 0.12382955600817998, + 0.12388923019170761, + 0.12618380909164748, + 0.11922835186123848 + ], + [ + 0.1308002918958664, + 0.1262262836098671, + 0.12252574041485786, + 0.1297115907073021, + 0.1252483588953813, + 0.12491714954376221, + 0.11918970445791881, + 0.12138067185878754 + ], + [ + 0.12953696896632513, + 0.12736436227957407, + 0.12352116405963898, + 0.12935289119680723, + 0.12269865224758784, + 0.12306055799126625, + 0.12102520341674487, + 0.1234399676322937 + ], + [ + 0.13130362828572592, + 0.12377394487460454, + 0.12103419999281566, + 0.12481699883937836, + 0.12288371349374454, + 0.12421268845597903, + 0.12616640826066336, + 0.12580818434556326 + ], + [ + 0.12551182880997658, + 0.12528547396262488, + 0.11757995809117953, + 0.12393354251980782, + 0.11792046576738358, + 0.12722831591963768, + 0.1227088322242101, + 0.13983136291305223 + ], + [ + 0.1330226163069407, + 0.12466675663987796, + 0.1339919144908587, + 0.11846163744727771, + 0.12841484198967615, + 0.12116376186410587, + 0.12407143786549568, + 0.11620679994424184 + ], + [ + 0.13542568435271582, + 0.12381013482809067, + 0.13177195812265077, + 0.11735910301407178, + 0.12286354725559552, + 0.12672228614489237, + 0.12073874101042747, + 0.12130832175413768 + ], + [ + 0.1276869960129261, + 0.12634232640266418, + 0.13974490761756897, + 0.12125503644347191, + 0.12908870726823807, + 0.11349999407927196, + 0.1323202376564344, + 0.11006155734260877 + ], + [ + 0.1212860532104969, + 0.1411466325322787, + 0.13094734400510788, + 0.11165967707832654, + 0.12907678882280985, + 0.11349489291508992, + 0.12447169050574303, + 0.12791667133569717 + ], + [ + 0.12637867281834284, + 0.13124189898371696, + 0.12512758374214172, + 0.11593564351399739, + 0.1268130255242189, + 0.12352017934123675, + 0.11832461133599281, + 0.13265814632177353 + ], + [ + 0.11763261258602142, + 0.1365760937333107, + 0.12776738653580347, + 0.11447113379836082, + 0.13463561236858368, + 0.12397951756914456, + 0.11094118778904279, + 0.1339962345858415 + ], + [ + 0.12347735464572906, + 0.12095268939932187, + 0.13229694838325182, + 0.12163240710894267, + 0.1248765562971433, + 0.1319022923707962, + 0.11414576694369316, + 0.1307157687842846 + ], + [ + 0.11588245009382565, + 0.14677994698286057, + 0.11366373797257741, + 0.12813347453872362, + 0.11937368288636208, + 0.1387942706545194, + 0.11703516667087872, + 0.1203370342652003 + ], + [ + 0.129832544674476, + 0.1231203240652879, + 0.13753188028931618, + 0.11606638630231221, + 0.12083546941479047, + 0.12472109248240788, + 0.11878509074449539, + 0.12910698105891547 + ], + [ + 0.12559781968593597, + 0.1246491049726804, + 0.13683745513359705, + 0.11824160069227219, + 0.11790597438812256, + 0.1283221865693728, + 0.11877486854791641, + 0.1296707640091578 + ], + [ + 0.12176170075933139, + 0.12655443822344145, + 0.12426854421695073, + 0.13259323686361313, + 0.12473662073413531, + 0.12401948869228363, + 0.1188626450796922, + 0.12720309322079024 + ], + [ + 0.1201527367035548, + 0.12396369501948357, + 0.12388773386677106, + 0.13527012492219606, + 0.11901206771532695, + 0.12229867279529572, + 0.1266723374525706, + 0.1287424030403296 + ], + [ + 0.11328563715020816, + 0.13253434002399445, + 0.12043049062291782, + 0.12306247279047966, + 0.1247171846528848, + 0.1245497539639473, + 0.12263259167472522, + 0.13878730436166128 + ], + [ + 0.12252657363812129, + 0.13365487133463225, + 0.12510309368371964, + 0.12738022953271866, + 0.11276384443044662, + 0.12918535619974136, + 0.12390327453613281, + 0.12548253933588663 + ], + [ + 0.12353649859627087, + 0.12630368024110794, + 0.12104234596093495, + 0.1243001086016496, + 0.12809344629446665, + 0.12785459061463675, + 0.11856027816732724, + 0.1303088180720806 + ], + [ + 0.11664886275927226, + 0.12162637089689572, + 0.12251690030097961, + 0.12297582750519116, + 0.1262002351383368, + 0.1313761187096437, + 0.1278794047733148, + 0.13077604894836745 + ], + [ + 0.12413136909405391, + 0.12647772828737894, + 0.12584826598564783, + 0.12634766598542532, + 0.12213055541117986, + 0.1243738519648711, + 0.12425843502084415, + 0.1264319084584713 + ], + [ + 0.12163005024194717, + 0.13785199324289957, + 0.12464820469419162, + 0.1260561135907968, + 0.12239678824941318, + 0.12511069824298224, + 0.12235800176858902, + 0.11994792148470879 + ], + [ + 0.12900185585021973, + 0.12147142613927524, + 0.12872044493754706, + 0.12204616144299507, + 0.11932169521848361, + 0.1307222085694472, + 0.11773818358778954, + 0.13097780818740526 + ], + [ + 0.12296098222335179, + 0.1283737557629744, + 0.12514369189739227, + 0.12623755385478339, + 0.12537351250648499, + 0.13383686915040016, + 0.11299574499328931, + 0.12507766236861548 + ], + [ + 0.11777730286121368, + 0.12679906810323396, + 0.12904758751392365, + 0.12999804938832918, + 0.1216236191491286, + 0.12349878624081612, + 0.1296217106282711, + 0.12163363893826802 + ], + [ + 0.11433746044834454, + 0.12620082373420397, + 0.133262999355793, + 0.1294241026043892, + 0.1217850757141908, + 0.1280668651064237, + 0.12287040924032529, + 0.12405202413598697 + ], + [ + 0.12194746732711792, + 0.13035698359211287, + 0.12385719145337741, + 0.12690474092960358, + 0.12017568325002988, + 0.12677250429987907, + 0.11918358132243156, + 0.13080160692334175 + ], + [ + 0.11847771952549617, + 0.12834900741775832, + 0.12210531781117122, + 0.12877421453595161, + 0.1250902165969213, + 0.12721621866027513, + 0.11974218487739563, + 0.13024490078290304 + ], + [ + 0.12808279568950334, + 0.12343575308720271, + 0.12495943655570348, + 0.12048824007312457, + 0.12543574596444765, + 0.12858771656950316, + 0.11856227368116379, + 0.13044780989487967 + ], + [ + 0.1258824703594049, + 0.12397458652655284, + 0.1190020168821017, + 0.12439048538605373, + 0.1278511422375838, + 0.13025615364313126, + 0.12203493341803551, + 0.1266079843044281 + ], + [ + 0.11983991290132205, + 0.1234966367483139, + 0.12777634213368097, + 0.12969054654240608, + 0.12375900521874428, + 0.12586947157979012, + 0.1168056031068166, + 0.1327622483174006 + ], + [ + 0.11898627628882726, + 0.1275850273668766, + 0.12886973842978477, + 0.1331743225455284, + 0.11779008060693741, + 0.12448523069421451, + 0.1182781308889389, + 0.13083097835381827 + ], + [ + 0.1235158604880174, + 0.12121846651037534, + 0.12857254594564438, + 0.12210141991575559, + 0.12765698259075484, + 0.12384665384888649, + 0.11834277088443439, + 0.13474509492516518 + ], + [ + 0.12135905275742213, + 0.12286672741174698, + 0.12445233389735222, + 0.12475640575091045, + 0.12828450401624045, + 0.12452368686596553, + 0.12025310347477595, + 0.13350394368171692 + ], + [ + 0.12949753304322562, + 0.12319017698367436, + 0.1279755396147569, + 0.12648415192961693, + 0.12297424549857776, + 0.12110309923688571, + 0.11881105974316597, + 0.12996396919091543 + ], + [ + 0.1270089695851008, + 0.12242921565969785, + 0.1289629489183426, + 0.12572596967220306, + 0.12833455950021744, + 0.1189367746313413, + 0.11713863412539165, + 0.1314626932144165 + ], + [ + 0.1219395101070404, + 0.12280197317401569, + 0.12985733027259508, + 0.12590260182817778, + 0.12031196430325508, + 0.12728221590320268, + 0.1217620534201463, + 0.13014214237531027 + ], + [ + 0.11909300088882446, + 0.1192251704633236, + 0.1298979806403319, + 0.12453357130289078, + 0.12398965160051982, + 0.12804956237475076, + 0.12516246115167937, + 0.13004836812615395 + ], + [ + 0.1225983922680219, + 0.12321280563871066, + 0.12887195373574892, + 0.127829788873593, + 0.12427625680963199, + 0.12272781257828076, + 0.12278018270929654, + 0.12770257890224457 + ], + [ + 0.11967643350362778, + 0.12622199083367983, + 0.12935252984364828, + 0.1275608316063881, + 0.12684360146522522, + 0.11924727633595467, + 0.12436794117093086, + 0.12672916303078333 + ], + [ + 0.12299285208185513, + 0.1286910424629847, + 0.1278941072523594, + 0.12266997123757999, + 0.12076893076300621, + 0.12417774150768916, + 0.12182774891455968, + 0.13097737729549408 + ], + [ + 0.1256938117245833, + 0.12827380249897638, + 0.13194265713294348, + 0.12549796079595885, + 0.11927563076217969, + 0.11971810708443324, + 0.1183239755531152, + 0.13127383217215538 + ], + [ + 0.12895027299722037, + 0.12499265124400456, + 0.12619781494140625, + 0.1231600505610307, + 0.12525979429483414, + 0.12456691016753514, + 0.12507760773102441, + 0.12179466336965561 + ] + ], + "load_balancing_losses": [ + 0.06011377274990082, + 0.06007998958230019, + 0.06010848172008991, + 0.060275179147720334, + 0.060526183992624286, + 0.060488412901759145, + 0.060212722048163414, + 0.06003621481359005, + 0.05998955182731151, + 0.06000234745442867, + 0.06004708334803581, + 0.060166075080633166, + 0.06036680713295937, + 0.06075631454586983, + 0.060696480050683024, + 0.06080991327762604, + 0.06104581765830517, + 0.06085282489657402, + 0.06125308088958263, + 0.06112743839621544, + 0.06076779589056969, + 0.06055579483509064, + 0.06095332615077496, + 0.060868193954229356, + 0.0607514038681984, + 0.06071991883218288, + 0.06068772189319134, + 0.060609794408082965, + 0.06040669232606888, + 0.06033438295125961, + 0.060364256426692006, + 0.06039164587855339, + 0.060282951965928075, + 0.06028437204658985, + 0.06029377542436123, + 0.06033190116286278, + 0.06015363447368145, + 0.06024842411279678, + 0.06008845344185829, + 0.0602034155279398, + 0.060144612193107606, + 0.06013825386762619, + 0.06014916449785233, + 0.060094399750232695, + 0.06013817973434925, + 0.06013930328190327, + 0.060143918916583064, + 0.06004832312464714, + 0.060126136243343356, + 0.06007350608706474 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.13007571548223495, + 0.12459517021973927, + 0.12555722147226334, + 0.1227981982131799, + 0.1259324513375759, + 0.12495209897557895, + 0.12609715511401495, + 0.11999174828330676 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/temp_3.0/metrics.json b/experiments/exp10_routing_temperature_specialization/results/temp_3.0/metrics.json new file mode 100644 index 0000000..4c6988a --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/temp_3.0/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "temp_3.0", + "description": "Soft routing (high exploration)", + "temperature": 3.0, + "temperature_schedule": null, + "final_metrics": { + "val_loss": 24.333604981115766, + "val_accuracy": 0.01640931313229101, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.070771646499634, + 4.4658355712890625, + 3.1863182544708253, + 1.8671497702598572, + 0.7618836134672164, + 0.3560397356748581, + 0.22378090918064117, + 0.16204893290996553, + 0.11390570029616356, + 0.09375766590237618, + 0.04976499788463116, + 0.017738218512386085, + 0.014075941033661365, + 0.011220329627394676, + 0.010278751188889147, + 0.010544700501486658, + 0.009907942684367298, + 0.00829516933299601, + 0.008591719483956695, + 0.0079573642462492, + 0.009623420517891645, + 0.0011046744999475776, + 0.0008535285945981741, + 0.0008028231910429895, + 0.000838369713164866, + 0.0008390388684347272, + 0.0007787981070578098, + 0.0005580487166298553, + 0.0005401516682468355, + 0.00047382304910570383, + 0.0003791542141698301, + 0.0003408290926017798, + 0.0002655111558851786, + 0.0002623992506414652, + 0.00026915718626696616, + 0.00025941182975657283, + 0.00024518739082850515, + 0.0002390200039371848, + 0.00020633784733945504, + 0.00021585724898613988, + 0.00018516206910135224, + 0.00019820798916043715, + 0.00014632271049777045, + 0.00015526828065048903, + 0.00016252591012744233, + 0.00016448518144898116, + 0.00015370284527307376, + 0.00015332770854001865, + 0.000159021420404315, + 0.00015425651508849114 + ], + "val_losses": [ + 10.773820947842548, + 10.737065254588853, + 10.738510714824123, + 11.276096684351405, + 12.461579117252633, + 14.960970652819523, + 16.71199899006227, + 19.370140817055855, + 20.906520917643086, + 22.884732330645775, + 23.880305280112545, + 24.815036322118537, + 25.04758594903845, + 25.19368578098688, + 24.96726939197985, + 24.822999334166834, + 24.61766196476697, + 23.85886660060276, + 23.757654554009857, + 24.15690976065376, + 24.168269457328446, + 23.75375895955116, + 23.952565034792197, + 24.036626707959933, + 24.009804985127264, + 23.878146309734653, + 23.96884984262419, + 24.33934579047213, + 24.289113418794773, + 24.37410675426254, + 24.168671712437284, + 24.16249053469816, + 24.214974278696012, + 24.185493186168873, + 24.30271199903724, + 24.37023561160893, + 24.38985477487948, + 24.41771251826741, + 24.328815601739784, + 24.408404784994495, + 24.313341551871687, + 24.470443058350902, + 24.49363865532218, + 24.39710278730089, + 24.362548106971982, + 24.45015034490255, + 24.334769589319666, + 24.412221315471527, + 24.430267428340844, + 24.333604981115766 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016402398124649928, + 0.016395483117008846, + 0.016395483117008846, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101 + ], + "val_perplexities": [ + 47754.135421397994, + 46030.76492265698, + 46097.34867344071, + 78912.63901834222, + 258223.07374624163, + 3143887.5105988095, + 18110437.546556007, + 258431394.19507945, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.0289901336034139, + 0.0700024406115214, + 0.10710208813349406, + 0.14583977460861205, + 0.18310449918111166, + 0.22192178964614867, + 0.2593983252843221, + 0.29810545841852826, + 0.3354979435602824, + 0.3743580420811971, + 0.41166659196217853, + 0.4501521587371826, + 0.4871056834856669, + 0.5254133065541585, + 0.5624098141988119, + 0.6006175398826599, + 0.6376061836878458, + 0.6758050799369812, + 0.7127941687901814, + 0.7602055470148722, + 0.7975546201070149, + 0.8364224751790365, + 0.8735642353693645, + 0.9119723240534464, + 0.9488789399464925, + 0.9871223052342732, + 1.0240062673886616, + 1.0622629165649413, + 1.0991905848185222, + 1.1373241623242696, + 1.1741138378779092, + 1.2125749150911966, + 1.2495089809099833, + 1.2876089096069336, + 1.3242868026097616, + 1.3624726017316182, + 1.4092057347297668, + 1.4482465306917827, + 1.4855604648590088, + 1.5238747358322144, + 1.5607340931892395, + 1.5986498514811198, + 1.6353978355725607, + 1.6733405033747355, + 1.7099307775497437, + 1.7478296200434367, + 1.7843554139137268, + 1.822295618057251, + 1.858902668952942, + 1.8967880765597025 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12472607443730037, + 0.11413564657171567, + 0.13908488924304643, + 0.1281932753821214, + 0.13386016835769018, + 0.12641903261343637, + 0.1205245020488898, + 0.11305620521306992 + ], + [ + 0.13059944783647856, + 0.11570687095324199, + 0.13596310590704283, + 0.12195512404044469, + 0.12098198384046555, + 0.1288033090531826, + 0.13332026079297066, + 0.1126696728169918 + ], + [ + 0.13310648004213968, + 0.11378111317753792, + 0.13221190248926482, + 0.12288440515597661, + 0.11897570143143336, + 0.12832148373126984, + 0.13374246781071028, + 0.11697621022661527 + ], + [ + 0.1394077849884828, + 0.11538318917155266, + 0.134284737209479, + 0.12809398770332336, + 0.10721301039059956, + 0.12393237774570783, + 0.13685639947652817, + 0.11482828731338184 + ], + [ + 0.13685966158906618, + 0.12397508695721626, + 0.14025520657499632, + 0.12401387343804042, + 0.10191544517874718, + 0.12110727156201999, + 0.13862955197691917, + 0.11324367672204971 + ], + [ + 0.12534980227549872, + 0.12722217664122581, + 0.13806622475385666, + 0.12110455582539241, + 0.11363568156957626, + 0.12403056770563126, + 0.13495965550343195, + 0.11563108240564664 + ], + [ + 0.12274332468708356, + 0.12289928272366524, + 0.13027398536602655, + 0.12471852327386539, + 0.12615439295768738, + 0.12561199193199477, + 0.13090622425079346, + 0.11669203266501427 + ], + [ + 0.1232175740102927, + 0.12426386525233586, + 0.12014101569851239, + 0.13513945788145065, + 0.12810809661944708, + 0.12434400618076324, + 0.1212994356950124, + 0.12348632390300433 + ], + [ + 0.12157885109384854, + 0.12558400382598242, + 0.11725193013747533, + 0.13487297420700392, + 0.12607504551609358, + 0.11985043187936147, + 0.120205357670784, + 0.134581179668506 + ], + [ + 0.12353533630569775, + 0.12738284096121788, + 0.12230070928732555, + 0.12665949513514838, + 0.1298290230333805, + 0.12155631929636002, + 0.12370743975043297, + 0.1250285990536213 + ], + [ + 0.12152615437904994, + 0.1275811865925789, + 0.12720437347888947, + 0.1247339037557443, + 0.12945388381679854, + 0.12587590888142586, + 0.12071902925769488, + 0.12290534128745396 + ], + [ + 0.12080265084902446, + 0.12322263792157173, + 0.11961399391293526, + 0.12743914624055228, + 0.12629839777946472, + 0.1301287425061067, + 0.12313347185651462, + 0.12936074286699295 + ], + [ + 0.12029729162653287, + 0.12463215241829555, + 0.11208489785591762, + 0.13320803021391234, + 0.12465452154477437, + 0.12730018546183905, + 0.1255464864273866, + 0.13227621465921402 + ], + [ + 0.1260697841644287, + 0.129348153869311, + 0.13164451097448668, + 0.12049078692992528, + 0.12975971028208733, + 0.11292146270473798, + 0.12621330842375755, + 0.1235520566503207 + ], + [ + 0.12189025556047757, + 0.12929208452502886, + 0.11865086480975151, + 0.1249602623283863, + 0.14065316567818323, + 0.11410342777768771, + 0.13052026430765787, + 0.11992945025364558 + ], + [ + 0.12586038062969843, + 0.13115117947260538, + 0.11517486969629924, + 0.12366710354884465, + 0.13030601044495901, + 0.13022703925768533, + 0.12174310038487117, + 0.1218700793882211 + ], + [ + 0.11899892116586368, + 0.13124312832951546, + 0.11121102049946785, + 0.12012842794259389, + 0.1286408267915249, + 0.13079561789830527, + 0.12574169288078943, + 0.13324011489748955 + ], + [ + 0.1250427340467771, + 0.11346900463104248, + 0.12365476787090302, + 0.1342919630308946, + 0.12021505584319432, + 0.13343259071310362, + 0.12706448510289192, + 0.12282917276024818 + ], + [ + 0.12427872171004613, + 0.131193016966184, + 0.1217530903716882, + 0.12653622776269913, + 0.1193309302131335, + 0.12380938977003098, + 0.1282565457125505, + 0.12484185521801312 + ], + [ + 0.1143236222366492, + 0.12730790053804716, + 0.11941495165228844, + 0.12616252899169922, + 0.13114960491657257, + 0.13083002467950186, + 0.11563719560702641, + 0.13517393916845322 + ], + [ + 0.1267846537133058, + 0.11718155071139336, + 0.12345526988307635, + 0.12220315014322598, + 0.1286692793170611, + 0.13959895198543867, + 0.11128036181131999, + 0.1308265527089437 + ], + [ + 0.12507840866843858, + 0.11014386266469955, + 0.10971656814217567, + 0.13400843491156897, + 0.1357493412991365, + 0.1227389710644881, + 0.11575853079557419, + 0.1468056639035543 + ], + [ + 0.1254200960199038, + 0.11712440351645152, + 0.13074298078815141, + 0.11806165426969528, + 0.12761933108170828, + 0.11636410281062126, + 0.13046537091334662, + 0.13420184950033823 + ], + [ + 0.12422571827967961, + 0.12452367569009463, + 0.12181266273061435, + 0.12265633915861447, + 0.12762916584809622, + 0.1210183451573054, + 0.13047051429748535, + 0.1276633602877458 + ], + [ + 0.11422779783606529, + 0.1318163312971592, + 0.12139534081021945, + 0.12642008066177368, + 0.13844073315461478, + 0.12200620646278064, + 0.12803960343201956, + 0.11765366916855176 + ], + [ + 0.12974936266740164, + 0.1217629003028075, + 0.12491700425744057, + 0.12390923251708348, + 0.1273606816927592, + 0.1283036358654499, + 0.11900316302975018, + 0.12499380484223366 + ], + [ + 0.12174949049949646, + 0.12175241112709045, + 0.12835455189148584, + 0.13113053888082504, + 0.12289684886733691, + 0.12553331131736437, + 0.1254726933936278, + 0.12310993919769923 + ], + [ + 0.1341996267437935, + 0.12266441062092781, + 0.11392356579502423, + 0.11675774430235226, + 0.1304306797683239, + 0.12313696617881457, + 0.1318251850704352, + 0.12706159676114717 + ], + [ + 0.12960617616772652, + 0.12072737763325374, + 0.12123511855800946, + 0.11654793843626976, + 0.13726886610190073, + 0.12161464740832646, + 0.1268019030491511, + 0.12619773422678313 + ], + [ + 0.11878873283664386, + 0.12847328806916872, + 0.12323109308878581, + 0.11415188759565353, + 0.1305048738916715, + 0.12411610285441081, + 0.12619938825567564, + 0.13453439498941103 + ], + [ + 0.12038205688198407, + 0.12674840042988458, + 0.12475752706329028, + 0.12413328886032104, + 0.124527208507061, + 0.12337127700448036, + 0.12241227055589358, + 0.13366775959730148 + ], + [ + 0.12543733417987823, + 0.12838186571995416, + 0.1288215033710003, + 0.12376301859815915, + 0.1185281698902448, + 0.12431250636776288, + 0.12520815059542656, + 0.12554720789194107 + ], + [ + 0.12892396996418634, + 0.12538310140371323, + 0.13369227573275566, + 0.11626814678311348, + 0.12961831440528235, + 0.12391503279407819, + 0.12424274161458015, + 0.11795617764194806 + ], + [ + 0.12109444290399551, + 0.1258570837477843, + 0.12462745482722919, + 0.11526937037706375, + 0.12431987623373668, + 0.12976568813125292, + 0.12840891629457474, + 0.13065694396694502 + ], + [ + 0.1110860879222552, + 0.12308529267708461, + 0.12841816619038582, + 0.11615587895115216, + 0.12611126651366553, + 0.13598316038648287, + 0.13183288152019182, + 0.12732703487078348 + ], + [ + 0.12529016161958376, + 0.1262428561846415, + 0.12788847088813782, + 0.1214534193277359, + 0.12399440507094066, + 0.12381287167469661, + 0.12001719201604526, + 0.13130039225021997 + ], + [ + 0.12347214917341869, + 0.13019327819347382, + 0.13151004165410995, + 0.12433807800213496, + 0.12341948101917903, + 0.12001078203320503, + 0.11783748244245847, + 0.129218477755785 + ], + [ + 0.12051367883880933, + 0.12791825334231058, + 0.12174551313122113, + 0.12101834019025166, + 0.1284767103691896, + 0.1314557728668054, + 0.12189535299936931, + 0.12697615971167883 + ], + [ + 0.12003429606556892, + 0.12793711945414543, + 0.12780164554715157, + 0.12430100639661153, + 0.12168940529227257, + 0.1298987716436386, + 0.12141961480180423, + 0.12691790610551834 + ], + [ + 0.12238635495305061, + 0.119320809841156, + 0.12571924676497778, + 0.11988828579584758, + 0.12770364060997963, + 0.1271932733555635, + 0.12603302051623663, + 0.13175515085458755 + ], + [ + 0.12517683332165083, + 0.11392779151598613, + 0.12567486613988876, + 0.12419225523869197, + 0.12861550723512968, + 0.1282950701812903, + 0.1245880052447319, + 0.12952945878108343 + ], + [ + 0.1179804690182209, + 0.1259866108496984, + 0.12425537407398224, + 0.12656096617380777, + 0.12519927819569907, + 0.12634536375602087, + 0.12193843349814415, + 0.13173329333464304 + ], + [ + 0.12156252314647038, + 0.12046961113810539, + 0.12292546530564626, + 0.12384568527340889, + 0.13546372205018997, + 0.12384220709403355, + 0.12038928767045338, + 0.13150126487016678 + ], + [ + 0.12343370790282886, + 0.12398527686794598, + 0.12138952563206355, + 0.1260326678554217, + 0.12304265424609184, + 0.13104653855164847, + 0.12314933662613232, + 0.12792008370161057 + ], + [ + 0.1251525121430556, + 0.12411284695068996, + 0.12043577805161476, + 0.12568721051017442, + 0.12337924167513847, + 0.1260602151354154, + 0.12819369385639826, + 0.12697827319304147 + ], + [ + 0.12485796461502711, + 0.1229457954565684, + 0.12437445173660915, + 0.11730960259834926, + 0.12504501268267632, + 0.12840055922667185, + 0.12571208675702414, + 0.131354292233785 + ], + [ + 0.12224547689159711, + 0.12357236569126447, + 0.1270434372127056, + 0.11915260801712672, + 0.12662151207526526, + 0.12594799200693765, + 0.12649253259102503, + 0.12892387186487517 + ], + [ + 0.12354755898316701, + 0.1295084444185098, + 0.12089034045735995, + 0.11584710453947385, + 0.12965192894140878, + 0.12971324225266775, + 0.1265015428264936, + 0.12433961530526479 + ], + [ + 0.12104954570531845, + 0.13032419855395952, + 0.11714073270559311, + 0.11382553974787395, + 0.13202049831549326, + 0.12688606729110083, + 0.13040306667486826, + 0.1283501274883747 + ], + [ + 0.12074750413497289, + 0.12424988051255544, + 0.12933322538932165, + 0.11879227310419083, + 0.1261460855603218, + 0.1288937491675218, + 0.1303452787299951, + 0.12149178112546603 + ] + ], + "load_balancing_losses": [ + 0.06007528156042099, + 0.06005345806479454, + 0.060072893276810646, + 0.060200298205018044, + 0.060406743735075, + 0.0604095533490181, + 0.06020508445799351, + 0.06007457599043846, + 0.06000114642083645, + 0.05994063504040241, + 0.059976843371987346, + 0.059980519115924835, + 0.06015676185488701, + 0.06035235337913036, + 0.060568489506840704, + 0.06051530539989471, + 0.06067498400807381, + 0.0605585515499115, + 0.06078568547964096, + 0.06079221293330193, + 0.06021268852055073, + 0.06074596792459488, + 0.06085770018398762, + 0.0606797955930233, + 0.060736792534589766, + 0.0604791846126318, + 0.0604643277823925, + 0.06046038530766964, + 0.060452453792095184, + 0.060344332829117775, + 0.060186026617884636, + 0.06023854948580265, + 0.06029247641563416, + 0.06022992692887783, + 0.06025746092200279, + 0.06021227538585663, + 0.06017339937388897, + 0.06014360003173351, + 0.06011454164981842, + 0.0601342748850584, + 0.06016431413590908, + 0.06013965830206871, + 0.0601209469139576, + 0.06010842472314835, + 0.06012655310332775, + 0.06019891314208507, + 0.06007558926939964, + 0.06005396507680416, + 0.06007737554609775, + 0.0600309357047081 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.12003081788619359, + 0.12314234549800555, + 0.13216516996423402, + 0.11812149236599605, + 0.12567463144659996, + 0.12962264940142632, + 0.13125648846228918, + 0.11998618890841801 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/temp_5.0/metrics.json b/experiments/exp10_routing_temperature_specialization/results/temp_5.0/metrics.json new file mode 100644 index 0000000..47675a9 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/temp_5.0/metrics.json @@ -0,0 +1,1108 @@ +{ + "experiment_name": "temp_5.0", + "description": "Very soft routing (maximum exploration)", + "temperature": 5.0, + "temperature_schedule": null, + "final_metrics": { + "val_loss": 24.41567771426359, + "val_accuracy": 0.01640931313229101, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500 + ], + "train_losses": [ + 5.0703545093536375, + 4.46462049484253, + 3.1875943422317503, + 1.8689485669136048, + 0.7624981820583343, + 0.3557029962539673, + 0.2238013669848442, + 0.16209573149681092, + 0.11361046060919762, + 0.09386456608772278, + 0.049562677182257174, + 0.017827144358307123, + 0.01399796400219202, + 0.011256490927189588, + 0.010376389836892486, + 0.010919028520584106, + 0.00978595931082964, + 0.00829839431680739, + 0.008575531095266343, + 0.008257886162027717, + 0.009540112875401973, + 0.0010036894993390888, + 0.0008184442878700793, + 0.0008429156441707164, + 0.0008803114527836442, + 0.0008313788217492402, + 0.0007687594450544565, + 0.0005616549198748544, + 0.0005631086620269343, + 0.0004606483154930174, + 0.0003698965738294646, + 0.0003490874063572846, + 0.00026360198389738796, + 0.0002687782354769297, + 0.000263003256986849, + 0.0002515418105758727, + 0.00023252369719557465, + 0.00021666392131010072, + 0.00022820265294285492, + 0.00019479463517200202, + 0.0001990688091609627, + 0.00019492459978209808, + 0.0001447171875042841, + 0.00014930147590348498, + 0.0001529595916508697, + 0.0001651784434216097, + 0.00015199912595562637, + 0.00015204865485429764, + 0.00015086211351444945, + 0.00015498396824114024 + ], + "val_losses": [ + 10.77370401658776, + 10.73704313473651, + 10.738431044265154, + 11.275021037448843, + 12.464150378223865, + 14.961996253724655, + 16.705523750386053, + 19.377588400149936, + 20.914736158013763, + 22.872542505971957, + 23.90016095461357, + 24.86539844068116, + 25.179874635837947, + 25.196459733134024, + 25.069449859457386, + 24.451877735528846, + 24.376162006661243, + 24.299318380996112, + 24.177560920850127, + 24.64056517125861, + 24.486083384537444, + 24.40486630955349, + 24.20979449268786, + 24.13727837822041, + 24.034373441770303, + 24.04908610906702, + 24.064054421738263, + 24.30540413064586, + 24.411597814661032, + 24.517229396968343, + 24.385016047069968, + 24.275730341988822, + 24.260663029161865, + 24.3975143971797, + 24.44620602880687, + 24.551363348539642, + 24.47933273180635, + 24.458823160232168, + 24.432253086103568, + 24.425183110860548, + 24.376090356402177, + 24.40755343352948, + 24.450850308152052, + 24.356389264757137, + 24.278506922637614, + 24.444314943185542, + 24.36549594461286, + 24.401875229690607, + 24.39038821243987, + 24.41567771426359 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.016395483117008846, + 0.01640931313229101, + 0.01640931313229101, + 0.016423143147573177, + 0.016430058155214262, + 0.016416228139932095, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101 + ], + "val_perplexities": [ + 47748.55179687827, + 46029.74674019473, + 46093.676218200104, + 78827.80251784841, + 258887.88699553953, + 3147113.5385011053, + 17993546.978445563, + 260363268.44495198, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.02915708621342977, + 0.07085682153701782, + 0.10821998516718546, + 0.14691590865453083, + 0.18436036507288614, + 0.22336636384328207, + 0.26151129802068074, + 0.3005089044570923, + 0.3379643321037292, + 0.37702012062072754, + 0.4151902476946513, + 0.4547288537025452, + 0.4927114129066467, + 0.5311727643013, + 0.5689235965410868, + 0.607405177752177, + 0.6446005662282308, + 0.6833065748214722, + 0.7214114586512248, + 0.7605320572853088, + 0.7976165970166524, + 0.8361191431681315, + 0.8727133115132649, + 0.9199727018674214, + 0.9567810455958049, + 0.9949237823486328, + 1.0317928274472554, + 1.0700676361719768, + 1.1069627324740092, + 1.1451308131217957, + 1.1818441987037658, + 1.2199201305707297, + 1.2566996812820435, + 1.294764471054077, + 1.3315291047096252, + 1.369658148288727, + 1.4064087470372517, + 1.4442970395088195, + 1.4807718992233276, + 1.5186456759770712, + 1.5552038351694744, + 1.5929227948188782, + 1.6296249906222025, + 1.6675133983294168, + 1.7040600140889486, + 1.742207896709442, + 1.7789579431215923, + 1.8170294404029845, + 1.8537912567456563, + 1.8919958829879762 + ], + "learning_rates": [ + 0.005600000000000001, + 0.014000000000000002, + 0.019600000000000003, + 0.028000000000000004, + 0.033600000000000005, + 0.042, + 0.04760000000000001, + 0.05600000000000001, + 0.06160000000000001, + 0.07, + 0.06999724420604547, + 0.06998277760623918, + 0.06996624706137415, + 0.06993112925739141, + 0.06990084203752397, + 0.06984511143016786, + 0.0698011006539283, + 0.06972481818379972, + 0.06966713197632568, + 0.06957038105718576, + 0.06949908249749817, + 0.06938196892505649, + 0.06929713597708359, + 0.06915978781331218, + 0.06906151324063706, + 0.06890408067373649, + 0.06879247193816149, + 0.06861512711833227, + 0.06849030626237115, + 0.06829324311356999, + 0.06815534662699614, + 0.06793878063488303, + 0.06778795930547993, + 0.06755212728178762, + 0.06738854603046465, + 0.06713370585404863, + 0.06695754355450242, + 0.06668397388935406, + 0.06649542317247278, + 0.06620342316300441, + 0.06600269020622872, + 0.06569257915016359, + 0.06547988345203457, + 0.06515200045125948, + 0.06492757459140026, + 0.06458227818116263, + 0.064346367565956, + 0.06398403532281084, + 0.06373689791705092, + 0.0633579260459864 + ], + "temperatures": [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12477028618256251, + 0.11404932911197345, + 0.13935748487710953, + 0.12813226381937662, + 0.13374890138705572, + 0.12626479069391885, + 0.12055238708853722, + 0.11312436188260715 + ], + [ + 0.13119290272394815, + 0.11558721214532852, + 0.136261152724425, + 0.1220968986550967, + 0.12160027399659157, + 0.1273151934146881, + 0.13316109155615172, + 0.11278505126635234 + ], + [ + 0.1353981780509154, + 0.11397860447565715, + 0.1327712076405684, + 0.12281746293107669, + 0.11620801190535228, + 0.12524349863330522, + 0.13464910412828127, + 0.1189337025086085 + ], + [ + 0.14023700108130774, + 0.11708694944779079, + 0.138208935658137, + 0.12674926221370697, + 0.10306349024176598, + 0.11930984631180763, + 0.13831660275657973, + 0.1170276589691639 + ], + [ + 0.1316507632533709, + 0.12857574100295702, + 0.14245687425136566, + 0.12692434216539064, + 0.10080775370200475, + 0.12150231500466664, + 0.13474519674976668, + 0.11333677048484485 + ], + [ + 0.12569407746195793, + 0.12786443655689558, + 0.14190090323487917, + 0.12532584369182587, + 0.11291309570272763, + 0.12561137105027834, + 0.1333636517326037, + 0.10732638587554295 + ], + [ + 0.1255204752087593, + 0.11991419394810994, + 0.13823805997769037, + 0.12669343501329422, + 0.13079335168004036, + 0.12262205655376117, + 0.12988103677829108, + 0.10633717104792595 + ], + [ + 0.1259683407843113, + 0.11897451058030128, + 0.1251400721569856, + 0.1286942002673944, + 0.13305794447660446, + 0.12700254345933595, + 0.12287053341666858, + 0.1182916226486365 + ], + [ + 0.12419281775752704, + 0.12015323961774509, + 0.12121061359842618, + 0.13156378517548242, + 0.13171553860108057, + 0.12572013835112253, + 0.11582103744149208, + 0.1296226109067599 + ], + [ + 0.12288110703229904, + 0.12715152526895204, + 0.12072817857066791, + 0.12859910478194556, + 0.1297885850071907, + 0.1257924847304821, + 0.11638227726022403, + 0.12867651879787445 + ], + [ + 0.12377662460009257, + 0.12659810235102972, + 0.12307258571187656, + 0.12743108347058296, + 0.127451424797376, + 0.12984498466054598, + 0.11708961427211761, + 0.12473536531130473 + ], + [ + 0.12474415575464566, + 0.12836556136608124, + 0.12529647474487624, + 0.1215812514225642, + 0.13081897919376692, + 0.11882359037796657, + 0.12321120873093605, + 0.12715857600172362 + ], + [ + 0.12364306549231212, + 0.12744436288873354, + 0.1247934103012085, + 0.12345421314239502, + 0.13450921699404716, + 0.11015255004167557, + 0.12663327530026436, + 0.1293696959813436 + ], + [ + 0.12619193394978842, + 0.12955861166119576, + 0.12013716374834378, + 0.12541617453098297, + 0.11452159409721692, + 0.12571229909857115, + 0.12131353467702866, + 0.13714846596121788 + ], + [ + 0.13448402285575867, + 0.12641586239139238, + 0.11578949044148128, + 0.11991504828135173, + 0.1161196914811929, + 0.1305411420762539, + 0.12215424701571465, + 0.13458027069767317 + ], + [ + 0.1281722771624724, + 0.11752036462227504, + 0.10578508178393047, + 0.12392654518286388, + 0.12420964613556862, + 0.12668093790610632, + 0.12555691972374916, + 0.14814799278974533 + ], + [ + 0.12360658993323644, + 0.1218559555709362, + 0.10403751581907272, + 0.12938379247983298, + 0.12880768130222955, + 0.12601970384518305, + 0.1302609828611215, + 0.13602754473686218 + ], + [ + 0.13646909967064857, + 0.1149950623512268, + 0.13262039919694266, + 0.11723950877785683, + 0.1256739484767119, + 0.12072555844982465, + 0.1152287336687247, + 0.13704746340711912 + ], + [ + 0.12861392895380655, + 0.1274326853454113, + 0.13041726127266884, + 0.11785081898172696, + 0.12248246123393376, + 0.11831739420692126, + 0.12649325281381607, + 0.1283919500807921 + ], + [ + 0.12174006924033165, + 0.12623372798164686, + 0.11065865804751714, + 0.11125793059666951, + 0.1325134076178074, + 0.13039778793851534, + 0.12924037873744965, + 0.13795779645442963 + ], + [ + 0.12914384404818216, + 0.12221771851181984, + 0.11726376911004384, + 0.12032396097977956, + 0.12339333444833755, + 0.13300885880986849, + 0.12421010062098503, + 0.13043818126122156 + ], + [ + 0.12202925359209378, + 0.11807028700908025, + 0.12775781378149986, + 0.12321480611960094, + 0.1315752975642681, + 0.12511195242404938, + 0.13026975467801094, + 0.12197061503926913 + ], + [ + 0.12298161660631497, + 0.11879849185546239, + 0.11511528367797534, + 0.12830017631252608, + 0.12728476276000342, + 0.13300958896676698, + 0.12759954979022345, + 0.12691031520565352 + ], + [ + 0.12301204353570938, + 0.12908517072598139, + 0.11999612798293431, + 0.12100109457969666, + 0.13396340608596802, + 0.11687351142366727, + 0.13174517949422201, + 0.12432324141263962 + ], + [ + 0.11311099429925282, + 0.12907075633605322, + 0.1263370414574941, + 0.1322266993423303, + 0.12884623557329178, + 0.12473848462104797, + 0.12211047857999802, + 0.12355910986661911 + ], + [ + 0.12983701253930727, + 0.12173312405745189, + 0.12207641204198201, + 0.13177673518657684, + 0.12330907459060352, + 0.11995165422558784, + 0.12373720233639081, + 0.12757856274644533 + ], + [ + 0.1268964260816574, + 0.13022598127524057, + 0.12393082181612651, + 0.13038464014728865, + 0.11740660543243091, + 0.1237058552602927, + 0.12605544552206993, + 0.121393999705712 + ], + [ + 0.1321516645451387, + 0.12313180044293404, + 0.11901085451245308, + 0.11343496044476827, + 0.12958736717700958, + 0.12710683171947798, + 0.12996351098020872, + 0.12561276679237685 + ], + [ + 0.1285351775586605, + 0.12518409018715224, + 0.13269146780172983, + 0.11589425678054492, + 0.12341672430435817, + 0.12130612383286159, + 0.1360117495059967, + 0.11696016912659009 + ], + [ + 0.12472530454397202, + 0.12946436057488123, + 0.11869195848703384, + 0.12517083808779716, + 0.12870808690786362, + 0.12260618433356285, + 0.11867646003762881, + 0.13195657481749853 + ], + [ + 0.13187258938948312, + 0.13282287244995436, + 0.12957224746545157, + 0.11315961306293805, + 0.12411983435352643, + 0.11693492780129115, + 0.12212342272202174, + 0.12939426054557165 + ], + [ + 0.12550069764256477, + 0.12219651664296786, + 0.12362342452009518, + 0.13058074191212654, + 0.12213251739740372, + 0.12225718796253204, + 0.12663770591219267, + 0.127070972075065 + ], + [ + 0.13131801038980484, + 0.11635175347328186, + 0.136438408245643, + 0.12553071106473604, + 0.12676986927787462, + 0.12003661319613457, + 0.11848867187897365, + 0.1250657377143701 + ], + [ + 0.12001736462116241, + 0.13127816965182623, + 0.11939862991372745, + 0.11824446047345798, + 0.13444430381059647, + 0.12069116160273552, + 0.13165192057689032, + 0.12427375962336858 + ], + [ + 0.11951538051168124, + 0.1316730615993341, + 0.12135189895828564, + 0.12558163205782572, + 0.12968073785305023, + 0.12074824174245198, + 0.12602591266234717, + 0.12542292227347693 + ], + [ + 0.12548970058560371, + 0.12411254271864891, + 0.12327447533607483, + 0.12189250066876411, + 0.12929068505764008, + 0.12880051011840501, + 0.12393552685777347, + 0.12320382023851077 + ], + [ + 0.12608481322725615, + 0.1245619294544061, + 0.12383013094464938, + 0.11925172184904416, + 0.1281214877963066, + 0.12769856055577597, + 0.1278823340932528, + 0.12256879111131032 + ], + [ + 0.12150382002194722, + 0.1248937485118707, + 0.13154955705006918, + 0.11728510136405627, + 0.12400428826610248, + 0.12869583194454512, + 0.12158907825748126, + 0.13047833864887556 + ], + [ + 0.11881603921453159, + 0.12522679443160692, + 0.13028604164719582, + 0.12374808887640636, + 0.12735786040623984, + 0.12905522932608923, + 0.12412923698623975, + 0.12138048683603604 + ], + [ + 0.12545411909619966, + 0.11841143295168877, + 0.13537941997249922, + 0.1208376574019591, + 0.1318199560046196, + 0.11679320285717647, + 0.1276612567404906, + 0.12364272276560466 + ], + [ + 0.12642036005854607, + 0.12033715471625328, + 0.12685072173674902, + 0.12487638990084331, + 0.1268667442103227, + 0.1203000657260418, + 0.1293960710366567, + 0.12495224922895432 + ], + [ + 0.12674545248349509, + 0.12066101158658664, + 0.12826942279934883, + 0.12532779574394226, + 0.12852359314759573, + 0.12342178573211034, + 0.11662036925554276, + 0.130430335799853 + ], + [ + 0.12158407146732013, + 0.11633847281336784, + 0.13018065442641577, + 0.1270310121277968, + 0.12587824712196985, + 0.12058176348606746, + 0.11985754345854123, + 0.13854802151521048 + ], + [ + 0.12396118665734927, + 0.12650743623574576, + 0.13001176590720812, + 0.1261710487306118, + 0.11839381108681361, + 0.12594938899079958, + 0.12432178979118665, + 0.12468335404992104 + ], + [ + 0.12187394003073375, + 0.12943923597534499, + 0.13239886611700058, + 0.12699544181426367, + 0.11759350324670474, + 0.12479124466578166, + 0.12372298041979472, + 0.12318453565239906 + ], + [ + 0.12472069387634595, + 0.12557107582688332, + 0.12597007428606352, + 0.11917857204874356, + 0.12682970985770226, + 0.12139171361923218, + 0.1329143779973189, + 0.12342354903618495 + ], + [ + 0.12313155457377434, + 0.12367554629842441, + 0.12050254518787067, + 0.12245697155594826, + 0.12575539201498032, + 0.1250599498550097, + 0.13582346588373184, + 0.12359433496991794 + ], + [ + 0.119864322245121, + 0.12844805916150412, + 0.12570173541704813, + 0.11907673502961795, + 0.12188619375228882, + 0.12204952786366145, + 0.13584869230786958, + 0.127124502013127 + ], + [ + 0.12247722099224727, + 0.12226896360516548, + 0.12860757857561111, + 0.12310665970047314, + 0.1166485386590163, + 0.1250074878334999, + 0.13339037944873175, + 0.1284929725031058 + ], + [ + 0.12247613941629727, + 0.12451584140459697, + 0.1243399182955424, + 0.12536592657367387, + 0.12524844457705817, + 0.11946872249245644, + 0.13208768516778946, + 0.1264970786869526 + ] + ], + "load_balancing_losses": [ + 0.06004501059651375, + 0.06003277376294136, + 0.0600448913872242, + 0.06012629866600037, + 0.06028372347354889, + 0.06031692922115326, + 0.06019878387451172, + 0.06008903346955776, + 0.06002595722675323, + 0.060004842653870585, + 0.060019636526703835, + 0.06000459119677544, + 0.059986402094364163, + 0.060028502345085145, + 0.06013297997415066, + 0.06017942801117897, + 0.06042850837111473, + 0.06041756384074688, + 0.06054118573665619, + 0.06055377274751663, + 0.060345054045319556, + 0.060216189920902254, + 0.06054187603294849, + 0.06058353632688522, + 0.06030330806970596, + 0.06021143570542335, + 0.060424667224287985, + 0.06046537086367607, + 0.0602728120982647, + 0.06023858301341534, + 0.060200073570013043, + 0.06016919799149036, + 0.06030133403837681, + 0.06024600118398667, + 0.06014646291732788, + 0.06020813770592213, + 0.060055486485362054, + 0.0601214550435543, + 0.06014898307621479, + 0.06020488478243351, + 0.06015055291354656, + 0.060126836970448495, + 0.06008899882435799, + 0.06003185734152794, + 0.060166549682617185, + 0.06011587493121624, + 0.060002363100647924, + 0.06005332544445992, + 0.06007154993712902, + 0.06007159687578678 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.12206076458096504, + 0.12474725147088368, + 0.12385034188628197, + 0.12561731040477753, + 0.1259287049372991, + 0.11743786682685216, + 0.13355866322914758, + 0.1267988681793213 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 500, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/results/temp_best_long/metrics.json b/experiments/exp10_routing_temperature_specialization/results/temp_best_long/metrics.json new file mode 100644 index 0000000..40dc85a --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/results/temp_best_long/metrics.json @@ -0,0 +1,2158 @@ +{ + "experiment_name": "temp_best_long", + "description": "Best temperature from ablation, trained for 1000 steps", + "temperature": 2.0, + "temperature_schedule": null, + "final_metrics": { + "val_loss": 22.979948994127685, + "val_accuracy": 0.01640931313229101, + "val_perplexity": 485165195.4097903 + }, + "history": { + "steps": [ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 130, + 140, + 150, + 160, + 170, + 180, + 190, + 200, + 210, + 220, + 230, + 240, + 250, + 260, + 270, + 280, + 290, + 300, + 310, + 320, + 330, + 340, + 350, + 360, + 370, + 380, + 390, + 400, + 410, + 420, + 430, + 440, + 450, + 460, + 470, + 480, + 490, + 500, + 510, + 520, + 530, + 540, + 550, + 560, + 570, + 580, + 590, + 600, + 610, + 620, + 630, + 640, + 650, + 660, + 670, + 680, + 690, + 700, + 710, + 720, + 730, + 740, + 750, + 760, + 770, + 780, + 790, + 800, + 810, + 820, + 830, + 840, + 850, + 860, + 870, + 880, + 890, + 900, + 910, + 920, + 930, + 940, + 950, + 960, + 970, + 980, + 990, + 1000 + ], + "train_losses": [ + 5.093000078201294, + 4.768654918670654, + 3.9902720928192137, + 3.085107660293579, + 2.056787669658661, + 1.195764034986496, + 0.6177730023860931, + 0.36190438866615293, + 0.23436530232429503, + 0.1760684847831726, + 0.1099536381661892, + 0.06502041332423687, + 0.04747045300900936, + 0.03471943717449903, + 0.02919358089566231, + 0.02650294080376625, + 0.02317241672426462, + 0.02013303805142641, + 0.020495779812335968, + 0.01874320739880204, + 0.022934645228087903, + 0.003164438437670469, + 0.002522177342325449, + 0.0021018866798840465, + 0.001861788157839328, + 0.0015394110116176308, + 0.001337145664729178, + 0.001207011571386829, + 0.0011392907123081385, + 0.0010924837493803351, + 0.0009832238603848964, + 0.000799700984498486, + 0.0005536529410164803, + 0.0005476086429553106, + 0.000539558957098052, + 0.0004916821722872555, + 0.0005005413171602413, + 0.00043973037682007996, + 0.0004396916250698268, + 0.0004265427385689691, + 0.0004060846142238006, + 0.0004098861507372931, + 0.00032479470246471465, + 0.0003291622153483331, + 0.0003475018165772781, + 0.0003396520740352571, + 0.0003362289018696174, + 0.00032715669076424094, + 0.0003169335541315377, + 0.0003185467416187748, + 0.00031889035599306224, + 0.0003233567462302744, + 0.00029575721418950707, + 0.0002802553819492459, + 0.00027922074368689207, + 0.000278554558462929, + 0.0002860859502106905, + 0.0002909263246692717, + 0.0002977533295052126, + 0.00029731441172771157, + 0.00029416689067147673, + 0.00031793613743502647, + 0.00030841972038615496, + 0.0002592450924566947, + 0.00025148740824079143, + 0.0002429259751806967, + 0.00024965459160739554, + 0.00024422139103990047, + 0.0002538585657021031, + 0.0002462363641825505, + 0.00024257804034277797, + 0.00024414215731667356, + 0.0002476524401572533, + 0.00021220916241873055, + 0.00017282165645156057, + 0.0001666264855884947, + 0.00017115303344326094, + 0.0001683504568063654, + 0.00015985368518158793, + 0.0001629199687158689, + 0.0001645225565880537, + 0.00016069567354861646, + 0.00016084629896795377, + 0.0001505965177784674, + 0.00011141209470224566, + 0.00010772274545161054, + 0.00010467169049661606, + 0.00010468966866028495, + 9.94078982330393e-05, + 9.385020166519098e-05, + 9.300635792897083e-05, + 8.85863133589737e-05, + 8.859239023877308e-05, + 8.375489778700285e-05, + 7.243949949042872e-05, + 6.130908877821639e-05, + 5.828483190271072e-05, + 5.5807465469115416e-05, + 5.580250763159711e-05, + 5.3706197286373934e-05 + ], + "val_losses": [ + 10.779226397457055, + 10.748181060009205, + 10.73444930908958, + 10.744837343060928, + 10.880311969312256, + 11.666932308210502, + 12.594663788488813, + 14.2920420953326, + 15.44162052343254, + 17.06974983215332, + 18.043296557854426, + 19.346303575872955, + 20.09925075975829, + 20.998059343533466, + 21.436998697557215, + 22.022583890719464, + 22.24292656726635, + 22.13121087306801, + 22.117182633902075, + 21.749672771763887, + 21.676656305158097, + 21.619962503547804, + 21.672284965380342, + 21.41208668900884, + 21.388349438724585, + 21.755900608777157, + 21.722522432307052, + 21.830109053702742, + 22.00259015248437, + 21.99168669629855, + 21.99317664560918, + 22.122004525821538, + 22.182502948774466, + 22.197347122030628, + 22.324309291772202, + 22.30476748732712, + 22.345670235030642, + 22.411979169811882, + 22.317876155300613, + 22.39727671997286, + 22.320257759768214, + 22.417020932524448, + 22.42570321719975, + 22.37258204807241, + 22.345258369041417, + 22.381934155844966, + 22.30150601552148, + 22.318089131331696, + 22.328738202475826, + 22.27325452258646, + 22.282489311568728, + 22.220852551948898, + 22.21031966731742, + 22.1268543809547, + 22.091828558554496, + 22.06091240522297, + 22.0406348966457, + 21.995526256493882, + 21.993306689043347, + 21.93322995189222, + 21.895074709565396, + 21.815620934584114, + 21.793328706451526, + 21.735666066092232, + 21.737147712033543, + 21.680026475616565, + 21.662327163211028, + 21.658950576512638, + 21.63663097826415, + 21.610892959702564, + 21.608446202092793, + 21.590824848350284, + 21.571639037385005, + 21.558307229840715, + 21.551746981725255, + 21.570144066962254, + 21.580016092361074, + 21.605436129620557, + 21.6218087091884, + 21.654367837804788, + 21.67288459400406, + 21.709339303599652, + 21.741782245703384, + 21.77811714954174, + 21.83077843871639, + 21.895473952006956, + 21.942066610491317, + 22.016616524740158, + 22.06737811489577, + 22.147896379127097, + 22.199941635131836, + 22.30382027642887, + 22.357646746685987, + 22.448525115373698, + 22.524168156061073, + 22.619052967839863, + 22.696267084182363, + 22.803448754570088, + 22.87490337223552, + 22.979948994127685 + ], + "val_accuracies": [ + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.016402398124649928, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.016402398124649928, + 0.016402398124649928, + 0.01640931313229101, + 0.016402398124649928, + 0.016402398124649928, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101, + 0.01640931313229101 + ], + "val_perplexities": [ + 48012.96691436932, + 46545.28832196924, + 45910.50831127818, + 46389.913961249054, + 53120.16919607174, + 116649.88693831676, + 294980.2808433114, + 1610478.8181656147, + 5084059.036467911, + 25899904.31220081, + 68565260.50898427, + 252343944.96528533, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903, + 485165195.4097903 + ], + "elapsed_times": [ + 0.029030096530914307, + 0.07002029021581015, + 0.10727337996164958, + 0.1459606965382894, + 0.18312939008076987, + 0.22184558312098185, + 0.25910626649856566, + 0.2977186719576518, + 0.3417368332544963, + 0.38481602668762205, + 0.4217920541763306, + 0.4601393461227417, + 0.49705619414647423, + 0.535248609383901, + 0.5723066449165344, + 0.6111922820409139, + 0.6481263995170593, + 0.6868864059448242, + 0.7241862018903097, + 0.7622060139973958, + 0.798847230275472, + 0.836992343266805, + 0.8737150669097901, + 0.911770248413086, + 0.9485302805900574, + 0.9868516643842061, + 1.0237465461095174, + 1.074126406510671, + 1.1108299652735392, + 1.1488507509231567, + 1.1854864796002706, + 1.2234553734461466, + 1.2602906227111816, + 1.298561453819275, + 1.3354615052541097, + 1.3737420121828714, + 1.4106024146080016, + 1.4487469871838887, + 1.4856273452440898, + 1.5238646507263183, + 1.5608189821243286, + 1.5988837758700052, + 1.6360863049825032, + 1.6740562121073406, + 1.7107356746991476, + 1.7488200108210246, + 1.7853419502576193, + 1.8298504829406739, + 1.8710852225621541, + 1.9093469738960267, + 1.9464125792185465, + 1.9845890243848165, + 2.0217219591140747, + 2.0598358392715452, + 2.096575343608856, + 2.1349153955777487, + 2.1719194849332175, + 2.2100890477498374, + 2.2469960689544677, + 2.2852455457051595, + 2.3221314509709674, + 2.3600756168365478, + 2.3964901089668276, + 2.4345386107762654, + 2.4711479981740316, + 2.50909903049469, + 2.5457919398943583, + 2.5836928288141885, + 2.6202367146809897, + 2.6581838647524516, + 2.704783296585083, + 2.742840536435445, + 2.7794804175694785, + 2.817513891061147, + 2.854218284289042, + 2.8923499743143717, + 2.9290514906247456, + 2.967230180899302, + 3.00408988793691, + 3.0421733299891156, + 3.078792174657186, + 3.1172967751820884, + 3.1541523774464926, + 3.1922855496406557, + 3.2289458870887757, + 3.2672651171684266, + 3.31540470123291, + 3.353548498948415, + 3.390299900372823, + 3.4284762422243755, + 3.465256949265798, + 3.5034289717674256, + 3.5402917146682737, + 3.578429226080577, + 3.6151761531829836, + 3.6533300161361693, + 3.690035370985667, + 3.7280458490053814, + 3.764844044049581, + 3.803180734316508 + ], + "learning_rates": [ + 0.0028000000000000004, + 0.007000000000000001, + 0.009800000000000001, + 0.014000000000000002, + 0.016800000000000002, + 0.021, + 0.023800000000000005, + 0.028000000000000004, + 0.030800000000000004, + 0.035, + 0.03780000000000001, + 0.042, + 0.044800000000000006, + 0.049, + 0.051800000000000006, + 0.05600000000000001, + 0.058800000000000005, + 0.06300000000000001, + 0.0658, + 0.07, + 0.06999931104397708, + 0.06999569410726278, + 0.06999156063482156, + 0.06998277760623918, + 0.06997520074742977, + 0.0699612540281716, + 0.06995023585443154, + 0.06993112925739141, + 0.06991667278097936, + 0.06989241152971011, + 0.0698745207028824, + 0.06984511143016786, + 0.06982379114409791, + 0.06978924189013962, + 0.06976449797358074, + 0.06972481818379972, + 0.06969665740149165, + 0.06965185792394626, + 0.06962028797476569, + 0.06957038105718576, + 0.06953541057204157, + 0.06948040985848004, + 0.0694420483979537, + 0.06938196892505649, + 0.06934022697678827, + 0.06927508516968339, + 0.06922997414550512, + 0.06915978781331218, + 0.0691113200461275, + 0.06903610837708891, + 0.06898429711750137, + 0.06890408067373649, + 0.06884894008642714, + 0.06876374079831073, + 0.06870528595816558, + 0.06861512711833227, + 0.06855337400632094, + 0.06845828026329735, + 0.06839324576210407, + 0.06829324311356999, + 0.06822494500297802, + 0.06812006078865909, + 0.0680485177406899, + 0.06793878063488303, + 0.06786401220869158, + 0.06774945221242577, + 0.06767147884895325, + 0.06755212728178762, + 0.06747097029817303, + 0.06734685978963431, + 0.06726254137338662, + 0.06713370585404863, + 0.06704624905698091, + 0.06691272374918822, + 0.06682215248111556, + 0.06668397388935406, + 0.0665903129115568, + 0.06644751881247378, + 0.06635079373092805, + 0.06620342316300441, + 0.0661036604213817, + 0.06595175367425923, + 0.065848980546697, + 0.06569257915016359, + 0.0655868237338087, + 0.06542597044644449, + 0.06531726165377183, + 0.06515200045125948, + 0.0650403680021675, + 0.06487074406526966, + 0.06475621847895532, + 0.06458227818116263, + 0.06446489076777771, + 0.06428668166263082, + 0.064166464514722, + 0.06398403532281084, + 0.06386102130654603, + 0.06367442190219014, + 0.06354864464837305, + 0.0633579260459864 + ], + "temperatures": [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0 + ], + "routing_entropies": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "selection_confidences": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "expert_utilizations": [ + [ + 0.12587271258234978, + 0.11336542790134747, + 0.13968932007749876, + 0.1283344104886055, + 0.13466989994049072, + 0.1230706771214803, + 0.12242491915822029, + 0.11257241790493329 + ], + [ + 0.12790738294521967, + 0.11341678972045581, + 0.13681673755248389, + 0.12429658075173695, + 0.12479822958509128, + 0.12900208309292793, + 0.13055740421017012, + 0.11320456365744273 + ], + [ + 0.1315304860472679, + 0.11547527462244034, + 0.1354562764366468, + 0.12157857045531273, + 0.11887108534574509, + 0.12922349696358046, + 0.13387440890073776, + 0.11399018143614133 + ], + [ + 0.1349102978905042, + 0.11592956632375717, + 0.1323829690615336, + 0.12167617057760556, + 0.11587324862678845, + 0.13046928370992342, + 0.13412983218828836, + 0.11462841058770816 + ], + [ + 0.13805674389004707, + 0.11615101123849551, + 0.1328147860864798, + 0.12568429360787073, + 0.11441744615634282, + 0.12479598820209503, + 0.1369024564822515, + 0.11117703840136528 + ], + [ + 0.13957551370064417, + 0.11980107550819714, + 0.13397935032844543, + 0.12755162393053374, + 0.10671305904785792, + 0.12231272707382838, + 0.138968446602424, + 0.1110979715983073 + ], + [ + 0.13679840912421545, + 0.12388345350821812, + 0.13687013710538545, + 0.12710880984862646, + 0.10531436279416084, + 0.11917128413915634, + 0.1376905602713426, + 0.11316274975736935 + ], + [ + 0.12815066302816072, + 0.12166611105203629, + 0.13248287638028464, + 0.12471334387858708, + 0.1203408141930898, + 0.12519187231858572, + 0.13138747587800026, + 0.11606660857796669 + ], + [ + 0.12674027432998022, + 0.12228081127007802, + 0.12772589673598608, + 0.12556237479050955, + 0.12813247616092363, + 0.12443717569112778, + 0.1255374439060688, + 0.1195833111802737 + ], + [ + 0.12667261933286986, + 0.12313628941774368, + 0.12201519062121709, + 0.12767006332675615, + 0.12785578767458597, + 0.12491566936175029, + 0.124739158898592, + 0.12299499288201332 + ], + [ + 0.12720459575454393, + 0.12650738780697188, + 0.11843281239271164, + 0.12679838513334593, + 0.12315842012564342, + 0.12766795232892036, + 0.12299713864922523, + 0.1272331103682518 + ], + [ + 0.13045632218321165, + 0.1252553661664327, + 0.12120022252202034, + 0.12310117607315381, + 0.12651676187912622, + 0.12502328554789224, + 0.12333596249421437, + 0.12511067713300386 + ], + [ + 0.13135796909530958, + 0.12432100251317024, + 0.12366848190625508, + 0.1238104763130347, + 0.1256855713824431, + 0.12326287850737572, + 0.12316071490446727, + 0.12473269924521446 + ], + [ + 0.12877081210414568, + 0.12667947510878244, + 0.1253932664791743, + 0.1269336218635241, + 0.12209556500116985, + 0.12317002192139626, + 0.1223796047270298, + 0.1245774229367574 + ], + [ + 0.1260360168914, + 0.12652320290605226, + 0.12297584488987923, + 0.12539793302615485, + 0.12492821241418521, + 0.12554357076684633, + 0.12422364577651024, + 0.12437134981155396 + ], + [ + 0.12443829948703448, + 0.12506015598773956, + 0.12344298884272575, + 0.12388104945421219, + 0.12065098558863004, + 0.12475194285313289, + 0.12070238962769508, + 0.1370719646414121 + ], + [ + 0.1268621807297071, + 0.12457372123996417, + 0.12497630839546521, + 0.1185834879676501, + 0.1249399296939373, + 0.12670695905884108, + 0.1163462686041991, + 0.13701090837518373 + ], + [ + 0.1329279032846292, + 0.12702622388799986, + 0.1296961655219396, + 0.1118754893541336, + 0.11725708469748497, + 0.1308150738477707, + 0.1196103369196256, + 0.13079148282607397 + ], + [ + 0.1309177614748478, + 0.1189967580139637, + 0.11863441268603007, + 0.11681319649020831, + 0.12706602861483893, + 0.13702987631162009, + 0.117781197031339, + 0.13276052350799242 + ], + [ + 0.12277690693736076, + 0.11891919374465942, + 0.12566564852992693, + 0.11949326718846957, + 0.12425324196616809, + 0.1295070710281531, + 0.1293593980371952, + 0.13002505029241243 + ], + [ + 0.11898445462187131, + 0.11600970476865768, + 0.10575242092212041, + 0.13355291013916334, + 0.1342888871828715, + 0.1275617095331351, + 0.12069196502367656, + 0.14315771808226904 + ], + [ + 0.1227885050078233, + 0.12779962395628294, + 0.13178966442743936, + 0.11417382086316745, + 0.13049409414331117, + 0.11783487473924954, + 0.12698261191447577, + 0.12813656156261763 + ], + [ + 0.12639433642228445, + 0.13114009176691374, + 0.12804179141918817, + 0.09782435745000839, + 0.12737402195731798, + 0.12271647527813911, + 0.11701905354857445, + 0.14948963870604834 + ], + [ + 0.12785115341345468, + 0.10892670104900996, + 0.13894320403536162, + 0.11245918770631154, + 0.12459361180663109, + 0.12091729417443275, + 0.11895841732621193, + 0.14735020448764166 + ], + [ + 0.13274119546016058, + 0.12144281342625618, + 0.1272424471875032, + 0.13589521621664366, + 0.11758510395884514, + 0.11564349258939426, + 0.1225752371052901, + 0.12687425563732782 + ], + [ + 0.12746076782544455, + 0.1267718387146791, + 0.115023884922266, + 0.12166303644577663, + 0.12298926090200742, + 0.12116577103734016, + 0.13183269649744034, + 0.1330925188958645 + ], + [ + 0.13183709606528282, + 0.12204498425126076, + 0.12194284920891126, + 0.1254388727247715, + 0.12497304876645406, + 0.1322875109811624, + 0.11500612770517667, + 0.12646927932898203 + ], + [ + 0.13040753826498985, + 0.12467105810840924, + 0.1184080330034097, + 0.1153053380548954, + 0.13790381451447806, + 0.1199042908847332, + 0.11747441937526067, + 0.1359252631664276 + ], + [ + 0.12704912945628166, + 0.12799829492966333, + 0.11730604991316795, + 0.11303821206092834, + 0.1310229760905107, + 0.1253361739218235, + 0.11859673137466113, + 0.1396522099773089 + ], + [ + 0.12515039866169295, + 0.12648878370722136, + 0.12777609005570412, + 0.12139808386564255, + 0.1209163765112559, + 0.1248055932422479, + 0.1221737911303838, + 0.1312906543413798 + ], + [ + 0.12447518731156985, + 0.12218866869807243, + 0.12958581621448198, + 0.1188336784640948, + 0.1265310992797216, + 0.12554291387399039, + 0.12446958695848782, + 0.1283728281656901 + ], + [ + 0.12457960347334544, + 0.126409778992335, + 0.11854103828469913, + 0.13325187812248865, + 0.12352979555726051, + 0.12201112632950147, + 0.1196229246755441, + 0.13205363601446152 + ], + [ + 0.11999240145087242, + 0.13260090599457422, + 0.12094186122218768, + 0.12161210055152576, + 0.1285305693745613, + 0.12508359303077063, + 0.12076066558559735, + 0.13047767927249274 + ], + [ + 0.12820471823215485, + 0.12854821359117827, + 0.12272358313202858, + 0.12432362015048663, + 0.12306931614875793, + 0.12038508802652359, + 0.12575785939892134, + 0.1269873840113481 + ], + [ + 0.13002181177337965, + 0.13211728632450104, + 0.1216458020110925, + 0.11764348795016606, + 0.1218843547006448, + 0.12328926473855972, + 0.12160226330161095, + 0.1317954733967781 + ], + [ + 0.1314155968526999, + 0.12376571198304494, + 0.13107706606388092, + 0.11471213524540265, + 0.13038118183612823, + 0.11950700109203656, + 0.1195463923116525, + 0.12959467495481172 + ], + [ + 0.1242268905043602, + 0.12467419356107712, + 0.13129907473921776, + 0.11222259079416592, + 0.12788565953572592, + 0.11881897350152333, + 0.1260842184225718, + 0.13478817666570345 + ], + [ + 0.1274895096818606, + 0.12022375439604123, + 0.1254719433685144, + 0.1179546279211839, + 0.1243738941848278, + 0.12324430048465729, + 0.12151822696129481, + 0.13972350458304086 + ], + [ + 0.13156223545471826, + 0.11741525307297707, + 0.12317869688073795, + 0.12224439159035683, + 0.11778135473529498, + 0.1286664940416813, + 0.12102767080068588, + 0.1381236786643664 + ], + [ + 0.12006788204113643, + 0.12958076347907385, + 0.1284430573383967, + 0.12037194768587749, + 0.12476476033528645, + 0.12210047493378322, + 0.125126543144385, + 0.12954435249169668 + ], + [ + 0.12297759825984637, + 0.12560815239946047, + 0.12864662831028303, + 0.12362178415060043, + 0.12236574913064639, + 0.12753341843684515, + 0.12292222554485004, + 0.12632423266768456 + ], + [ + 0.1274275134007136, + 0.12966816375652948, + 0.12023578584194183, + 0.12366743509968121, + 0.1233125552535057, + 0.12050872792800267, + 0.12191376214226086, + 0.1332658144334952 + ], + [ + 0.12485626339912415, + 0.12677678962548575, + 0.11963416263461113, + 0.1198193368812402, + 0.12705819184581438, + 0.11949774747093518, + 0.12504937748114267, + 0.1373079059024652 + ], + [ + 0.1171618103981018, + 0.12857511763771376, + 0.1260807402431965, + 0.11889946336547534, + 0.13355845337112746, + 0.11888961369792621, + 0.12340328097343445, + 0.13343130300442377 + ], + [ + 0.12101152042547862, + 0.12455817436178525, + 0.1273962805668513, + 0.12189114466309547, + 0.12678203855951628, + 0.11659996708234151, + 0.12582438811659813, + 0.1359362838168939 + ], + [ + 0.12317622949679692, + 0.12338119745254517, + 0.12543142711122832, + 0.11815710614124934, + 0.12725298975904784, + 0.12580565238992372, + 0.12787324686845145, + 0.12892194092273712 + ], + [ + 0.1209084043900172, + 0.1186300627887249, + 0.12189604962865512, + 0.12188772981365521, + 0.1307782679796219, + 0.12706713750958443, + 0.12832620119055113, + 0.13050592069824538 + ], + [ + 0.1260350135465463, + 0.12675354753931364, + 0.12731152648727098, + 0.11999985451499622, + 0.11963030075033505, + 0.1257828064262867, + 0.12410260736942291, + 0.13038412109017372 + ], + [ + 0.12144877761602402, + 0.13025116920471191, + 0.12315368031462033, + 0.12004957223931949, + 0.12899880359570184, + 0.12411915759245555, + 0.12217831859985988, + 0.12980028738578162 + ], + [ + 0.1291740077237288, + 0.1252783089876175, + 0.13256158431371054, + 0.12126804515719414, + 0.12133379280567169, + 0.11881930877765019, + 0.12261790533860524, + 0.1289468171695868 + ], + [ + 0.1291037363310655, + 0.12572022527456284, + 0.12957237412532172, + 0.11967806269725163, + 0.12997588763634363, + 0.11865894868969917, + 0.12191799034674962, + 0.12537255013982454 + ], + [ + 0.12742587054769197, + 0.12574236219127974, + 0.12274153530597687, + 0.1220607968668143, + 0.12481518959005673, + 0.1198483295738697, + 0.12524557610352835, + 0.13212011257807413 + ], + [ + 0.1290092058479786, + 0.12212285026907921, + 0.12419220308462779, + 0.12421460077166557, + 0.12363921975096066, + 0.12011488651235898, + 0.1233511430521806, + 0.13335565477609634 + ], + [ + 0.12538858503103256, + 0.12466845909754436, + 0.12154690672953923, + 0.12301296989123027, + 0.12981624777118364, + 0.12704829623301825, + 0.1248041180272897, + 0.12371418873469035 + ], + [ + 0.11902124434709549, + 0.12671059494217238, + 0.1260922352472941, + 0.1271773725748062, + 0.12644306818644205, + 0.12542754784226418, + 0.12590078388651213, + 0.12322693939010303 + ], + [ + 0.12317654117941856, + 0.12843865901231766, + 0.12069269766410191, + 0.125810240705808, + 0.1277270627518495, + 0.12764985983570418, + 0.12175220002730687, + 0.1247525264819463 + ], + [ + 0.1211647018790245, + 0.12715468431512514, + 0.12056185429294904, + 0.12519296631217003, + 0.12480061997969945, + 0.12839860344926515, + 0.12541193763415018, + 0.12731440116961798 + ], + [ + 0.12123912448684375, + 0.12248460327585538, + 0.12326325600345929, + 0.12683571502566338, + 0.12672854959964752, + 0.12482267990708351, + 0.12549365808566412, + 0.12913217147191366 + ], + [ + 0.12050784503420194, + 0.12555688992142677, + 0.12008845185240109, + 0.12465623517831166, + 0.13099781175454459, + 0.12576152632633844, + 0.11882188667853673, + 0.1336091235280037 + ], + [ + 0.12568799406290054, + 0.11993959173560143, + 0.1274260220428308, + 0.11817048738400142, + 0.12619078904390335, + 0.12464018786946933, + 0.12417953833937645, + 0.1337651622792085 + ], + [ + 0.12325371926029523, + 0.1211228979130586, + 0.13070687154928842, + 0.11808653796712558, + 0.13001534715294838, + 0.12158843502402306, + 0.12363952646652858, + 0.13158642997344336 + ], + [ + 0.1260810730357965, + 0.12453075995047887, + 0.1290532574057579, + 0.12438120692968369, + 0.11917600656549136, + 0.1295322080453237, + 0.12189143026868503, + 0.1253538504242897 + ], + [ + 0.12526876603563628, + 0.12080891554554303, + 0.12967888390024504, + 0.12297213574250539, + 0.11888084560632706, + 0.13118606433272362, + 0.12140955651799838, + 0.12979459141691527 + ], + [ + 0.12424750874439876, + 0.12485102439920108, + 0.1268209877113501, + 0.11906190340717633, + 0.1285974308848381, + 0.12277497847874959, + 0.12632345284024873, + 0.12732247014840445 + ], + [ + 0.125874575227499, + 0.12523056070009866, + 0.1264867732922236, + 0.12128748868902524, + 0.12441011145710945, + 0.12246322259306908, + 0.12810595333576202, + 0.12614111105600992 + ], + [ + 0.12473135565718015, + 0.12540381277600923, + 0.1253200632830461, + 0.12537001942594847, + 0.1260132429500421, + 0.12043848757942517, + 0.12780407443642616, + 0.12491871540745099 + ], + [ + 0.12605776265263557, + 0.12424845372637112, + 0.12497357651591301, + 0.12922962506612143, + 0.12503756955266, + 0.12192104011774063, + 0.1254057822128137, + 0.12312595546245575 + ], + [ + 0.12352173278729121, + 0.1265761541823546, + 0.12423497438430786, + 0.12461677193641663, + 0.12539410715301832, + 0.1255429039398829, + 0.12238302826881409, + 0.1277301013469696 + ], + [ + 0.12309405331810315, + 0.12791564936439195, + 0.12622039516766867, + 0.12361625457803409, + 0.12056272476911545, + 0.12604241694013277, + 0.12378428628047307, + 0.12876397867997488 + ], + [ + 0.12611348181962967, + 0.12320082634687424, + 0.1253701572616895, + 0.12443139652411143, + 0.12695778285463652, + 0.12241215258836746, + 0.12573333705464998, + 0.12578065693378448 + ], + [ + 0.12688816090424856, + 0.1288478635251522, + 0.12740991016228995, + 0.11879989504814148, + 0.12534291048844656, + 0.1229795292019844, + 0.12383322914441426, + 0.12589828670024872 + ], + [ + 0.12476491928100586, + 0.1262869065006574, + 0.12492788831392924, + 0.12399974713722865, + 0.12469915300607681, + 0.12713435913125673, + 0.1245567575097084, + 0.12363003566861153 + ], + [ + 0.12631421784559885, + 0.12580405548214912, + 0.12628123412529627, + 0.12935657799243927, + 0.12092787648240726, + 0.12676102295517921, + 0.1189101127286752, + 0.12564465776085854 + ], + [ + 0.12526675313711166, + 0.12217352539300919, + 0.12484640503923099, + 0.12589062377810478, + 0.12966536233822504, + 0.12522071475783983, + 0.12327616289258003, + 0.12366023659706116 + ], + [ + 0.12656376014153162, + 0.11716300249099731, + 0.12596874559919038, + 0.1257962646583716, + 0.1269005909562111, + 0.12683834383885065, + 0.12561031058430672, + 0.1251587631801764 + ], + [ + 0.13022523125012717, + 0.1282994436721007, + 0.12454549099008243, + 0.12071131666501363, + 0.12538232281804085, + 0.11908584336439769, + 0.12621022264162698, + 0.12553991625706354 + ], + [ + 0.13221148898204169, + 0.12223595753312111, + 0.12724651272098222, + 0.12031713624795277, + 0.12249614670872688, + 0.11974877988298734, + 0.12654854729771614, + 0.12919519593318304 + ], + [ + 0.11990595608949661, + 0.1272318698465824, + 0.1267327405512333, + 0.12448954333861668, + 0.12569350625077882, + 0.12164032459259033, + 0.12683004885911942, + 0.12747578074534735 + ], + [ + 0.12444956476489703, + 0.12797370428840318, + 0.12600189199050268, + 0.12382493292291959, + 0.12690917402505875, + 0.11933111399412155, + 0.12926588455835977, + 0.12224351490537326 + ], + [ + 0.1255095216135184, + 0.1255613019069036, + 0.12129526461164157, + 0.1223294772207737, + 0.12632360309362411, + 0.12481639658411343, + 0.1274109147489071, + 0.12675329546133676 + ], + [ + 0.12528312702973685, + 0.12439442550142606, + 0.11931468298037846, + 0.12152635306119919, + 0.12659601246317229, + 0.1257651410996914, + 0.12814742450912794, + 0.12897261853019396 + ], + [ + 0.1214146117369334, + 0.1273880017300447, + 0.12608121211330095, + 0.12650746976335844, + 0.12330925837159157, + 0.12363646055261295, + 0.12210099523266156, + 0.12956176698207855 + ], + [ + 0.12062123914559682, + 0.1259618860979875, + 0.12593835219740868, + 0.12335928405324618, + 0.13069554915030798, + 0.12232162182529767, + 0.12119729444384575, + 0.12990452721714973 + ], + [ + 0.1253475584089756, + 0.1269937021036943, + 0.12334655970335007, + 0.12731938436627388, + 0.12488362938165665, + 0.12304140130678813, + 0.12077892074982326, + 0.12828861673672995 + ], + [ + 0.12448840960860252, + 0.12420830751458804, + 0.12420420721173286, + 0.12860180685917535, + 0.12120072667797406, + 0.12527909502387047, + 0.12215096006790797, + 0.12986628090341887 + ], + [ + 0.12123213832577069, + 0.12265237296621005, + 0.127204945931832, + 0.12560922528306642, + 0.1270847866932551, + 0.12552902350823084, + 0.12612696488698324, + 0.1245603288213412 + ], + [ + 0.11809998626510303, + 0.12428157900770505, + 0.1253818174203237, + 0.1257936954498291, + 0.12949112678567568, + 0.1271564637621244, + 0.12387922530372937, + 0.12591588869690895 + ], + [ + 0.126338808486859, + 0.12426133205493291, + 0.12641060600678125, + 0.12206799164414406, + 0.12264598409334819, + 0.12538964301347733, + 0.12488857160011928, + 0.1279968520005544 + ], + [ + 0.12697053079803786, + 0.12228543559710185, + 0.1261767658094565, + 0.12111446509758632, + 0.12292597194512685, + 0.12427658587694168, + 0.12521553163727125, + 0.1310344859957695 + ], + [ + 0.12766745934883753, + 0.1253199279308319, + 0.12255134185155232, + 0.12318125118811925, + 0.12039463594555855, + 0.12757552415132523, + 0.12357263018687566, + 0.12973700960477194 + ], + [ + 0.1242472343146801, + 0.1297652112940947, + 0.12095560381809871, + 0.124796728293101, + 0.12095990777015686, + 0.12902158498764038, + 0.12186569223801295, + 0.1283878001073996 + ], + [ + 0.12422218297918637, + 0.12716689084966978, + 0.1249363124370575, + 0.12150920430819194, + 0.12222799534598987, + 0.12615740795930228, + 0.12329558655619621, + 0.1304841972887516 + ], + [ + 0.12710294996698698, + 0.12885230034589767, + 0.12337821473677953, + 0.1203689177831014, + 0.12481469785173734, + 0.1250703458984693, + 0.12101439014077187, + 0.12939795975883803 + ], + [ + 0.12462215994795163, + 0.12458508337537448, + 0.12500757724046707, + 0.1232065645356973, + 0.12567366287112236, + 0.12203507497906685, + 0.12521824489037195, + 0.1296513962248961 + ], + [ + 0.12396224463979404, + 0.12786907578508058, + 0.12520382553339005, + 0.12352350726723671, + 0.12572820608814558, + 0.12161949028571446, + 0.12249657387534778, + 0.129596841832002 + ], + [ + 0.12068009624878566, + 0.12379990518093109, + 0.12593792006373405, + 0.1276268350581328, + 0.1261026325325171, + 0.12102650726834933, + 0.12821885446707407, + 0.12660702566305795 + ], + [ + 0.1233363188803196, + 0.12685688709219298, + 0.1282887620230516, + 0.12687948221961656, + 0.1214559090634187, + 0.12196332216262817, + 0.1233159638941288, + 0.127903134872516 + ], + [ + 0.12056848406791687, + 0.12428977464636166, + 0.12211137513319652, + 0.12405076622962952, + 0.1270131046573321, + 0.12591566642125449, + 0.1266051009297371, + 0.1294455093642076 + ], + [ + 0.12070096284151077, + 0.12888250251611075, + 0.1211340253551801, + 0.12761419887344042, + 0.12578162799278894, + 0.12303286169966061, + 0.12564598148067793, + 0.12720761448144913 + ], + [ + 0.1237399031718572, + 0.12495135888457298, + 0.12168942764401436, + 0.12188419699668884, + 0.12582296008865038, + 0.12607480337222418, + 0.126584326227506, + 0.1292528100311756 + ] + ], + "load_balancing_losses": [ + 0.06011591739952564, + 0.06009032502770424, + 0.06007496602833271, + 0.060119301453232764, + 0.06025106981396675, + 0.06042454689741135, + 0.06050819158554077, + 0.06040439046919346, + 0.06020762734115124, + 0.0600824985653162, + 0.06001702025532722, + 0.059980913251638415, + 0.05997035503387451, + 0.0599874921143055, + 0.06003512442111969, + 0.06003240235149861, + 0.06010275147855282, + 0.06043860539793968, + 0.060852956026792526, + 0.06072245985269546, + 0.0610693134367466, + 0.06144227758049965, + 0.06105490066111088, + 0.061231859028339386, + 0.06127622313797474, + 0.060751666128635404, + 0.060541939362883566, + 0.06071356795728207, + 0.06042528674006462, + 0.06035650297999382, + 0.060394616797566414, + 0.06033426821231842, + 0.06030041724443436, + 0.06031496711075306, + 0.06021142601966858, + 0.060243581607937816, + 0.060279525443911554, + 0.06031738966703415, + 0.060062835738062856, + 0.060149386525154114, + 0.06020931825041771, + 0.06027392894029617, + 0.0601655475795269, + 0.06007538698613644, + 0.06006893813610077, + 0.06012786403298378, + 0.06015343964099884, + 0.06010826490819454, + 0.060064369812607765, + 0.06001862846314907, + 0.06011522077023983, + 0.06011498793959617, + 0.06005149334669113, + 0.060055335983633995, + 0.06012761406600475, + 0.06004953421652317, + 0.06012034453451633, + 0.060062723234295845, + 0.060141066834330556, + 0.05999366156756878, + 0.060102271288633345, + 0.05999068170785904, + 0.060199663043022156, + 0.06013624221086502, + 0.060072819516062734, + 0.060046962648630145, + 0.06006497293710709, + 0.060109606757760045, + 0.060041088983416556, + 0.060026617348194124, + 0.06006389781832695, + 0.0600794717669487, + 0.06007534712553024, + 0.06000873222947121, + 0.06010126173496246, + 0.06002997867763042, + 0.06010465919971466, + 0.06000690646469593, + 0.06006423607468605, + 0.06008252203464508, + 0.06006045490503311, + 0.06007501110434532, + 0.06004559099674225, + 0.06002157405018806, + 0.060094448551535604, + 0.060062019526958464, + 0.06004556752741337, + 0.05999722816050053, + 0.06003735214471817, + 0.060049128532409665, + 0.060085409134626386, + 0.06008372902870178, + 0.05999944508075714, + 0.05999916717410088, + 0.06007374040782452, + 0.05997239574790001, + 0.06008260659873486, + 0.059988987445831296, + 0.060002564638853076, + 0.060079767182469365 + ] + }, + "routing_stats": { + "avg_entropy": 0.0, + "avg_confidence": 0.0, + "expert_utilization": [ + 0.12387550994753838, + 0.12439893558621407, + 0.12115537996093433, + 0.12058991814653079, + 0.12618273124098778, + 0.12656908358136812, + 0.1266754890481631, + 0.13055273766318956 + ], + "num_layers": 0 + }, + "config": { + "max_steps": 1000, + "batch_size": 24, + "num_experts": 8, + "expert_top_k": 2 + } +} \ No newline at end of file diff --git a/experiments/exp10_routing_temperature_specialization/run_experiment.py b/experiments/exp10_routing_temperature_specialization/run_experiment.py new file mode 100644 index 0000000..a375d69 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/run_experiment.py @@ -0,0 +1,327 @@ +""" +Main script to run temperature routing experiments +""" +import argparse +import json +import os +import sys +import random +from pathlib import Path +import torch +from torch.utils.data import DataLoader + +# Fix tokenizer parallelism warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Add project root to path +script_dir = Path(__file__).resolve().parent +project_root = script_dir.parent.parent +sys.path.insert(0, str(script_dir)) +sys.path.insert(0, str(project_root)) + +from configs.moe_config import MoEModelConfig +from configs.dataset_config import DataConfig +from utils.helpers import set_seed +from utils.logger import setup_logging +from config import ( + get_experiment_config, + list_experiments, + TEMPERATURE_ABLATION, + TEMPERATURE_SCHEDULES, + ALL_EXPERIMENTS, +) +from tracking_trainer import train_with_temperature_tracking +from temperature_model import create_temperature_moe_model + + +def prepare_data(config: MoEModelConfig): + """Prepare train and validation data loaders""" + print("Loading dataset with Hugging Face Datasets API...") + data_cfg = DataConfig( + dataset_path="HuggingFaceTB/smollm-corpus", + dataset_name="cosmopedia-v2", + tokenizer_name="HuggingFaceTB/SmolLM-135M", + seq_length=config.max_seq_len, + num_samples=config.num_documents, + cache_dir="./hf_cache", + ) + + # Split documents BEFORE tokenization to prevent data leakage + from datasets import load_dataset, Dataset + print("Loading raw dataset and splitting documents...") + raw_dataset = load_dataset( + data_cfg.dataset_path, + data_cfg.dataset_name, + split=data_cfg.split, + cache_dir=data_cfg.cache_dir, + streaming=True, + ) + + # Take samples and split into train/val + raw_samples = list(raw_dataset.take(data_cfg.num_samples)) + random.shuffle(raw_samples) + num_val = int(len(raw_samples) * 0.1) + num_train = len(raw_samples) - num_val + + raw_train = Dataset.from_list(raw_samples[:num_train]) + raw_val = Dataset.from_list(raw_samples[num_train:]) + print(f"Split into {len(raw_train):,} train docs and {len(raw_val):,} val docs") + + # Now tokenize each split separately + from data.loader import setup_tokenizer, tokenize_and_chunk, finalize_dataset + tokenizer = setup_tokenizer(data_cfg) + config.vocab_size = tokenizer.vocab_size + + print("Tokenizing train set...") + train_ds = tokenize_and_chunk(raw_train, tokenizer, data_cfg) + train_ds = finalize_dataset(train_ds, data_cfg) + + print("Tokenizing validation set...") + val_ds = tokenize_and_chunk(raw_val, tokenizer, data_cfg) + val_ds = finalize_dataset(val_ds, data_cfg) + + print(f"Train sequences: {len(train_ds):,}, Val sequences: {len(val_ds):,}") + + loader_args = dict( + batch_size=config.batch_size, + num_workers=2, + pin_memory=torch.cuda.is_available(), + persistent_workers=True, + ) + train_loader = DataLoader(train_ds, shuffle=True, **loader_args) + val_loader = DataLoader(val_ds, shuffle=False, **loader_args) + + return train_loader, val_loader + + +def run_single_experiment(exp_name: str, output_dir: str = "./results"): + """Run a single temperature experiment""" + logger = setup_logging(log_dir="./logs") + logger.info(f"Running experiment: {exp_name}") + + set_seed(42) + + # Get experiment configuration + temp_config = get_experiment_config(exp_name) + + # Create model config + model_config = MoEModelConfig() + model_config.max_steps = temp_config.max_steps + + # Prepare data + train_loader, val_loader = prepare_data(model_config) + + # Create model with temperature-aware routing + model = create_temperature_moe_model(model_config) + + # Create experiment output directory + exp_output_dir = Path(output_dir) / exp_name + exp_output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n{'='*80}") + print(f"Experiment: {exp_name}") + print(f"Description: {temp_config.description}") + print(f"Temperature: {temp_config.temperature}") + print(f"Schedule: {temp_config.temperature_schedule or 'constant'}") + print(f"Steps: {temp_config.max_steps}") + print(f"Output: {exp_output_dir}") + print(f"{'='*80}\n") + + # Train model + model, metrics, history, routing_stats = train_with_temperature_tracking( + model=model, + config=model_config, + temp_config=temp_config, + train_loader=train_loader, + val_loader=val_loader, + output_dir=str(exp_output_dir), + ) + + # Save results + results = { + 'experiment_name': exp_name, + 'description': temp_config.description, + 'temperature': temp_config.temperature, + 'temperature_schedule': temp_config.temperature_schedule, + 'final_metrics': metrics, + 'history': history, + 'routing_stats': routing_stats, + 'config': { + 'max_steps': temp_config.max_steps, + 'batch_size': model_config.batch_size, + 'num_experts': model_config.num_experts, + 'expert_top_k': model_config.expert_top_k, + } + } + + # Save metrics + metrics_file = exp_output_dir / "metrics.json" + with open(metrics_file, 'w') as f: + json.dump(results, f, indent=2, default=str) + print(f"\n✅ Results saved to {metrics_file}") + + # Save model checkpoint + model_file = exp_output_dir / "model.pt" + torch.save({ + 'model_state_dict': model.state_dict(), + 'config': model_config, + 'metrics': metrics, + }, model_file) + print(f"✅ Model saved to {model_file}") + + return results + + +def run_multiple_experiments(exp_names: list, output_dir: str = "./results"): + """Run multiple experiments sequentially""" + logger = setup_logging(log_dir="./logs") + logger.info(f"Running {len(exp_names)} experiments") + + results = {} + + for i, exp_name in enumerate(exp_names): + print(f"\n{'='*80}") + print(f"Running experiment {i+1}/{len(exp_names)}: {exp_name}") + print(f"{'='*80}\n") + + try: + result = run_single_experiment(exp_name, output_dir) + results[exp_name] = result + + print(f"\n✅ Experiment '{exp_name}' completed successfully") + print(f" Final loss: {result['final_metrics']['val_loss']:.4f}") + print(f" Final accuracy: {result['final_metrics']['val_accuracy']:.4f}") + + except Exception as e: + import traceback + print(f"\n❌ Experiment '{exp_name}' failed with error: {e}") + print("\nFull traceback:") + traceback.print_exc() + logger.error(f"Experiment '{exp_name}' failed: {e}") + logger.error(traceback.format_exc()) + continue + + # Save summary + summary_file = Path(output_dir) / "experiment_summary.json" + with open(summary_file, 'w') as f: + json.dump({ + 'experiments': list(results.keys()), + 'num_completed': len(results), + 'num_requested': len(exp_names), + }, f, indent=2) + print(f"\n📁 Experiment summary saved to {summary_file}") + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Run routing temperature experiments for MoE training", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # List all available experiments + python run_experiment.py --list + + # Run single temperature + python run_experiment.py --experiment temp_1.0 + + # Run temperature ablation + python run_experiment.py --ablation + + # Run temperature schedules + python run_experiment.py --schedules + + # Run all experiments + python run_experiment.py --all + """ + ) + + parser.add_argument( + '--experiment', '-e', + type=str, + help='Single experiment to run' + ) + parser.add_argument( + '--experiments', + nargs='+', + help='Multiple experiments to run (space-separated)' + ) + parser.add_argument( + '--ablation', + action='store_true', + help='Run all temperature ablation experiments' + ) + parser.add_argument( + '--schedules', + action='store_true', + help='Run all temperature schedule experiments' + ) + parser.add_argument( + '--all', + action='store_true', + help='Run all experiments' + ) + parser.add_argument( + '--list', '-l', + action='store_true', + help='List all available experiments and exit' + ) + parser.add_argument( + '--output-dir', '-o', + default='./results', + help='Output directory for results (default: ./results)' + ) + parser.add_argument( + '--temperature', '-t', + type=float, + help='Run with specific temperature (creates temp_X.X experiment)' + ) + + args = parser.parse_args() + + if args.list: + list_experiments() + return + + # Determine which experiments to run + exp_names = [] + + if args.all: + exp_names = list(ALL_EXPERIMENTS.keys()) + print(f"Running all {len(exp_names)} experiments...") + elif args.ablation: + exp_names = list(TEMPERATURE_ABLATION.keys()) + print(f"Running temperature ablation ({len(exp_names)} experiments)...") + elif args.schedules: + exp_names = list(TEMPERATURE_SCHEDULES.keys()) + print(f"Running temperature schedules ({len(exp_names)} experiments)...") + elif args.experiments: + exp_names = args.experiments + elif args.experiment: + exp_names = [args.experiment] + elif args.temperature: + # Create custom temperature experiment + exp_name = f"temp_{args.temperature}" + print(f"Running custom temperature experiment: {exp_name}") + run_single_experiment(exp_name, args.output_dir) + return + else: + parser.print_help() + print("\n❌ No experiments specified. Use --list to see available experiments.") + return + + # Run experiments + if len(exp_names) == 1: + run_single_experiment(exp_names[0], args.output_dir) + else: + run_multiple_experiments(exp_names, args.output_dir) + + print(f"\n{'='*80}") + print(f"✅ All experiments completed!") + print(f"{'='*80}\n") + + +if __name__ == "__main__": + main() + diff --git a/experiments/exp10_routing_temperature_specialization/temperature_model.py b/experiments/exp10_routing_temperature_specialization/temperature_model.py new file mode 100644 index 0000000..8659221 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/temperature_model.py @@ -0,0 +1,201 @@ +""" +Model creation with temperature-aware MoE components +""" +import sys +import torch +import torch.nn as nn +import math +from pathlib import Path + +# Add project root to path +project_root = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(project_root)) + +from configs.moe_config import MoEModelConfig +from models.layers import MultiHeadAttention, MultiHeadLatentAttention +from temperature_moe import TemperatureMoE + + +class TemperatureMoETransformerBlock(nn.Module): + """Transformer block with temperature-aware MoE""" + + def __init__( + self, + d_model: int, + n_heads: int, + d_ff: int, + use_mla: bool, + qk_rope_dim: int | None, + qk_nope_dim: int | None, + kv_lora_rank: int | None, + v_dim: int | None, + max_seq_len: int, + num_experts: int = 8, + top_k: int = 2, + dropout: float = 0.1, + ): + super().__init__() + + # Attention layer (reuse from main models) + if use_mla: + self.attention = MultiHeadLatentAttention( + d_model, + n_heads, + qk_rope_dim, + qk_nope_dim, + kv_lora_rank, + v_dim, + max_seq_len, + dropout, + ) + else: + self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout) + + # Temperature-aware MoE layer + self.feed_forward = TemperatureMoE(d_model, d_ff, num_experts, top_k, dropout) + + # Normalization layers + self.norm1 = nn.RMSNorm(d_model) + self.norm2 = nn.RMSNorm(d_model) + self.dropout = nn.Dropout(dropout) + + def set_temperature(self, temperature: float): + """Set temperature for MoE routing""" + self.feed_forward.set_temperature(temperature) + + def forward(self, x, return_routing_stats=False): + # Self-attention + attn_out = self.attention(self.norm1(x)) + x = x + self.dropout(attn_out) + + # MoE feed-forward + ff_out, aux_loss, routing_stats = self.feed_forward( + self.norm2(x), + return_routing_stats=return_routing_stats + ) + x = x + self.dropout(ff_out) + + if return_routing_stats: + return x, aux_loss, routing_stats + return x, aux_loss + + +class TemperatureMoEModel(nn.Module): + """Complete MoE LLM with temperature-aware routing""" + + def __init__(self, config: MoEModelConfig): + super().__init__() + self.config = config + + # Token embeddings + self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) + self.position_dropout = nn.Dropout(config.dropout) + + # Transformer blocks with temperature-aware MoE + self.transformer_blocks = nn.ModuleList( + [ + TemperatureMoETransformerBlock( + config.d_model, + config.n_heads, + config.d_ff, + config.use_mla, + config.qk_rope_dim, + config.qk_nope_dim, + config.kv_lora_rank, + config.v_dim, + config.max_seq_len, + config.num_experts, + config.expert_top_k, + config.dropout, + ) + for i in range(config.n_layers) + ] + ) + + # Output layers + self.norm = nn.RMSNorm(config.d_model) + self.output_dropout = nn.Dropout(config.dropout) + + # Language modeling head (tied with embeddings) + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + self.lm_head.weight = self.token_embedding.weight + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def set_temperature(self, temperature: float): + """Set routing temperature for all MoE layers""" + for block in self.transformer_blocks: + block.set_temperature(temperature) + + def forward(self, x, return_aux_loss=True, return_routing_stats=False): + # Token embeddings + x = self.token_embedding(x) * math.sqrt(self.config.d_model) + x = self.position_dropout(x) + + # Collect auxiliary losses and routing stats + aux_losses = [] + routing_stats_list = [] + + # Pass through transformer blocks + for block in self.transformer_blocks: + if return_routing_stats: + x, aux_loss, routing_stats = block(x, return_routing_stats=True) + if routing_stats is not None: + routing_stats_list.append(routing_stats) + else: + x, aux_loss = block(x, return_routing_stats=False) + + if aux_loss is not None and return_aux_loss: + aux_losses.append(aux_loss) + + # Output projection + x = self.norm(x) + x = self.output_dropout(x) + logits = self.lm_head(x) + + # Combine auxiliary losses + total_aux_loss = sum(aux_losses) if aux_losses else None + + if return_routing_stats: + return logits, total_aux_loss, routing_stats_list + + if return_aux_loss: + return logits, total_aux_loss + return logits + + +def create_temperature_moe_model(config: MoEModelConfig) -> TemperatureMoEModel: + """ + Create a temperature-aware MoE model. + + Args: + config: Model configuration + + Returns: + TemperatureMoEModel instance + """ + model = TemperatureMoEModel(config) + + # Count parameters + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print(f"\n{'='*80}") + print(f"Temperature-aware MoE Model Created") + print(f"{'='*80}") + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_params:,}") + print(f"Number of experts: {config.num_experts}") + print(f"Top-k routing: {config.expert_top_k}") + print(f"{'='*80}\n") + + return model + diff --git a/experiments/exp10_routing_temperature_specialization/temperature_moe.py b/experiments/exp10_routing_temperature_specialization/temperature_moe.py new file mode 100644 index 0000000..d06945e --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/temperature_moe.py @@ -0,0 +1,165 @@ +""" +Temperature-aware Mixture of Experts implementation +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional +from temperature_router import TemperatureRouter + + +class Expert(nn.Module): + """Single expert network (essentially a FeedForward layer)""" + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): + super().__init__() + self.linear1 = nn.Linear(d_model, d_ff, bias=False) + self.linear2 = nn.Linear(d_ff, d_model, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.linear2(self.dropout(F.silu(self.linear1(x)))) + + +class TemperatureMoE(nn.Module): + """ + Mixture of Experts layer with temperature-controlled routing. + + This version extends the standard MoE with: + 1. Temperature-scaled routing + 2. Detailed routing statistics tracking + 3. Expert specialization analysis + """ + + def __init__( + self, + d_model: int, + d_ff: int, + num_experts: int = 8, + top_k: int = 2, + dropout: float = 0.1, + load_balancing_weight: float = 0.01 + ): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.load_balancing_weight = load_balancing_weight + + # Create experts + self.experts = nn.ModuleList([ + Expert(d_model, d_ff, dropout) for _ in range(num_experts) + ]) + + # Create temperature-aware router + self.router = TemperatureRouter(d_model, num_experts, top_k) + + # Expert statistics (accumulated over training) + self.expert_activation_counts = torch.zeros(num_experts) + self.expert_activation_history = [] + + def set_temperature(self, temperature: float): + """Set the routing temperature""" + self.router.set_temperature(temperature) + + def forward( + self, + x: torch.Tensor, + return_routing_stats: bool = False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[dict]]: + """ + Args: + x: Input tensor [batch_size, seq_len, d_model] + return_routing_stats: Whether to return routing statistics + + Returns: + - output: MoE output [batch_size, seq_len, d_model] + - aux_loss: Load balancing auxiliary loss (only during training) + - routing_stats: Routing statistics (if return_routing_stats=True) + """ + batch_size, seq_len, d_model = x.shape + + # Get routing decisions + router_weights, expert_indices, router_probs, routing_stats = self.router(x) + + # Initialize output tensor + output = torch.zeros_like(x) + + # Track expert activations + expert_hits = torch.zeros(self.num_experts, device=x.device) + + # Process each expert + for expert_idx in range(self.num_experts): + # Find tokens routed to this expert + expert_mask = (expert_indices == expert_idx).any(dim=-1) # [batch_size, seq_len] + + if expert_mask.any(): + # Count activations + expert_hits[expert_idx] = expert_mask.sum().item() + + # Get tokens for this expert + expert_input = x[expert_mask] # [num_tokens, d_model] + + # Apply expert + expert_output = self.experts[expert_idx](expert_input) + + # Get weights for this expert + mask_for_expert = (expert_indices == expert_idx) # [batch, seq, top_k] + positions = mask_for_expert[expert_mask].float().argmax(dim=-1) + expert_weights = router_weights[expert_mask].gather( + -1, positions.unsqueeze(-1) + ).squeeze(-1) + + # Add weighted expert output to result + output[expert_mask] += expert_weights.unsqueeze(-1) * expert_output + + # Update activation counts + with torch.no_grad(): + self.expert_activation_counts += expert_hits.cpu() + + # Compute load balancing loss during training + aux_loss = None + if self.training: + aux_loss = self._compute_load_balancing_loss(router_probs, expert_indices) + + # Return routing stats if requested + if return_routing_stats: + routing_stats['expert_hits'] = expert_hits.cpu().numpy().tolist() + return output, aux_loss, routing_stats + + return output, aux_loss, None + + def _compute_load_balancing_loss( + self, + router_probs: torch.Tensor, + expert_indices: torch.Tensor + ) -> torch.Tensor: + """ + Compute auxiliary loss to ensure balanced expert usage. + This encourages the router to distribute tokens evenly across experts. + """ + # Compute the fraction of tokens routed to each expert + expert_mask = F.one_hot(expert_indices, num_classes=self.num_experts).float() + tokens_per_expert = expert_mask.sum(dim=[0, 1, 2]) / expert_mask.sum() + + # Compute the average probability of routing to each expert + router_prob_mean = router_probs.mean(dim=[0, 1]) + + # Load balancing loss encourages uniform distribution + aux_loss = torch.sum(tokens_per_expert * router_prob_mean) * self.num_experts + + return aux_loss * self.load_balancing_weight + + def get_expert_stats(self) -> dict: + """Get expert activation statistics""" + total_activations = self.expert_activation_counts.sum() + return { + 'total_activations': total_activations.item(), + 'expert_counts': self.expert_activation_counts.numpy().tolist(), + 'expert_distribution': (self.expert_activation_counts / total_activations).numpy().tolist() if total_activations > 0 else None, + } + + def reset_stats(self): + """Reset expert statistics""" + self.expert_activation_counts = torch.zeros(self.num_experts) + self.expert_activation_history = [] + self.router.reset_stats() + diff --git a/experiments/exp10_routing_temperature_specialization/temperature_router.py b/experiments/exp10_routing_temperature_specialization/temperature_router.py new file mode 100644 index 0000000..c8223e0 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/temperature_router.py @@ -0,0 +1,131 @@ +""" +Temperature-aware router for MoE experiments +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple + + +class TemperatureRouter(nn.Module): + """ + Router that selects top-k experts with temperature-scaled softmax. + + Temperature controls the sharpness of the routing distribution: + - Low temperature (< 1.0): Sharp, confident routing (exploitation) + - Temperature = 1.0: Standard softmax (baseline) + - High temperature (> 1.0): Soft, exploratory routing (exploration) + - Very high temperature (>> 1.0): Nearly uniform routing + """ + + def __init__(self, d_model: int, num_experts: int, top_k: int = 2): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.gate = nn.Linear(d_model, num_experts, bias=False) + self.noise_std = 0.1 # Standard deviation for noise during training + + # Temperature is set dynamically during forward pass + self.current_temperature = 1.0 + + # Statistics tracking + self.expert_counts = None + self.routing_entropy_history = [] + self.selection_confidence_history = [] + + def set_temperature(self, temperature: float): + """Set the current routing temperature""" + self.current_temperature = max(temperature, 0.01) # Avoid division by zero + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: + """ + Args: + x: Input tensor [batch_size, seq_len, d_model] + + Returns: + - router_weights: Softmax weights for selected experts [batch_size, seq_len, top_k] + - expert_indices: Indices of selected experts [batch_size, seq_len, top_k] + - router_probs: Full probability distribution over experts (for load balancing loss) + - routing_stats: Dictionary with routing statistics + """ + batch_size, seq_len, d_model = x.shape + + # Compute router logits + router_logits = self.gate(x) # [batch_size, seq_len, num_experts] + + # Add noise during training for exploration + if self.training and self.noise_std > 0: + noise = torch.randn_like(router_logits) * self.noise_std + router_logits = router_logits + noise + + # Apply temperature scaling + scaled_logits = router_logits / self.current_temperature + + # Get full probability distribution (for load balancing loss and analysis) + router_probs = F.softmax(scaled_logits, dim=-1) + + # Select top-k experts + top_k_logits, top_k_indices = torch.topk(scaled_logits, self.top_k, dim=-1) + top_k_weights = F.softmax(top_k_logits, dim=-1) + + # Compute routing statistics + routing_stats = self._compute_routing_stats( + router_probs, top_k_weights, top_k_indices + ) + + return top_k_weights, top_k_indices, router_probs, routing_stats + + def _compute_routing_stats( + self, + router_probs: torch.Tensor, + top_k_weights: torch.Tensor, + top_k_indices: torch.Tensor + ) -> dict: + """ + Compute routing statistics for analysis. + + Returns dictionary with: + - routing_entropy: Average entropy of routing distribution + - selection_confidence: Average confidence in top-1 expert + - expert_utilization: Fraction of tokens routed to each expert + """ + with torch.no_grad(): + # Routing entropy: measure of routing diversity + # High entropy = more uniform routing, low entropy = sharp routing + entropy = -torch.sum(router_probs * torch.log(router_probs + 1e-10), dim=-1) + avg_entropy = entropy.mean().item() + + # Selection confidence: how strongly the top expert is preferred + top1_confidence = top_k_weights[:, :, 0].mean().item() + + # Expert utilization: how many tokens each expert processes + expert_mask = F.one_hot(top_k_indices, num_classes=self.num_experts).float() + expert_usage = expert_mask.sum(dim=[0, 1, 2]) / (expert_mask.sum() + 1e-10) + + # Update running statistics + if self.expert_counts is None: + self.expert_counts = expert_usage.cpu() + else: + self.expert_counts = 0.9 * self.expert_counts + 0.1 * expert_usage.cpu() + + return { + 'routing_entropy': avg_entropy, + 'selection_confidence': top1_confidence, + 'expert_utilization': expert_usage.cpu().numpy().tolist(), + 'temperature': self.current_temperature, + } + + def get_routing_summary(self) -> dict: + """Get summary statistics for the entire training run""" + return { + 'final_expert_counts': self.expert_counts.numpy().tolist() if self.expert_counts is not None else None, + 'routing_entropy_history': self.routing_entropy_history, + 'selection_confidence_history': self.selection_confidence_history, + } + + def reset_stats(self): + """Reset routing statistics""" + self.expert_counts = None + self.routing_entropy_history = [] + self.selection_confidence_history = [] + diff --git a/experiments/exp10_routing_temperature_specialization/tracking_trainer.py b/experiments/exp10_routing_temperature_specialization/tracking_trainer.py new file mode 100644 index 0000000..74d8259 --- /dev/null +++ b/experiments/exp10_routing_temperature_specialization/tracking_trainer.py @@ -0,0 +1,296 @@ +""" +Custom trainer with routing statistics tracking +""" +import sys +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import Optional + +# Add project root to path +project_root = Path(__file__).resolve().parent.parent.parent +sys.path.insert(0, str(project_root)) + +from configs.moe_config import MoEModelConfig +from optimizers.muon import Muon +from training.evaluation import evaluate_model +from utils.logger import setup_logging +from config import TemperatureConfig + + +def setup_optimizer(model: nn.Module, config: MoEModelConfig): + """Setup Muon optimizer with optimal settings from exp9""" + # Separate parameters by dimensionality + muon_params = [] + adamw_params = [] + + for name, param in model.named_parameters(): + if param.requires_grad: + # Use Muon for 2D parameters (weight matrices) + if param.ndim >= 2 and 'embedding' not in name.lower() and 'norm' not in name.lower(): + muon_params.append(param) + else: + adamw_params.append(param) + + # Create two separate optimizers + muon_optimizer = Muon( + muon_params, + lr=config.muon_lr, + momentum=config.muon_momentum, + nesterov=True, + ns_steps=5, + ) + + adamw_optimizer = torch.optim.AdamW( + adamw_params, + lr=config.adamw_lr, + weight_decay=config.weight_decay, + betas=(0.9, 0.95), + eps=1e-8, + ) + + # Return list of optimizers (like exp9) + return [muon_optimizer, adamw_optimizer] + + +def get_lr_schedulers(optimizers, config: MoEModelConfig, warmup_steps: int, total_steps: int): + """Create cosine learning rate schedules with warmup for all optimizers""" + import math + + def lr_lambda(step): + if step < warmup_steps: + # Linear warmup + return step / max(1, warmup_steps) + else: + # Cosine decay + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress)) + + return [torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda) for opt in optimizers] + + +def train_with_temperature_tracking( + model: nn.Module, + config: MoEModelConfig, + temp_config: TemperatureConfig, + train_loader, + val_loader, + output_dir: str = "." +): + """ + Train model with temperature-controlled routing and comprehensive tracking. + + Args: + model: MoE model with temperature-aware routing + config: Model configuration + temp_config: Temperature experiment configuration + train_loader: Training data loader + val_loader: Validation data loader + output_dir: Directory to save results + + Returns: + model: Trained model + metrics: Final evaluation metrics + history: Training history with routing statistics + """ + logger = setup_logging(log_dir=Path(output_dir) / "logs") + logger.info(f"Training with temperature config: {temp_config.name}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + # Setup optimizers and schedulers (hybrid Muon + AdamW) + optimizers = setup_optimizer(model, config) + + warmup_steps = int(config.warmup_ratio * config.max_steps) + schedulers = get_lr_schedulers(optimizers, config, warmup_steps, config.max_steps) + + # Use AMP for mixed precision training + scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp) + + # Training state + global_step = 0 + start_time = time.time() + + # History tracking + history = { + 'steps': [], + 'train_losses': [], + 'val_losses': [], + 'val_accuracies': [], + 'val_perplexities': [], + 'elapsed_times': [], + 'learning_rates': [], + 'temperatures': [], + 'routing_entropies': [], + 'selection_confidences': [], + 'expert_utilizations': [], + 'load_balancing_losses': [], + } + + logger.info(f"Starting training for {config.max_steps} steps") + logger.info(f"Temperature: {temp_config.temperature}, Schedule: {temp_config.temperature_schedule}") + + model.train() + accumulated_loss = 0.0 + accumulated_aux_loss = 0.0 + for opt in optimizers: + opt.zero_grad() + + train_iter = iter(train_loader) + + for step in range(config.max_steps): + # Update temperature based on schedule + current_temp = temp_config.get_temperature_at_step(step) + + # Set temperature for all MoE layers + for module in model.modules(): + if hasattr(module, 'set_temperature'): + module.set_temperature(current_temp) + + # Get next batch + try: + batch = next(train_iter) + except StopIteration: + train_iter = iter(train_loader) + batch = next(train_iter) + + input_ids = batch['input_ids'].to(device) + labels = batch['labels'].to(device) + + # Forward pass with AMP + with torch.cuda.amp.autocast(enabled=config.use_amp): + logits, aux_loss = model(input_ids, return_aux_loss=True) + + # Compute language modeling loss + lm_loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + ignore_index=-100 + ) + + # Total loss = LM loss + load balancing loss + loss = lm_loss + if aux_loss is not None: + loss = loss + aux_loss + + # Backward pass with gradient scaling + scaled_loss = loss / config.gradient_accumulation_steps + scaler.scale(scaled_loss).backward() + + accumulated_loss += lm_loss.item() + if aux_loss is not None: + accumulated_aux_loss += aux_loss.item() + + # Update weights after accumulation steps + if (step + 1) % config.gradient_accumulation_steps == 0: + # Unscale, clip, and step all optimizers + for opt in optimizers: + scaler.unscale_(opt) + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) + for opt in optimizers: + scaler.step(opt) + scaler.update() + for sched in schedulers: + sched.step() + for opt in optimizers: + opt.zero_grad() + global_step += 1 + + # Evaluation + if (step + 1) % config.eval_every == 0: + elapsed_time = (time.time() - start_time) / 60 + current_lr = schedulers[0].get_last_lr()[0] # Use Muon LR for logging + + # Evaluate on validation set + model.eval() + val_metrics = evaluate_model(model, val_loader, config) + model.train() + + # Collect routing statistics from MoE layers + routing_stats = collect_routing_stats(model) + + # Log progress + avg_train_loss = accumulated_loss / config.eval_every + avg_aux_loss = accumulated_aux_loss / config.eval_every + + logger.info( + f"Step {step+1}/{config.max_steps} | " + f"Train Loss: {avg_train_loss:.4f} | " + f"Val Loss: {val_metrics['val_loss']:.4f} | " + f"Val Acc: {val_metrics['val_accuracy']:.4f} | " + f"LR: {current_lr:.6f} | " + f"Temp: {current_temp:.2f} | " + f"Entropy: {routing_stats['avg_entropy']:.3f} | " + f"Time: {elapsed_time:.2f}m" + ) + + # Update history + history['steps'].append(step + 1) + history['train_losses'].append(avg_train_loss) + history['val_losses'].append(val_metrics['val_loss']) + history['val_accuracies'].append(val_metrics['val_accuracy']) + history['val_perplexities'].append(val_metrics['val_perplexity']) + history['elapsed_times'].append(elapsed_time) + history['learning_rates'].append(current_lr) + history['temperatures'].append(current_temp) + history['routing_entropies'].append(routing_stats['avg_entropy']) + history['selection_confidences'].append(routing_stats['avg_confidence']) + history['expert_utilizations'].append(routing_stats['expert_utilization']) + history['load_balancing_losses'].append(avg_aux_loss) + + # Reset accumulators + accumulated_loss = 0.0 + accumulated_aux_loss = 0.0 + + # Final evaluation + model.eval() + final_metrics = evaluate_model(model, val_loader, config) + final_routing_stats = collect_routing_stats(model) + + logger.info(f"Training complete!") + logger.info(f"Final validation loss: {final_metrics['val_loss']:.4f}") + logger.info(f"Final validation accuracy: {final_metrics['val_accuracy']:.4f}") + logger.info(f"Final routing entropy: {final_routing_stats['avg_entropy']:.3f}") + + return model, final_metrics, history, final_routing_stats + + +def collect_routing_stats(model: nn.Module) -> dict: + """Collect routing statistics from all MoE layers""" + entropies = [] + confidences = [] + utilizations = [] + + for module in model.modules(): + if hasattr(module, 'router') and hasattr(module.router, 'routing_entropy_history'): + router = module.router + + # Get latest stats + if hasattr(router, 'routing_entropy_history') and router.routing_entropy_history: + entropies.append(router.routing_entropy_history[-1]) + if hasattr(router, 'selection_confidence_history') and router.selection_confidence_history: + confidences.append(router.selection_confidence_history[-1]) + if hasattr(router, 'expert_counts') and router.expert_counts is not None: + utilizations.append(router.expert_counts.numpy().tolist()) + + # Aggregate statistics across layers + avg_entropy = sum(entropies) / len(entropies) if entropies else 0.0 + avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0 + + # Average expert utilization across layers + if utilizations: + import numpy as np + avg_utilization = np.mean(utilizations, axis=0).tolist() + else: + avg_utilization = [] + + return { + 'avg_entropy': avg_entropy, + 'avg_confidence': avg_confidence, + 'expert_utilization': avg_utilization, + 'num_layers': len(entropies), + } + diff --git a/requirements.txt b/requirements.txt index f28baa7..a377013 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ transformers torchtune torchao matplotlib +seaborn # lm-eval # Single T4 GPU training