Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fire finetuning #553

Draft
wants to merge 762 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
762 commits
Select commit Hold shift + click to select a range
973d868
Add specific vizier running file for optimization
Jun 13, 2024
a4bb9c8
Add working run_vizier modification of run_exp
Jun 13, 2024
c3b56c7
Convert boolean action to new format for compat
Jun 13, 2024
f832845
Remove temporary test file
Jun 15, 2024
f524534
Streamline run_vizier.py
Jun 15, 2024
a8a955d
Remove leading spaces from train.py
Jun 15, 2024
ce306f9
Add configuration json file to out_dir
Jun 15, 2024
2a14bd8
Add saving of best validation loss and iter file
Jun 15, 2024
f050079
Copy meta.pkl to out_dir
Jun 15, 2024
f13233c
Add check for meta.pkl from out_dir to sample.py
Jun 15, 2024
5fc97ef
Add fast method for obtaining best validation loss
Jun 15, 2024
fc9c93f
Supress warnings in the run_vizier
Jun 15, 2024
8b27d8e
Add comments to ckpt saving and end action list
Jun 15, 2024
8673a3a
Add --fast option for inspect checkpoints
Jun 15, 2024
ee8f4df
Add scripts compat with python-codes-25k
Jun 15, 2024
447ae47
Merge pull request #186 from klei22/add_vizier_optimization
gkielian Jun 16, 2024
6196fc9
Merge branch 'add_scripts_for_python_codes_dataset' into HEAD
Jun 16, 2024
dcfb5bc
Merge pull request #187 from klei22/add_more_datasets
gkielian Jun 16, 2024
48a36cc
Remove duplicate block
gkielian Jun 17, 2024
641401e
Merge pull request #188 from gkielian/main
klei22 Jun 17, 2024
aeca808
Add options random init mean and std to train.py
Jun 18, 2024
6277d35
Clean imports section of model.py
Jun 18, 2024
45a5b2e
Add polymorphic interface for linear variations
Jun 18, 2024
6e70517
Simplify linear wrapper to inheritance based
Jun 18, 2024
df5bd39
Merge branch 'add_kan_and_hyperparams' into origin_main
Jun 18, 2024
01aef4b
Don't set WPE if not selected
Jun 18, 2024
0e9ac28
Merge pull request #189 from klei22/add_linear_wrapper_and_kan_feedback
gkielian Jun 18, 2024
e72f663
Fix bug separating shuffle moveset with moveset
Jun 16, 2024
881d7a6
Add trial argparse arg
Jun 16, 2024
1c76156
Upgrade progress bar
Jun 16, 2024
0777f5b
Add create_datasets.sh for experiments
Jun 22, 2024
40070e8
Add latest iteration of parameter exploration
Jun 22, 2024
5516e22
Add quantized krmsnorm
Jun 23, 2024
b3ae142
Add exploration for vizier
Jun 23, 2024
0c81ccf
Add option for no quantization
Jun 23, 2024
f4ef50f
Update ints to categories
Jun 23, 2024
a081c0e
Fix configuration param names
Jun 24, 2024
78c1165
Add fix for exploration json specifying krmsnorm
Jun 24, 2024
939e7cd
Add initial nan detection to train.py
Jun 24, 2024
437dd7e
Merge pull request #191 from klei22/rubiks_cube_improvements_2
gkielian Jun 24, 2024
4b079e1
golden gen for decoder
Buck008 Jun 25, 2024
03698f7
update for golden gen
Buck008 Jun 28, 2024
d17fdb3
Add scripts compatible with the Newswire Dataset
Jul 1, 2024
083f37c
Add snac_converter.py for mp3 <-> text conversion
Jul 4, 2024
a75244d
Update prepare.py to include direct numeric tokens
Jul 5, 2024
bc9d8df
Add optional token boundary to sample.py
Jul 5, 2024
7530e41
Update sample.py to use write instead of append
Jul 5, 2024
d0b9ed5
Add option for snac processing all files in a dir
Jul 5, 2024
c6cc4fe
Add prepare.py softlink to snac folder
Jul 5, 2024
178fe66
Add example of how to create a listening sample
Jul 5, 2024
ba4dec5
Add helper files
Jul 5, 2024
468e21e
Update get_dataset.sh with tokenization
Jul 5, 2024
c551d5f
Add sampling and training scripts
Jul 5, 2024
cead197
Add progress bars to audio directory processing
Jul 5, 2024
ad7192b
Add split_mp3s.py to help reduce mp3 size for gpu
Jul 5, 2024
f5d725c
Organize imports
Jul 5, 2024
584ab7a
Add sample.sh and train.sh for easy demo of vc
Jul 5, 2024
9196ee6
Add README.md
Jul 5, 2024
3644b37
Remove stray char from script
Jul 9, 2024
99e820b
Add dependency installation instructions
Jul 9, 2024
5c44f27
Merge pull request #198 from klei22/add_snac_tokenization
gkielian Jul 9, 2024
3395f2d
Update get_dataset script to use jq
Jul 9, 2024
66a73be
Check for download directory before saving
Jul 9, 2024
c6cb631
Remove unused bash setting +x
Jul 9, 2024
2bdaa38
Merge pull request #196 from klei22/add_newswire
gkielian Jul 9, 2024
21d6ee4
Merge pull request #194 from Buck008/golden_gen
gkielian Jul 9, 2024
cfde3e9
Update train.py
gkielian Jul 9, 2024
d8a6b4b
Merge pull request #192 from klei22/add_quantized_krmsnorm
gkielian Jul 9, 2024
e9be482
Add scripts compatible wtih smollm-corpus
Jul 27, 2024
d6215a3
Substantive commit on implementing Mixture of Experts (MoE) architecture
djlisbonne Jul 29, 2024
104e3d9
Update typing extensions module
Jul 29, 2024
eea13a9
Merge pull request #204 from klei22/update_typing_extensions
klei22 Jul 29, 2024
cd761ae
Altering Block class to use MoE layer instead of basic MLP
djlisbonne Jul 29, 2024
28235b3
Small change to cmd line flags
djlisbonne Jul 29, 2024
4f713b3
More MoE flags and slight renaming
djlisbonne Jul 30, 2024
69a1ab7
Bug fix
djlisbonne Jul 30, 2024
33854fd
added kRMSNormWithRecompute module
mmoffatt2 Jul 30, 2024
6f46df9
added gpt_conf
mmoffatt2 Jul 30, 2024
7bad0c3
Comment updates
djlisbonne Jul 30, 2024
bd68e6e
Store and load support for GPTConfig
djlisbonne Jul 30, 2024
4744d67
Merge branch 'save_gptconf' into add_moe
djlisbonne Jul 30, 2024
6e9aeda
got rid of unneccesary krmsnorm module
mmoffatt2 Jul 31, 2024
b931379
moved quantize embedding code to new PR
mmoffatt2 Jul 31, 2024
6bd98a6
Adding argparse arg in train.py to load and save params to json
djlisbonne Jul 31, 2024
2736855
Merge branch 'save_gptconf' into add_moe
djlisbonne Jul 31, 2024
4c2f911
Added partial code for snac tokens
xinyixuu Jul 31, 2024
a4a19cf
got rid of unnecessary linear quant
mmoffatt2 Aug 1, 2024
94a34a9
initial feedback changes
mmoffatt2 Aug 2, 2024
23ece16
Added submodule
xinyixuu Aug 2, 2024
6661182
Update path to submodule
xinyixuu Aug 2, 2024
8c5aa0c
update minor issue on sample_whisper_snac.py
xinyixuu Aug 2, 2024
3ce65e4
Resolved merge conflict in data/snac/sample_whisper_snac.py
xinyixuu Aug 2, 2024
d41bbbb
Update .gitmodules
xinyixuu Aug 2, 2024
e81ed92
Added submodule
xinyixuu Aug 2, 2024
6b4a662
added bash script to install whisper.cpp
xinyixuu Aug 2, 2024
9929ad3
Moved moe router into dedicated variations file
djlisbonne Aug 2, 2024
db2921a
Merge pull request #208 from mmoffatt2/krmsnorm-recompute
gkielian Aug 3, 2024
dccaf76
Replace counters with flag for recomputed
gkielian Aug 3, 2024
238b0b7
Merge pull request #209 from djlisbonne/save_gptconf
gkielian Aug 3, 2024
31e032b
Merge pull request #203 from klei22/add_compat_with_smollm_corpus
gkielian Aug 3, 2024
02b7537
Fix bug with MoE layer frequency
djlisbonne Aug 4, 2024
f818f9a
add test files for debug
xinyixuu Aug 5, 2024
99fcad2
add bash script that run the whisper
xinyixuu Aug 6, 2024
71faffc
Add guards to sample_whisper_snac.py
gkielian Aug 6, 2024
f572743
Add graceful end of file handling
gkielian Aug 6, 2024
196c4f0
Add formatting for stdout
gkielian Aug 6, 2024
4ecc0cb
Add small formatting edits
gkielian Aug 6, 2024
f51c438
Add tempfix for compile error from rv64 flags
gkielian Aug 6, 2024
e3e7acc
Merge branch 'master' into quantization_embedding
gkielian Aug 6, 2024
3ac1d49
Merge pull request #210 from mmoffatt2/quantization_embedding
gkielian Aug 6, 2024
61b478a
Merge pull request #215 from gkielian/add_krmsnorm_recompute_flag
klei22 Aug 6, 2024
f757848
Merge pull request #1 from gkielian/add_snac_tokens_patch
xinyixuu Aug 6, 2024
bd9fc88
Merge pull request #2 from xinyixuu/master
xinyixuu Aug 6, 2024
8e847ae
updated to add snac patch
xinyixuu Aug 6, 2024
6cca569
Delete data/snac/.tmux.conf
gkielian Aug 6, 2024
fcb8735
Merge pull request #212 from xinyixuu/add_snac_tokens
gkielian Aug 6, 2024
4c63f46
Remove temp wav file
gkielian Aug 6, 2024
dc6cfb9
Add .wav and .mp3 files to .gitignore
gkielian Aug 6, 2024
9ce521e
Move whisper_snac.sh to snac dir and adjust paths
gkielian Aug 7, 2024
7ed18f9
Moved MoE layer construction to create_shared_param_group to save ove…
djlisbonne Aug 7, 2024
7d238c3
Removed print statements
djlisbonne Aug 7, 2024
5191acd
Merge pull request #205 from djlisbonne/add_moe
gkielian Aug 7, 2024
424d360
moved activation quantization to new PR
mmoffatt2 Aug 5, 2024
9fdd822
undo gpt_conf changes
mmoffatt2 Aug 5, 2024
80e4fee
updated act names
mmoffatt2 Aug 8, 2024
669377c
moved quantized linears to new PR
mmoffatt2 Aug 1, 2024
00b0f2e
got rid of prints
mmoffatt2 Aug 1, 2024
ad2dbf3
initial PR feedback
mmoffatt2 Aug 2, 2024
3623baa
removed warmup_iters changes
mmoffatt2 Aug 2, 2024
089cfa1
added quantization_warmup_iters
mmoffatt2 Aug 2, 2024
16032aa
added quantization_dict
mmoffatt2 Aug 7, 2024
82479c1
added new quantization to embedding
mmoffatt2 Aug 8, 2024
eb0eac4
added new quantization to embedding
mmoffatt2 Aug 8, 2024
10526d5
Merge pull request #218 from gkielian/add_snac_patches_2
klei22 Aug 8, 2024
c7ebec2
Add Model Parameter Section
gkielian Aug 9, 2024
a008fa7
Trim trailing spaces
gkielian Aug 9, 2024
a886556
Add torchinfo to requirements_cpu
gkielian Aug 9, 2024
5f8c620
Add torchinfo installation to README.md
gkielian Aug 9, 2024
9ce06df
changed act inouts to input
mmoffatt2 Aug 9, 2024
cfd8bdf
added linear_variants and quant_methods lists
mmoffatt2 Aug 9, 2024
2ef1c57
Fix typo in logging group
gkielian Aug 9, 2024
b8e43e0
Merge pull request #214 from mmoffatt2/quantization_linear
gkielian Aug 9, 2024
ee5b012
Merge branch 'master' into quantization_activations
gkielian Aug 9, 2024
90f1f91
Update train.py
gkielian Aug 9, 2024
0906b16
Update model.py
gkielian Aug 9, 2024
9e1206a
Merge pull request #216 from mmoffatt2/quantization_activations
gkielian Aug 9, 2024
c12522e
Merge pull request #223 from gkielian/add_model_param_section
klei22 Aug 9, 2024
6cafe38
Add implementation of Rotary Embeddings
Aug 11, 2024
556d99c
Add Rope Length option to Symmetrical Rope
Aug 12, 2024
8c29fef
Add Rope length to Standard Rope
Aug 12, 2024
e99d875
combined linear and activation PRs
mmoffatt2 Aug 8, 2024
151f57c
Add polling to inspect_ckpts.py
Aug 12, 2024
3ed497a
added functionality to save quantized weights/activations
mmoffatt2 Aug 9, 2024
05e86bc
quantized linear/activation feedback
mmoffatt2 Aug 9, 2024
dcd75bc
Got rid of merge mistakes
mmoffatt2 Aug 12, 2024
68249a5
Removing inspect print statement when polling
Aug 12, 2024
7f6ea79
Update rope_sweep with random seed
Aug 12, 2024
ca0dd1f
modified whisper_snac.sh file to run the whole process to get the re…
xinyixuu Aug 12, 2024
a6d4667
Adjust snac whisper script names and paths
gkielian Aug 12, 2024
e890fe0
Merge pull request #3 from gkielian/add_snac_tokens_over_dir
xinyixuu Aug 14, 2024
61b7876
Adjust try catch block and formatting
gkielian Aug 14, 2024
ebca13f
Merge pull request #4 from gkielian/add_snac_tokens_over_dir
xinyixuu Aug 14, 2024
d122ad5
Merge pull request #225 from klei22/add_rope_embeddings
gkielian Aug 14, 2024
a01bb72
added update_activations bool
mmoffatt2 Aug 14, 2024
76516f7
Add ReLUMax variation and sweep
gkielian Aug 14, 2024
90b05a7
refactored fake_quantize_act code
mmoffatt2 Aug 14, 2024
293cc60
fixed bug in ordering of fake_quantize_act
mmoffatt2 Aug 15, 2024
7a0b04f
added more granular qk and pv input act quantization
mmoffatt2 Aug 15, 2024
8706d6c
moved activation buffer creation to a function
mmoffatt2 Aug 15, 2024
bbc09a5
moved all train.py statistics functions to new folder
mmoffatt2 Aug 15, 2024
b957e7e
fixed plot_statistics argument
mmoffatt2 Aug 15, 2024
b08ffcd
moved if statement inside create_statistics
mmoffatt2 Aug 15, 2024
52a15d1
Merge pull request #230 from mmoffatt2/statistics_util
gkielian Aug 16, 2024
eac0c67
Merge pull request #224 from mmoffatt2/quantization_save_weights
gkielian Aug 16, 2024
c1db268
Merge pull request #227 from xinyixuu/add_snac_tokens
gkielian Aug 17, 2024
ee5ad20
Merge pull request #229 from mmoffatt2/matmul_quantization_granularity
gkielian Aug 17, 2024
9beea4b
Merge branch 'master' into add_relumax
klei22 Aug 17, 2024
d3de6f3
Merge pull request #231 from gkielian/add_relumax
klei22 Aug 17, 2024
2483864
Add model printing before training run
gkielian Aug 17, 2024
24ff05e
Adjust plotting to be optional
gkielian Aug 17, 2024
7d2e241
Add logging and plotting options to gptconf
gkielian Aug 17, 2024
3f61101
Trim ending spaces in train.py and gptconf.py
gkielian Aug 17, 2024
0c5be9f
Add ConSmaxV2 and add conditional io logging
gkielian Aug 17, 2024
d03d988
Add formatted model printing functions
gkielian Aug 17, 2024
74b4880
Merge branch 'master' of https://github.com/chipGPT/nanoGPT
gkielian Aug 17, 2024
7a57222
Trim settings for ConSvaxV2 to only defaults
gkielian Aug 19, 2024
54badaa
Support for importing & translating models from HF
djlisbonne Aug 19, 2024
e9e41bc
Cleanup
djlisbonne Aug 20, 2024
3464d3b
Temp remove of resume_gpt_model
gkielian Aug 20, 2024
bfb268c
Merge pull request #233 from djlisbonne/hf_from_pretrained_fix
gkielian Aug 20, 2024
1d4f8a3
Add printing of model summary to sample.py
gkielian Aug 20, 2024
e484371
Add abbreviations for config and output files
gkielian Aug 20, 2024
265ce88
Add gating args for logging statistics
gkielian Aug 20, 2024
26076a8
Add sweep for ConSmaxV2
gkielian Aug 20, 2024
9a16913
Add printout for model param names
gkielian Aug 20, 2024
dd0a852
Add means for logging per head via tensorboard/csv
gkielian Aug 21, 2024
6925207
Add softmax overflow recompute test
gkielian Aug 21, 2024
9e76d1a
added custom_gpt code
mmoffatt2 Aug 21, 2024
b903c04
Initial commit
djlisbonne Aug 21, 2024
273d9cc
simplifying train.py init_from branching and added support for overri…
djlisbonne Aug 21, 2024
c42bffc
Require xmax_guess set if overflow recompute
gkielian Aug 21, 2024
541d31d
Make overflow recompute false by default
gkielian Aug 21, 2024
fd97d3c
changed softmax to softplus
mmoffatt2 Aug 21, 2024
4b4f415
Add initial code emulating hardware
gkielian Aug 21, 2024
2fd45cf
Add latest train and pickel for testing
gkielian Aug 21, 2024
2671c15
Update kv_group to default as none
gkielian Aug 21, 2024
f4d3b29
Merge branch 'master' of https://github.com/chipGPT/nanoGPT
gkielian Aug 21, 2024
b073506
removed test_train.py
mmoffatt2 Aug 22, 2024
df9d975
Merge pull request #237 from mmoffatt2/huggingface_model
gkielian Aug 22, 2024
355149c
added file for uploading to huggingfacehub
mmoffatt2 Aug 22, 2024
4206edf
moved sample to its own file
mmoffatt2 Aug 22, 2024
0c94f78
Merge branch 'ReaLLMASIC:master' into huggingface_model
mmoffatt2 Aug 22, 2024
b9e20db
Merge pull request #238 from mmoffatt2/huggingface_model
gkielian Aug 22, 2024
a5e7592
update to sample.py to properly load pretrained GPT2 model
djlisbonne Aug 22, 2024
ae45d40
Add option to get sample inference after each val
gkielian Aug 22, 2024
02eb15c
Removed exits
djlisbonne Aug 23, 2024
e97b2fd
Adjust imports for inference compatibility
gkielian Aug 23, 2024
4f65cbb
Add start tokens option
gkielian Aug 23, 2024
c4be2b1
Add test script and README for goldenbrick
gkielian Aug 23, 2024
7006be5
Add note and modification for weight export
gkielian Aug 23, 2024
3eaaa69
Update to state_dict translation to correctly assign q,k,v matrices f…
djlisbonne Aug 23, 2024
eb8be6a
Merge pull request #240 from djlisbonne/gptconfig_fix
gkielian Aug 23, 2024
6eaa5f4
Merge branch 'master' into add_numpy_hw_test
klei22 Aug 24, 2024
4dc821a
Remove duplicate save file
klei22 Aug 24, 2024
55864d5
Restore statistic_plots.py
klei22 Aug 24, 2024
a65c5c3
Merge pull request #236 from gkielian/add_numpy_hw_test
klei22 Aug 24, 2024
57ff63e
Merge branch 'master' into add_training_sample_option
klei22 Aug 24, 2024
6c602e9
Merge pull request #241 from gkielian/add_training_sample_option
klei22 Aug 24, 2024
937ffae
Add progress bar to train.py
gkielian Aug 25, 2024
57b68f1
Merge pull request #243 from gkielian/add_progress_bar
klei22 Aug 25, 2024
6645172
Move notebooks to colab folder
Aug 25, 2024
c4ce6ed
Remove data_augmentation folder
Aug 25, 2024
474b73b
Remove data augmentation in favor of HF apis
Aug 25, 2024
405a276
Add original nanoGPT as module instead of hardcopy
Aug 25, 2024
8c8eb9a
Add llm.c as a submodule
Aug 25, 2024
be8d426
Clean images no longer used in README
Aug 25, 2024
fbf5988
Merge pull request #244 from klei22/organize_folders
gkielian Aug 25, 2024
5423961
Add softmax sweep to benchmark softmaxes v context
Aug 25, 2024
da5cc53
v2: manually adding +1 in log_rel & log_pos.
Mars-Cat2023 Aug 25, 2024
515bb90
v3: Adding one argument: –fire_log_bias
Mars-Cat2023 Aug 25, 2024
cddf59d
Add option to just do forward, for testing inference
Aug 25, 2024
8ec274b
v4: Adding 5 new arguments: –-fire_num_hidden_layers, mlp_width, init…
Mars-Cat2023 Aug 25, 2024
61833cd
Update train.py
Mars-Cat2023 Aug 26, 2024
7a414a7
Merge pull request #246 from Mars-Cat2023/FIRE
gkielian Aug 26, 2024
f4c0781
Merge pull request #245 from klei22/add_softmax_context_benchmark
gkielian Aug 26, 2024
cbf3d95
Fixed One Bug in FIRE - PR #246 v4
Mars-Cat2023 Aug 29, 2024
981c8dd
Add MLP Expansion factor control and sweep
gkielian Aug 31, 2024
863c54d
Merge pull request #251 from Mars-Cat2023/FIRE
gkielian Sep 2, 2024
37ca368
Merge pull request #252 from gkielian/add_mlp_expansion_factor
klei22 Sep 3, 2024
5a7528b
Add code for finetuning with FIRE
gkielian Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add progress bar to train.py
gkielian committed Aug 25, 2024
commit 937ffae346feee0b922d49e79c148e824c2c3a74
262 changes: 136 additions & 126 deletions train.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,9 @@
import time

