Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add GCP workload observability feature #1167

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jcyang43
Copy link
Collaborator

Description

Add option to enable GCP workload monitoring for MaxText workloads.

  • GCP workload monitoring sends performance metrics (heartbeat & training step times) to cloud monarch for monitoring such that if a metric hits its pre-defined threshold, oncalls will be notified to see if any actions are needed. This is ideal for critical workloads sensitive to infrastructure changes.
  • Each metric can be configured to be on or off based on configs. Examples are included in MaxText/configs/base.yml
  • Documentation can be found at getting_started/GCP_Workload_Monitoring.md

Tests

Tested on trillium TPU and confirmed metrics sent to cloud monarch successfully if configs are enabled. No metrics will be sent to cloud monarch if configs are set to False.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens
)
last_step_completion = new_time
step_time_delta = datetime.datetime.now() - last_step_completion
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes like this need to be handled with care - jax is by default lazy e.g. functions return instantly, they do not block on finishing computation. In this case the results before and after this PR are roughly the same as confirmed by this diff https://diff.googleplex.com/#key=TLvLsJisjPpi (LHS main, RHS this PR)

The real blocking function is when an array is either checkpointed or printed - in this case this is done by write_metrics which is below (not record_scalar_metrics or p_train_step). This is fine since write_metrics is done after last_step_completion, so the step_time_delta will have to wait a real train_step worth of time

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you for the documentation!

@jcyang43 jcyang43 force-pushed the feature/gcp-workload-observabiltiy branch 2 times, most recently from 9254b60 to 2f55959 Compare January 15, 2025 20:58
@jcyang43 jcyang43 force-pushed the feature/gcp-workload-observabiltiy branch 4 times, most recently from 35cbe59 to 01df877 Compare January 17, 2025 23:10
@jcyang43 jcyang43 force-pushed the feature/gcp-workload-observabiltiy branch from 01df877 to 669a14c Compare January 17, 2025 23:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants