A MinGRU model with 18 million parameter.
Mingru-lm | Mingru |
---|---|
The MinGRU model is a simplified version of the traditional Gated Recurrent Unit (GRU), designed to reduce complexity and improve efficiency. By removing the hidden state dependencies from its gates, MinGRU allows for parallel training, which is much faster compared to traditional GRUs. Additionally, it eliminates the use of non-linear activations like tanh, further streamlining computations.
Parameters | Value | Description |
---|---|---|
--dim |
512 | Dimension for the model |
--num_tokens |
256 | Maximum tokens for model (max is 256 due to ASCII) |
--num_layers |
6 | Number of layers to train the model |
Here is an example generated by the model after training it:
Prompt - "once upon a time
You can try the pre-trained model in this Hugging Face Space app: MinGru
And you can find the pre-trained model here: MinGru-model
The model was trained using two NVIDIA T4 GPUs in a distributed data parallel (DDP) setup
Hyperparameter | Type | Default Value | Description |
---|---|---|---|
--batch_size |
int | 204 | Batch size for training |
--lr |
float | 4e-3 | Learning rate for training the model |
--wd |
float | 1e-2 | Weight decay for your optimizer |
--epochs |
int | 30 | Total number of epochs |
Trained the model on the tiny-stories dataset
Before we begin the model was trained on two t4 kaggle gpus so the file train_ddp.py will work there but you can still train the model using the train.py
- First git clone this repo
git clone https://github.com/dame-cell/MinGru.git
cd MinGru
pip install -r requirements.txt
cd mingru
- First you will need to prepare the dataset just simply run this code,it will take less than 1 min or maybe
python data.py
- Then train the model
python train.py --path_to_train_data (required) --path_to_test_data (required) --batch_size 204
We updated the model architecture by adding casual-depth which can be used to capture local temporal dependencies
The updated architecture seems to solve the spike at the end of the training
Metric | Best Value |
---|---|
Best Updated Train Loss | 0.74 |
Best Updated Val Loss | 0.77 |
Best Updated Perplexity | 2.16 |
Check the wandb report right here wandb
@inproceedings{Feng2024WereRA,
title = {Were RNNs All We Needed?},
author = {Leo Feng and Frederick Tung and Mohamed Osama Ahmed and Yoshua Bengio and Hossein Hajimirsadegh},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273025630}
}
Thank to lucidrains for the reference code