from model_info_util.model_info import print_summary, print_module_structure, print_model_blocks, print_model_tree

from rich.progress import Progress

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
@@ -778,154 +781,161 @@ def train(self):
for head in range(self.args.n_head):
graph_y_labels.append(f"Layer {layer} Head {head}")

while True:
lr = self.get_lr(self.iter_num) if self.args.decay_lr else self.args.learning_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr

if self.iter_num % self.args.eval_interval == 0 and self.master_process:
losses = self.estimate_loss()
print(f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
self.log_metrics(losses, lr, running_mfu, self.iter_num)

if math.isnan(losses["val"]):
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : 0,
'nan' : True,
'config': vars(self.args),
}
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
if losses['val'] < self.best_val_loss or self.args.always_save_checkpoint:
if losses['val'] < self.best_val_loss:
self.iter_num_best_val_loss = self.iter_num
self.best_val_loss = losses['val']
# Save best validation loss
with open(os.path.join(self.args.out_dir, 'best_val_loss_and_iter.txt'), "w") as best_loss_file:
best_loss_file.write(str(self.best_val_loss.item())+","+str(self.iter_num))
# Reset early exit counter
num_steps_with_worse_loss = 0
if self.iter_num > 0:
# Create progress bar
progress = Progress()
with progress:
task_id = progress.add_task("[green]Training...", total=(self.args.max_iters - self.iter_num))
while True:
lr = self.get_lr(self.iter_num) if self.args.decay_lr else self.args.learning_rate
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr

