1616import chz
1717import tinker
1818from tinker .lib .public_interfaces import APIFuture
19+
1920from tinker_cookbook import checkpoint_utils
2021from tinker_cookbook .display import colorize_example
2122from tinker_cookbook .eval .evaluators import (
3132from tinker_cookbook .utils import ml_log
3233from tinker_cookbook .utils .lr_scheduling import compute_schedule_lr_multiplier
3334from tinker_cookbook .utils .misc_utils import timed
35+ from tinker_cookbook .utils .trace import get_scope_context , scope , trace_init
3436
3537logger = logging .getLogger (__name__ )
3638
@@ -72,6 +74,8 @@ class Config:
7274 wandb_project : str | None = None
7375 wandb_name : str | None = None
7476
77+ enable_trace : bool = False
78+
7579
7680@dataclass
7781class SubmittedBatch :
@@ -87,6 +91,7 @@ class SubmittedBatch:
8791 infrequent_eval_metrics : dict [str , float ] | None = None
8892
8993
94+ @scope
9095async def run_evals (
9196 evaluators : list [Evaluator ],
9297 training_client : tinker .TrainingClient ,
@@ -102,29 +107,42 @@ async def run_evals(
102107 checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next
103108 to the same-step training metrics.
104109 """
110+ context = get_scope_context ()
111+ context .attributes ["step" ] = step
112+
105113 metrics = {}
106114 sampling_client = None
107115
108- for evaluator in evaluators :
116+ @scope
117+ async def run_evaluator (evaluator : Evaluator ) -> dict [str , float ]:
118+ context = get_scope_context ()
119+ context .attributes ["step" ] = step
120+ context .attributes ["evaluator_name" ] = type (evaluator ).__name__
109121 if isinstance (evaluator , TrainingClientEvaluator ):
110- eval_metrics = await evaluator (training_client )
122+ context .attributes ["evaluator_type" ] = "TrainingClientEvaluator"
123+ return await evaluator (training_client )
111124 elif isinstance (evaluator , SamplingClientEvaluator ):
125+ context .attributes ["evaluator_type" ] = "SamplingClientEvaluator"
112126 # Create sampling client lazily, only when needed
127+ nonlocal sampling_client
113128 if sampling_client is None :
114129 # Snapshot the current pre-step weights and create a new sampling client.
115130 sampling_client = await training_client .save_weights_and_get_sampling_client_async (
116131 f"evals_step_{ step } "
117132 )
118- eval_metrics = await evaluator (sampling_client )
133+ return await evaluator (sampling_client )
119134 else :
120135 raise ValueError (f"Unknown evaluator type: { type (evaluator )} " )
121136
137+ for evaluator in evaluators :
138+ eval_metrics = await run_evaluator (evaluator )
122139 # Add test/ prefix to all metrics
123140 metrics .update ({f"test/{ k } " : v for k , v in eval_metrics .items ()})
124141
125142 return metrics
126143
127144
145+ @scope
128146async def main (config : Config ):
129147 """Run the standard supervised learning loop used by the supervised recipes.
130148
@@ -156,6 +174,18 @@ async def main(config: Config):
156174 config = config ,
157175 do_configure_logging_module = True ,
158176 )
177+ if config .enable_trace :
178+ # Get and rename the current (main) task
179+ current_task = asyncio .current_task ()
180+ if current_task is not None :
181+ current_task .set_name ("main" )
182+ trace_events_path = os .path .join (config .log_path , "trace_events.jsonl" )
183+ logger .info (f"Tracing is enabled. Trace events will be saved to { trace_events_path } " )
184+ logger .info (
185+ f"Run `python tinker_cookbook/utils/trace.py { trace_events_path } trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/"
186+ )
187+ trace_init (output_file = os .path .join (config .log_path , "trace_events.jsonl" ))
188+
159189 service_client = tinker .ServiceClient (base_url = config .base_url )
160190 load_state_path : str | None = (
161191 resume_info ["state_path" ] if resume_info else config .load_checkpoint_path
@@ -192,8 +222,12 @@ async def main(config: Config):
192222 f"Training for { n_batches } batches x { config .num_epochs } epochs = { n_batches * config .num_epochs } steps"
193223 )
194224
225+ @scope
195226 async def submit_batch (epoch_idx : int , batch_idx : int ) -> SubmittedBatch :
196227 step = epoch_idx * n_batches + batch_idx
228+ context = get_scope_context ()
229+ context .attributes ["step" ] = step
230+
197231 batch_start_time = time .time ()
198232 metrics : dict [str , int | float | str ] = {"epoch" : epoch_idx }
199233 metrics ["progress" ] = step / progress_denominator
@@ -250,7 +284,11 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
250284 infrequent_eval_metrics = infrequent_eval_metrics ,
251285 )
252286
287+ @scope
253288 async def finish_batch (submitted : SubmittedBatch ):
289+ context = get_scope_context ()
290+ context .attributes ["step" ] = submitted .step
291+
254292 metrics = submitted .metrics
255293 metrics ["progress" ] = min ((submitted .step + 1 ) / progress_denominator , 1.0 )
256294
0 commit comments