Skip to content

Commit 5f5ce26

Browse files
authored
[tinker-cookbook] supervised: add tracing annotations (#88)
1 parent 0adfc25 commit 5f5ce26

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

tinker_cookbook/checkpoint_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,27 @@
77
import tinker
88

99
from tinker_cookbook.utils.file_utils import read_jsonl
10+
from tinker_cookbook.utils.trace import get_scope_context, scope
1011

1112
CHECKPOINTS_BASE_NAME = "checkpoints.jsonl"
1213

1314
logger = logging.getLogger(__name__)
1415

1516

17+
@scope
1618
def load_checkpoints_file(log_dir: str) -> list[dict[str, Any]]:
1719
checkpoint_path = os.path.join(log_dir, CHECKPOINTS_BASE_NAME)
1820
if not os.path.exists(checkpoint_path):
1921
logger.info(f"No checkpoints found at {checkpoint_path}")
2022
return []
2123

2224
logger.info(f"Reading checkpoints from {checkpoint_path}")
25+
context = get_scope_context()
26+
context.attributes["checkpoint_path"] = checkpoint_path
2327
return read_jsonl(checkpoint_path)
2428

2529

30+
@scope
2631
def get_last_checkpoint(log_dir: str, required_key: str = "state_path") -> dict[str, Any] | None:
2732
"""
2833
Get the last checkpoint from the checkpoints.jsonl file in the specified log directory.
@@ -49,6 +54,7 @@ def get_last_checkpoint(log_dir: str, required_key: str = "state_path") -> dict[
4954
return None
5055

5156

57+
@scope
5258
async def save_checkpoint_async(
5359
training_client: tinker.TrainingClient,
5460
name: str,
@@ -72,6 +78,8 @@ async def save_checkpoint_async(
7278

7379
results = {k: await v.result_async() for k, v in futures.items()}
7480
paths = {k + "_path": v.path for k, v in results.items()}
81+
context = get_scope_context()
82+
context.attributes.update(paths)
7583
logger.info(f"Saved checkpoints: {paths}")
7684
full_dict = {"name": name, **loop_state, **paths}
7785
with open(os.path.join(log_path, "checkpoints.jsonl"), "a") as f:
@@ -80,6 +88,7 @@ async def save_checkpoint_async(
8088
return paths
8189

8290

91+
@scope
8392
def save_checkpoint(
8493
training_client: tinker.TrainingClient,
8594
name: str,

tinker_cookbook/supervised/train.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import chz
1717
import tinker
1818
from tinker.lib.public_interfaces import APIFuture
19+
1920
from tinker_cookbook import checkpoint_utils
2021
from tinker_cookbook.display import colorize_example
2122
from tinker_cookbook.eval.evaluators import (
@@ -31,6 +32,7 @@
3132
from tinker_cookbook.utils import ml_log
3233
from tinker_cookbook.utils.lr_scheduling import compute_schedule_lr_multiplier
3334
from tinker_cookbook.utils.misc_utils import timed
35+
from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init
3436

3537
logger = logging.getLogger(__name__)
3638

@@ -72,6 +74,8 @@ class Config:
7274
wandb_project: str | None = None
7375
wandb_name: str | None = None
7476

77+
enable_trace: bool = False
78+
7579

7680
@dataclass
7781
class SubmittedBatch:
@@ -87,6 +91,7 @@ class SubmittedBatch:
8791
infrequent_eval_metrics: dict[str, float] | None = None
8892

8993

94+
@scope
9095
async def run_evals(
9196
evaluators: list[Evaluator],
9297
training_client: tinker.TrainingClient,
@@ -102,29 +107,42 @@ async def run_evals(
102107
checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next
103108
to the same-step training metrics.
104109
"""
110+
context = get_scope_context()
111+
context.attributes["step"] = step
112+
105113
metrics = {}
106114
sampling_client = None
107115

108-
for evaluator in evaluators:
116+
@scope
117+
async def run_evaluator(evaluator: Evaluator) -> dict[str, float]:
118+
context = get_scope_context()
119+
context.attributes["step"] = step
120+
context.attributes["evaluator_name"] = type(evaluator).__name__
109121
if isinstance(evaluator, TrainingClientEvaluator):
110-
eval_metrics = await evaluator(training_client)
122+
context.attributes["evaluator_type"] = "TrainingClientEvaluator"
123+
return await evaluator(training_client)
111124
elif isinstance(evaluator, SamplingClientEvaluator):
125+
context.attributes["evaluator_type"] = "SamplingClientEvaluator"
112126
# Create sampling client lazily, only when needed
127+
nonlocal sampling_client
113128
if sampling_client is None:
114129
# Snapshot the current pre-step weights and create a new sampling client.
115130
sampling_client = await training_client.save_weights_and_get_sampling_client_async(
116131
f"evals_step_{step}"
117132
)
118-
eval_metrics = await evaluator(sampling_client)
133+
return await evaluator(sampling_client)
119134
else:
120135
raise ValueError(f"Unknown evaluator type: {type(evaluator)}")
121136

137+
for evaluator in evaluators:
138+
eval_metrics = await run_evaluator(evaluator)
122139
# Add test/ prefix to all metrics
123140
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})
124141

125142
return metrics
126143

127144

145+
@scope
128146
async def main(config: Config):
129147
"""Run the standard supervised learning loop used by the supervised recipes.
130148
@@ -156,6 +174,18 @@ async def main(config: Config):
156174
config=config,
157175
do_configure_logging_module=True,
158176
)
177+
if config.enable_trace:
178+
# Get and rename the current (main) task
179+
current_task = asyncio.current_task()
180+
if current_task is not None:
181+
current_task.set_name("main")
182+
trace_events_path = os.path.join(config.log_path, "trace_events.jsonl")
183+
logger.info(f"Tracing is enabled. Trace events will be saved to {trace_events_path}")
184+
logger.info(
185+
f"Run `python tinker_cookbook/utils/trace.py {trace_events_path} trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/"
186+
)
187+
trace_init(output_file=os.path.join(config.log_path, "trace_events.jsonl"))
188+
159189
service_client = tinker.ServiceClient(base_url=config.base_url)
160190
load_state_path: str | None = (
161191
resume_info["state_path"] if resume_info else config.load_checkpoint_path
@@ -192,8 +222,12 @@ async def main(config: Config):
192222
f"Training for {n_batches} batches x {config.num_epochs} epochs = {n_batches * config.num_epochs} steps"
193223
)
194224

225+
@scope
195226
async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
196227
step = epoch_idx * n_batches + batch_idx
228+
context = get_scope_context()
229+
context.attributes["step"] = step
230+
197231
batch_start_time = time.time()
198232
metrics: dict[str, int | float | str] = {"epoch": epoch_idx}
199233
metrics["progress"] = step / progress_denominator
@@ -250,7 +284,11 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
250284
infrequent_eval_metrics=infrequent_eval_metrics,
251285
)
252286

287+
@scope
253288
async def finish_batch(submitted: SubmittedBatch):
289+
context = get_scope_context()
290+
context.attributes["step"] = submitted.step
291+
254292
metrics = submitted.metrics
255293
metrics["progress"] = min((submitted.step + 1) / progress_denominator, 1.0)
256294

0 commit comments

Comments
 (0)