if self.iter_num % self.args.eval_interval == 0 and self.master_process:
losses = self.estimate_loss()
print(f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
self.log_metrics(losses, lr, running_mfu, self.iter_num)

if math.isnan(losses["val"]):
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : None,
'nan' : None,
'nan_iter_num' : 0,
'nan' : True,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
# Save checkpoint
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
# Try new checkpoint if better val loss
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)
elif self.args.sample_each_eval:
# Try model inference (e.g. exploring inference from overfitting)
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)

if self.args.patience is not None and num_steps_with_worse_loss >= self.args.patience:
print(f"Early Stopping: loss has not decreased in {self.args.patience + 1} steps")
if losses['val'] < self.best_val_loss or self.args.always_save_checkpoint:
if losses['val'] < self.best_val_loss:
self.iter_num_best_val_loss = self.iter_num
self.best_val_loss = losses['val']
# Save best validation loss
with open(os.path.join(self.args.out_dir, 'best_val_loss_and_iter.txt'), "w") as best_loss_file:
best_loss_file.write(str(self.best_val_loss.item())+","+str(self.iter_num))
# Reset early exit counter
num_steps_with_worse_loss = 0
if self.iter_num > 0:
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : None,
'nan' : None,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
# Save checkpoint
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
# Try new checkpoint if better val loss
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)
elif self.args.sample_each_eval:
# Try model inference (e.g. exploring inference from overfitting)
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)

