7373from ray .util .placement_group import PlacementGroup , placement_group
7474from ray .util .scheduling_strategies import PlacementGroupSchedulingStrategy
7575from rich .pretty import pprint
76+ from torch .utils .tensorboard import SummaryWriter
7677from tqdm import tqdm
7778from transformers import AutoModelForCausalLM , PreTrainedModel , PreTrainedTokenizer , get_scheduler
7879from transformers .integrations import HfDeepSpeedConfig
122123 is_beaker_job ,
123124 launch_ai2_evals_on_weka ,
124125 maybe_get_beaker_config ,
125- maybe_update_beaker_description_with_wandb_url ,
126126 maybe_use_ai2_hf_entity ,
127127 maybe_use_ai2_wandb_entity ,
128128 ray_get_with_progress ,
@@ -382,6 +382,8 @@ class Args:
382382 """The beaker evaluation tasks to launch"""
383383 oe_eval_max_length : int = 4096
384384 """the max generation length for evaluation for oe-eval"""
385+ oe_eval_beaker_image : Optional [str ] = None
386+ """the docker image for evaluation for oe-eval"""
385387 eval_priority : Literal ["low" , "normal" , "high" , "urgent" ] = "normal"
386388 """the priority of auto-launched evaluation jobs"""
387389
@@ -1078,6 +1080,7 @@ def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url
10781080 args .stop_strings ,
10791081 args .gs_bucket_path ,
10801082 args .eval_priority ,
1083+ args .oe_eval_beaker_image ,
10811084 )
10821085
10831086
@@ -1648,15 +1651,21 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod
16481651 wandb .init (
16491652 project = args .wandb_project_name ,
16501653 entity = args .wandb_entity ,
1654+ sync_tensorboard = True ,
16511655 config = all_configs ,
16521656 name = args .run_name ,
16531657 save_code = True ,
16541658 tags = [args .exp_name ] + get_wandb_tags (),
16551659 )
16561660 wandb_url = wandb .run .get_url ()
1657- maybe_update_beaker_description_with_wandb_url (wandb_url )
16581661
1659- return beaker_config , wandb_url
1662+ writer = SummaryWriter (f"runs/{ args .run_name } " )
1663+ writer .add_text (
1664+ "hyperparameters" ,
1665+ "|param|value|\n |-|-|\n %s" % ("\n " .join ([f"|{ key } |{ value } |" for key , value in vars (args ).items ()])),
1666+ )
1667+
1668+ return beaker_config , writer , wandb_url
16601669
16611670
16621671def setup_datasets (args : Args , tc : TokenizerConfig , tokenizer : PreTrainedTokenizer ):
@@ -1936,11 +1945,13 @@ def one_training_step(
19361945 collated_data ,
19371946 tokenizer ,
19381947 data_thread_metrics ,
1948+ average_metrics ,
19391949 episode ,
19401950 training_step ,
19411951 num_total_tokens ,
19421952 start_time ,
19431953 train_dataset ,
1954+ writer ,
19441955 wandb_url ,
19451956 chat_template_name ,
19461957):
@@ -1975,18 +1986,16 @@ def one_training_step(
19751986 ** data_thread_metrics ,
19761987 ** average_metrics ,
19771988 }
1978- # Print only scalar metrics
1979- scalar_metrics = {k : v for k , v in metrics .items () if isinstance (v , (float , int ))}
1989+ scalar_metrics = {}
1990+ for key , value in metrics .items ():
1991+ if isinstance (value , float ) or isinstance (value , int ):
1992+ writer .add_scalar (key , value , episode )
1993+ scalar_metrics [key ] = value
1994+ if isinstance (value , np .ndarray ) or isinstance (value , list ):
1995+ if len (value ) > 0 :
1996+ writer .add_histogram (key , value , episode )
19801997 print_rich_single_line_metrics (scalar_metrics )
19811998
1982- if args .with_tracking :
1983- # Convert array/list metrics to wandb histograms for logging
1984- for key , value in metrics .items ():
1985- if isinstance (value , np .ndarray ) or isinstance (value , list ):
1986- if len (value ) > 0 :
1987- metrics [key ] = wandb .Histogram (value )
1988- wandb .log (metrics , step = episode )
1989-
19901999 if args .save_freq > 0 and training_step % args .save_freq == 0 and (args .eval_on_step_0 or training_step > 1 ):
19912000 with Timer ("[Main Thread] 🗡️ Saving model" ):
19922001 checkpoint_dir = f"{ args .output_dir } _checkpoints"
@@ -2036,6 +2045,7 @@ def maybe_evaluate(
20362045 eval_batch : Optional [Batch ],
20372046 reward_fn ,
20382047 episode ,
2048+ writer ,
20392049 eval_pending_queries_map : PendingQueriesMap ,
20402050 eval_generation_config ,
20412051):
@@ -2083,18 +2093,19 @@ def maybe_evaluate(
20832093 ** eval_reward_metrics ,
20842094 }
20852095 print_rich_single_line_metrics (eval_metrics )
2086-
2096+ for key , value in eval_metrics .items ():
2097+ writer .add_scalar (key , value , episode )
20872098 table = {}
20882099 table ["prompt" ] = tokenizer .batch_decode (eval_batch .queries if eval_batch else [])
20892100 table ["response" ] = eval_decoded_responses
20902101 table ["response" ] = [item .replace (tokenizer .pad_token , "" ) for item in table ["response" ]]
20912102 table ["scores" ] = eval_scores
20922103 table ["ground_truth" ] = eval_batch .ground_truths if eval_batch else []
20932104 df = pd .DataFrame (table )
2094-
20952105 if args .with_tracking :
2096- eval_metrics ["sample_completions" ] = wandb .Table (dataframe = df )
2097- wandb .log (eval_metrics , step = episode )
2106+ import wandb
2107+
2108+ wandb .log ({"sample_completions" : wandb .Table (dataframe = df )})
20982109 else :
20992110 print_rich_table (df .iloc [:1 ])
21002111 del table
@@ -2229,8 +2240,11 @@ async def reward_fn(
22292240
22302241def cleanup_judge_clients ():
22312242 """Cleans up all LLM judge clients and shutdown Ray."""
2232- asyncio .run (cleanup_all_llm_judge_clients ())
2233- logger .info ("✅ LLM judge clients cleaned up" )
2243+ try :
2244+ asyncio .run (cleanup_all_llm_judge_clients ())
2245+ logger .info ("✅ LLM judge clients cleaned up" )
2246+ except Exception as cleanup_error :
2247+ logger .warning (f"Error during LLM judge cleanup: { cleanup_error } " )
22342248 ray .shutdown ()
22352249
22362250
@@ -2263,7 +2277,12 @@ def cleanup_training_resources(
22632277 queues [0 ].put (ShutdownSentinel (), timeout = 1 )
22642278
22652279 logger .info ("Shutting down Ray queues..." )
2266- [queue .shutdown () for queue in queues ]
2280+ for queue in queues :
2281+ try :
2282+ queue .shutdown ()
2283+ except Exception as e :
2284+ logger .warning (f"Error shutting down Ray queue: { e } " )
2285+
22672286 logger .info ("Shutting down thread pool executor..." )
22682287 executor .shutdown (wait = True )
22692288
@@ -2274,7 +2293,7 @@ def cleanup_training_resources(
22742293def main (args : Args , tc : TokenizerConfig , model_config : ModelConfig , num_eval_samples : int = 32 ):
22752294 tokenizer = make_tokenizer (tc , model_config )
22762295 args = setup_runtime_variables (args )
2277- beaker_config , wandb_url = setup_experiment_tracking (args , tc , model_config )
2296+ beaker_config , writer , wandb_url = setup_experiment_tracking (args , tc , model_config )
22782297
22792298 train_dataset , eval_dataset = setup_datasets (args , tc , tokenizer )
22802299 if args .cache_dataset_only :
@@ -2412,11 +2431,13 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
24122431 collated_data ,
24132432 tokenizer ,
24142433 data_thread_metrics ,
2434+ {},
24152435 episode ,
24162436 training_step ,
24172437 num_total_tokens ,
24182438 start_time ,
24192439 train_dataset ,
2440+ writer ,
24202441 wandb_url ,
24212442 tc .chat_template_name ,
24222443 )
@@ -2429,6 +2450,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
24292450 eval_batch ,
24302451 reward_fn ,
24312452 episode ,
2453+ writer ,
24322454 eval_pending_queries_map ,
24332455 generation_configs ["eval" ],
24342456 )
0 commit comments