A comprehensive LSTM (Long Short-Term Memory) neural network library implemented in Rust with complete training capabilities, multiple optimizers, and advanced regularization.
graph TD
A["Input Sequence<br/>(xβ, xβ, ..., xβ)"] --> B["LSTM Layer 1"]
B --> C["LSTM Layer 2"]
C --> D["Output Layer"]
D --> E["Predictions<br/>(yβ, yβ, ..., yβ)"]
F["Hidden State hβ"] --> B
G["Cell State cβ"] --> B
B --> H["Hidden State hβ"]
B --> I["Cell State cβ"]
H --> C
I --> C
style A fill:#e1f5fe
style E fill:#e8f5e8
style B fill:#fff3e0
style C fill:#fff3e0
- LSTM, BiLSTM & GRU Networks with multi-layer support
- Complete Training System with backpropagation through time (BPTT)
- Multiple Optimizers: SGD, Adam, RMSprop with comprehensive learning rate scheduling
- Advanced Learning Rate Scheduling: 12 different schedulers including OneCycle, Warmup, Cyclical, and Polynomial
- Loss Functions: MSE, MAE, Cross-entropy with softmax
- Advanced Dropout: Input, recurrent, output dropout, variational dropout, and zoneout
- Schedule Visualization: ASCII visualization of learning rate schedules
- Model Persistence: Save/load models in JSON or binary format
- Peephole LSTM variant for enhanced performance
Add to your Cargo.toml
:
[dependencies]
rust-lstm = "0.5.0"
use ndarray::Array2;
use rust_lstm::models::lstm_network::LSTMNetwork;
fn main() {
// Create LSTM network
let mut network = LSTMNetwork::new(3, 10, 2); // input_size, hidden_size, num_layers
// Create input data
let input = Array2::from_shape_vec((3, 1), vec![0.5, 0.1, -0.3]).unwrap();
let hx = Array2::zeros((10, 1));
let cx = Array2::zeros((10, 1));
// Forward pass
let (output, _) = network.forward(&input, &hx, &cx);
println!("Output: {:?}", output);
}
use rust_lstm::{LSTMNetwork, create_basic_trainer, TrainingConfig};
fn main() {
// Create network with dropout
let network = LSTMNetwork::new(1, 10, 2)
.with_input_dropout(0.2, true)
.with_recurrent_dropout(0.3, true);
// Setup trainer (uses SGD optimizer and MSE loss by default)
let mut trainer = create_basic_trainer(network, 0.001)
.with_config(TrainingConfig {
epochs: 100,
clip_gradient: Some(1.0),
..Default::default()
});
// Train (train_data is slice of (input_sequence, target_sequence) tuples)
// Each input_sequence and target_sequence is Vec<Array2<f64>>
trainer.train(&train_data, Some(&validation_data));
}
use rust_lstm::layers::bilstm_network::{BiLSTMNetwork, CombineMode};
// BiLSTM with concatenated outputs (output_size = 2 * hidden_size)
let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers);
// Process sequence with both past and future context
let outputs = bilstm.forward_sequence(&sequence);
graph TD
A["Input Sequence<br/>(xβ, xβ, xβ, xβ)"] --> B["Forward LSTM"]
A --> C["Backward LSTM"]
B --> D["Forward Hidden States<br/>(hββ, hββ, hββ, hββ)"]
C --> E["Backward Hidden States<br/>(hββ, hββ, hββ, hββ)"]
D --> F["Combine Layer<br/>(Concat/Sum/Average)"]
E --> F
F --> G["BiLSTM Output<br/>(combined representations)"]
style A fill:#e1f5fe
style B fill:#fff3e0
style C fill:#fff3e0
style F fill:#f3e5f5
style G fill:#e8f5e8
use rust_lstm::models::gru_network::GRUNetwork;
// Create GRU network (alternative to LSTM)
let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers)
.with_input_dropout(0.2, true)
.with_recurrent_dropout(0.3, true);
// Forward pass
let (output, _) = gru.forward(&input, &hidden_state);
graph LR
subgraph "LSTM Cell"
A1["Input xβ"] --> B1["Forget Gate<br/>fβ = Ο(WfΒ·[hβββ,xβ] + bf)"]
A1 --> C1["Input Gate<br/>iβ = Ο(WiΒ·[hβββ,xβ] + bi)"]
A1 --> D1["Candidate Values<br/>CΜβ = tanh(WCΒ·[hβββ,xβ] + bC)"]
A1 --> E1["Output Gate<br/>oβ = Ο(WoΒ·[hβββ,xβ] + bo)"]
B1 --> F1["Cell State<br/>Cβ = fβ * Cβββ + iβ * CΜβ"]
C1 --> F1
D1 --> F1
F1 --> G1["Hidden State<br/>hβ = oβ * tanh(Cβ)"]
E1 --> G1
end
subgraph "GRU Cell"
A2["Input xβ"] --> B2["Reset Gate<br/>rβ = Ο(WrΒ·[hβββ,xβ])"]
A2 --> C2["Update Gate<br/>zβ = Ο(WzΒ·[hβββ,xβ])"]
A2 --> D2["Candidate State<br/>hΜβ = tanh(WΒ·[rβ*hβββ,xβ])"]
B2 --> D2
C2 --> E2["Hidden State<br/>hβ = (1-zβ)*hβββ + zβ*hΜβ"]
D2 --> E2
end
style B1 fill:#ffcdd2
style C1 fill:#c8e6c9
style D1 fill:#fff3e0
style E1 fill:#e1f5fe
style B2 fill:#ffcdd2
style C2 fill:#c8e6c9
style D2 fill:#fff3e0
The library includes 12 different learning rate schedulers with visualization capabilities:
use rust_lstm::{
LSTMNetwork, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
ScheduledOptimizer, PolynomialLR, CyclicalLR, WarmupScheduler,
LRScheduleVisualizer, Adam
};
// Create a network
let network = LSTMNetwork::new(1, 10, 2);
// Step decay: reduce LR by 50% every 10 epochs
let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5);
// OneCycle policy for modern deep learning
let mut trainer = create_one_cycle_trainer(network.clone(), 0.1, 100);
// Cosine annealing with warm restarts
let mut trainer = create_cosine_annealing_trainer(network.clone(), 0.01, 20, 1e-6);
// Advanced combinations - Warmup + Cyclical scheduling
let base_scheduler = CyclicalLR::new(0.001, 0.01, 10);
let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001);
let optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01);
// Polynomial decay with visualization
let poly_scheduler = PolynomialLR::new(100, 2.0, 0.001);
LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, 100, 60, 10);
- ConstantLR: No scheduling (baseline)
- StepLR: Step decay at regular intervals
- MultiStepLR: Multi-step decay at specific milestones
- ExponentialLR: Exponential decay each epoch
- CosineAnnealingLR: Smooth cosine oscillation
- CosineAnnealingWarmRestarts: Cosine with periodic restarts
- OneCycleLR: One cycle policy for super-convergence
- ReduceLROnPlateau: Adaptive reduction on validation plateaus
- LinearLR: Linear interpolation between rates
- PolynomialLR β¨: Polynomial decay with configurable power
- CyclicalLR β¨: Triangular, triangular2, and exponential range modes
- WarmupScheduler β¨: Gradual warmup wrapper for any base scheduler
layers
: LSTM and GRU cells (standard, peephole, bidirectional) with dropoutmodels
: High-level network architectures (LSTM, BiLSTM, GRU)training
: Training utilities with automatic train/eval mode switchingoptimizers
: SGD, Adam, RMSprop with schedulingloss
: MSE, MAE, Cross-entropy loss functionsschedulers
: Learning rate scheduling algorithms
Run examples to see the library in action:
# Basic usage and training
cargo run --example basic_usage
cargo run --example training_example
cargo run --example multi_layer_lstm
cargo run --example time_series_prediction
# Advanced architectures
cargo run --example gru_example # GRU vs LSTM comparison
cargo run --example bilstm_example # Bidirectional LSTM
cargo run --example dropout_example # Dropout demo
# Learning and scheduling
cargo run --example learning_rate_scheduling # Basic schedulers
cargo run --example advanced_lr_scheduling # Advanced schedulers with visualization
# Real-world applications
cargo run --example stock_prediction
cargo run --example weather_prediction
cargo run --example text_classification_bilstm
cargo run --example text_generation_advanced
cargo run --example real_data_example
# Analysis and debugging
cargo run --example model_inspection
- Input Dropout: Applied to inputs before computing gates
- Recurrent Dropout: Applied to hidden states with variational support
- Output Dropout: Applied to layer outputs
- Zoneout: RNN-specific regularization preserving previous states
- SGD: Stochastic gradient descent with momentum
- Adam: Adaptive moment estimation with bias correction
- RMSprop: Root mean square propagation
- MSELoss: Mean squared error for regression
- MAELoss: Mean absolute error for robust regression
- CrossEntropyLoss: Numerically stable softmax cross-entropy for classification
- StepLR: Decay by factor every N epochs
- OneCycleLR: One cycle policy (warmup + annealing)
- CosineAnnealingLR: Smooth cosine oscillation with warm restarts
- ReduceLROnPlateau: Reduce when validation loss plateaus
- PolynomialLR: Polynomial decay with configurable power
- CyclicalLR: Triangular oscillation with multiple modes
- WarmupScheduler: Gradual increase wrapper for any scheduler
- LinearLR: Linear interpolation between learning rates
cargo test
The library includes comprehensive examples that demonstrate its capabilities:
Run the learning rate scheduling examples to see different scheduler behaviors:
cargo run --example learning_rate_scheduling # Compare basic schedulers
cargo run --example advanced_lr_scheduling # Advanced schedulers with ASCII visualization
Compare LSTM vs GRU performance:
cargo run --example gru_example
Test the library with practical examples:
cargo run --example stock_prediction # Stock price predictions
cargo run --example weather_prediction # Weather forecasting
cargo run --example text_classification_bilstm # Classification accuracy
The examples output training metrics, loss values, and predictions that you can analyze or plot with external tools.
- v0.4.0: Advanced learning rate scheduling with 12 different schedulers, warmup support, cyclical learning rates, polynomial decay, and ASCII visualization
- v0.3.0: Bidirectional LSTM networks with flexible combine modes
- v0.2.0: Complete training system with BPTT and comprehensive dropout
- v0.1.0: Initial LSTM implementation with forward pass
Contributions are welcome! Please submit issues, feature requests, or pull requests.
MIT License - see the LICENSE file for details.