if self.args.patience is not None and num_steps_with_worse_loss >= self.args.patience:
print(f"Early Stopping: loss has not decreased in {self.args.patience + 1} steps")
break
if losses['val'] > self.best_val_loss:
num_steps_with_worse_loss += 1

if self.iter_num == 0 and self.args.eval_only:
break
if losses['val'] > self.best_val_loss:
num_steps_with_worse_loss += 1

if self.iter_num == 0 and self.args.eval_only:
break

for micro_step in range(self.args.gradient_accumulation_steps):
if self.ddp:
self.model.require_backward_grad_sync = (micro_step == self.args.gradient_accumulation_steps - 1)

with self.ctx:
logits, loss = self.model(self.X, self.Y)
loss = loss / self.args.gradient_accumulation_steps

self.X, self.Y = self.get_batch('train')

self.scaler.scale(loss).backward()

if self.args.grad_clip != 0.0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)

self.scaler.step(self.optimizer)
self.scaler.update()

self.optimizer.zero_grad(set_to_none=True)

t1 = time.time()
dt = t1 - t0
t0 = t1
if self.iter_num % self.args.log_interval == 0 and self.master_process:
lossf = loss.item() * self.args.gradient_accumulation_steps
if local_iter_num >= 5:
mfu = self.raw_model.estimate_mfu(self.args.batch_size * self.args.gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f} ms, mfu {running_mfu*100:.2f}%")
if math.isnan(lossf):
if self.args.save_nan_checkpoint:
for micro_step in range(self.args.gradient_accumulation_steps):
if self.ddp:
self.model.require_backward_grad_sync = (micro_step == self.args.gradient_accumulation_steps - 1)

