- Python >=3.11
- uv
- Download M5 Dataset following the instructions in Reamde
- Run
./dataset/M5/extract.sh
- Run
uv sync
to download the required dependencies - Run
uv run pretrainm5.py
for pre-training (atleast for 10 ep) - Run
rye run trainm5.py
for training (atleast for 100 epochs)
- Requires atleast 70 GB VRAM with Mixed precision
- Toggle
SCALE_PREC = False
intrainm5.py
to use FP16 to run on GPUs of less than 40 GB VRAM