with self.ctx:
logits, loss = self.model(self.X, self.Y)
loss = loss / self.args.gradient_accumulation_steps

self.X, self.Y = self.get_batch('train')

self.scaler.scale(loss).backward()

if self.args.grad_clip != 0.0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)

self.scaler.step(self.optimizer)
self.scaler.update()

self.optimizer.zero_grad(set_to_none=True)

t1 = time.time()
dt = t1 - t0
t0 = t1
if self.iter_num % self.args.log_interval == 0 and self.master_process:
lossf = loss.item() * self.args.gradient_accumulation_steps
if local_iter_num >= 5:
mfu = self.raw_model.estimate_mfu(self.args.batch_size * self.args.gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f} ms, mfu {running_mfu*100:.2f}%")
if math.isnan(lossf):
if self.args.save_nan_checkpoint:
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num_best_val_loss,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : self.iter_num,
'nan' : True,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
sys.exit("Exiting training loss is NaN")
self.log_metrics_non_validation(lossf, running_mfu, self.iter_num)


if self.args.create_statistics:
create_statistics(self, graph_y_labels)


self.iter_num += 1
local_iter_num += 1

# Update progress bar
progress.update(task_id, advance=1)

# End of training actions
if self.iter_num > self.args.max_iters:
if self.args.only_save_checkpoint_at_end:
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num_best_val_loss,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : self.iter_num,
'nan' : True,
'nan_iter_num' : None,
'nan' : None,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
sys.exit("Exiting training loss is NaN")
self.log_metrics_non_validation(lossf, running_mfu, self.iter_num)


if self.args.create_statistics:
create_statistics(self, graph_y_labels)


self.iter_num += 1
local_iter_num += 1

# End of training actions
if self.iter_num > self.args.max_iters:
if self.args.only_save_checkpoint_at_end:
checkpoint = {
'model': self.raw_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'model_args': self.model_args,
'iter_num': self.iter_num,
'best_val_loss': self.best_val_loss,
'nan_iter_num' : None,
'nan' : None,
'config': vars(self.args),
}
print(f"saving checkpoint to {self.args.out_dir}")
torch.save(checkpoint, os.path.join(self.args.out_dir, 'ckpt.pt'))
# Sample if set
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)
break

if self.args.plot_statistics:
plot_statistics(self.args, self.stats, graph_y_labels)
# Sample if set
if self.args.max_sample_tokens:
self.sample_and_print(self.args.max_sample_tokens, start_tokens=self.args.sample_start_tokens)
break

if self.args.tensorboard_log:
self.writer.flush()
self.writer.close()
if self.args.plot_statistics:
plot_statistics(self.args, self.stats, graph_y_labels)

if self.args.wandb_log and self.master_process:
import wandb
wandb.log({"finished": True})
wandb.finish()
if self.args.tensorboard_log:
self.writer.flush()
self.writer.close()

if self.args.wandb_log and self.master_process:
import wandb
wandb.log({"finished": True})
wandb.finish()

def main():
args, model_group, training_group, logging_group = parse_args()