diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index e784535e6..6ec504b99 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -460,6 +460,11 @@ monitor_goodput: True goodput_upload_interval_seconds: 60 enable_pathways_goodput: False +# GCP workload monitoring +report_heartbeat_metric_for_gcp_monitoring: False +heartbeat_reporting_interval_in_seconds: 5 +report_performance_metric_for_gcp_monitoring: False + # Vertex AI Tensorboard Configurations - https://github.com/google/maxtext/tree/main/getting_started/Use_Vertex_AI_Tensorboard.md # Set to True for GCE, False if running via XPK use_vertex_tensorboard: False diff --git a/MaxText/monitoring/gcp_workload_monitor.py b/MaxText/monitoring/gcp_workload_monitor.py new file mode 100644 index 000000000..b6c87acf7 --- /dev/null +++ b/MaxText/monitoring/gcp_workload_monitor.py @@ -0,0 +1,219 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import datetime +import os +import time +import queue +import threading + +import max_logging +import requests # type: ignore[pyi-error] +import jax + +from google.api import metric_pb2, monitored_resource_pb2 +from google.api_core.exceptions import GoogleAPIError +from google.cloud import monitoring_v3 +from urllib3.util.retry import Retry + + +_METADATA_SERVER_URL = "http://metadata.google.internal/computeMetadata/v1/" +_METADATA_HEADERS = {"Metadata-Flavor": "Google"} + + +class GCPWorkloadMonitor: + """Interface for reporting metrics to GCP for monitoring.""" + + def __init__(self, run_name: str): + timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%f") + self.workload_id = f"{run_name if run_name else 'maxtext-unnamed'}-{timestamp}" + self.zone = get_node_zone() + self.project_id = get_gcp_project_id() + self.client = monitoring_v3.MetricServiceClient() + self.heartbeat_reporting_started = False + self.performance_reporting_started = False + self.termination_event = threading.Event() + + def __del__(self): + self.termination_event.set() + + def start_heartbeat_reporting_thread(self, interval: int): + """Starts a thread that reports heartbeat every {interval} seconds until termination event is set.""" + if self.heartbeat_reporting_started: + raise RuntimeError("Heartbeat reporting thread already started") + max_logging.log("Starting background thread for reporting heartbeat for workload observability") + self.heartbeat_reporting_started = True + t = threading.Thread(target=self._report_heartbeat_thread, args=(interval,)) + t.daemon = True + t.start() + + def start_performance_reporting_thread(self, metrics_queue: queue.Queue): + """Starts a thread that reports performance metric sent to metrics_queue until termination event is set.""" + if self.performance_reporting_started: + raise RuntimeError("Performance reporting thread already started") + max_logging.log("Starting background thread for reporting performance for workload observability") + self.performance_reporting_started = True + t = threading.Thread(target=self._report_performance_thread, args=(metrics_queue,)) + t.daemon = True + t.start() + + def _report_heartbeat_thread(self, interval: int): + """Reports heartbeat metric to GCP every {interval} seconds until termination event is set.""" + local_rank = os.getenv("LOCAL_RANK", "0") + global_rank = jax.process_index() + while not self.termination_event.is_set(): + self._report_heartbeat(local_rank, str(global_rank)) + time.sleep(interval) + + def _report_performance_thread(self, metrics_queue: queue.Queue): + """Reports performance metric to GCP whenever new metric arrives at the metrics_queue until termination event is set.""" + while not self.termination_event.is_set(): + try: + # adding a timeout of 1s to ensure we don't block indefinitely and miss the stop event + performance_metric = metrics_queue.get(timeout=1) + self._report_performance(performance_metric) + except queue.Empty: + continue + + def _report_heartbeat(self, local_rank: str, global_rank: str): + """Reports heartbeat metric for the process specified by the given local rank & global rank.""" + try: + now = time.time() + seconds = int(now) + nanos = int((now - seconds) * 10**9) + + # Create a TimeSeries object for the heartbeat metric + series = monitoring_v3.TimeSeries( + metric=metric_pb2.Metric( + type="compute.googleapis.com/workload_process/heartbeat", + labels={ + "local_rank": local_rank, + "instance_id": _get_gcp_metadata(category="instance", attribute="id"), + }, + ), + resource=monitored_resource_pb2.MonitoredResource( + type="compute.googleapis.com/WorkloadProcess", + labels={ + "project_id": self.project_id, + "location": self.zone, + "workload_id": self.workload_id, + "replica_id": "0", + "process_id": global_rank, + }, + ), + points=[ + monitoring_v3.Point( + interval=monitoring_v3.TimeInterval(end_time={"seconds": seconds, "nanos": nanos}), + value=monitoring_v3.TypedValue(bool_value=True), + ), + ], + ) + + # Send data to Google Cloud Monitoring + self.client.create_time_series( + request={"name": f"projects/{self.project_id}", "time_series": [series]}, + timeout=30, + ) + max_logging.log("Heartbeat metric successfully sent to GCP.") + except GoogleAPIError as e: + max_logging.log(f"Failed to send heartbeat to GCP: {e}") + except Exception as e: + max_logging.log(f"Unexpected error while sending heartbeat to GCP: {e}") + + def _report_performance(self, performance_metric): + """Reports performance metric to GCP.""" + try: + now = time.time() + seconds = int(now) + nanos = int((now - seconds) * 10**9) + + # Create a TimeSeries object for the performance metric + series = monitoring_v3.TimeSeries( + metric=metric_pb2.Metric( + type="compute.googleapis.com/workload/performance", + ), + resource=monitored_resource_pb2.MonitoredResource( + type="compute.googleapis.com/Workload", + labels={ + "location": self.zone, + "workload_id": self.workload_id, + "replica_id": "0", + }, + ), + points=[ + monitoring_v3.Point( + interval=monitoring_v3.TimeInterval(end_time={"seconds": seconds, "nanos": nanos}), + value=monitoring_v3.TypedValue(double_value=performance_metric), + ), + ], + ) + + # Send data to Google Cloud Monitoring + self.client.create_time_series( + request={"name": f"projects/{self.project_id}", "time_series": [series]}, + timeout=30, + ) + max_logging.log("Performance metric successfully sent to GCP.") + except GoogleAPIError as e: + max_logging.log(f"Failed to send performance to GCP: {e}") + except Exception as e: + max_logging.log(f"Unexpected error while sending performance to GCP: {e}") + + +def _get_gcp_metadata(category: str, attribute: str, timeout=5, retries=3): + """ + Fetch the specified attribute from GCP metadata server. + + Args: + category (str): The high-level metadata category (ex: 'instance', 'project'). + attribute (str): The attribute to fetch under this category (ex: 'id', 'zone'). + timeout (int): Timeout for the request in seconds. + retries (int): Number of retry attempts for transient failures. + + Returns: + str: The metadata value as a string, or None if the request fails. + """ + target_url = f"{_METADATA_SERVER_URL}{category}/{attribute}" + + session = requests.Session() + retry_strategy = Retry( + total=retries, + backoff_factor=0.5, + # Retry on the following status codes + status_forcelist=[429, 500, 502, 503, 504], + ) + adapter = requests.adapters.HTTPAdapter(max_retries=retry_strategy) + session.mount("http://", adapter) + + try: + response = session.get(target_url, headers=_METADATA_HEADERS, timeout=timeout) + response.raise_for_status() + return response.text + except requests.exceptions.RequestException as e: + max_logging.log(f"Failed to retrieve metadata for {category}/{attribute}: {e}") + return None + + +def get_gcp_project_id(): + """Returns the project id of the current GCP project.""" + return _get_gcp_metadata("project", "project-id") + + +def get_node_zone(): + """Returns the zone of the GCE instance.""" + zone_path = _get_gcp_metadata("instance", "zone") + # example zone_path: "projects/123456789/zones/us-central1-a" + return zone_path.rsplit("/", 1)[-1] if zone_path else None diff --git a/MaxText/train.py b/MaxText/train.py index 711a5e82d..83d52399f 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -25,6 +25,7 @@ import sys import functools import time +import queue from typing import Sequence, Optional from absl import app @@ -52,6 +53,8 @@ from input_pipeline.input_pipeline_interface import create_data_iterator from layers import models +from monitoring.gcp_workload_monitor import GCPWorkloadMonitor + import jax.numpy as jnp from jax import random from jax.sharding import Mesh @@ -859,6 +862,15 @@ def train_loop(config, state=None): example_batch = None last_step_completion = datetime.datetime.now() + performance_metric_queue = None + if config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring: + gcp_workload_monitor = GCPWorkloadMonitor(config.run_name) + if config.report_heartbeat_metric_for_gcp_monitoring: + gcp_workload_monitor.start_heartbeat_reporting_thread(config.heartbeat_reporting_interval_in_seconds) + if config.report_performance_metric_for_gcp_monitoring: + performance_metric_queue = queue.Queue() + gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue) + for step in np.arange(start_step, config.steps): if step == first_profiling_step or prof.should_activate_periodic_profile(step): optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" @@ -875,11 +887,11 @@ def train_loop(config, state=None): with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): state, metrics = p_train_step(state, example_batch, nextrng) - new_time = datetime.datetime.now() - record_scalar_metrics( - metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens - ) - last_step_completion = new_time + step_time_delta = datetime.datetime.now() - last_step_completion + record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens) + if performance_metric_queue: + performance_metric_queue.put(step_time_delta.total_seconds()) + last_step_completion = datetime.datetime.now() if checkpoint_manager is not None: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] diff --git a/getting_started/GCP_Workload_Observability.md b/getting_started/GCP_Workload_Observability.md new file mode 100644 index 000000000..78e183889 --- /dev/null +++ b/getting_started/GCP_Workload_Observability.md @@ -0,0 +1,26 @@ +# Enable GCP Workload Observabiltiy +This guide provides an overview on how to enable GCP workload observability for your MaxText workload. + +## Overview +Google offers a monitoring and alerting feature that is well suited for critical MaxText workloads sensitive to infrastructure changes. +Once enabled, metrics will be automatically sent to [Cloud Monarch](https://research.google/pubs/monarch-googles-planet-scale-in-memory-time-series-database/) for monitoring. +If a metric hits its pre-defined threshold, the Google Cloud on-call team will be alerted to see if any action is needed. + +The feature currently supports heartbeat and performance (training step time in seconds) metrics. In the near future, support for the goodput metric will also be added. +Users should work with their Customer Engineer (CE) and the Google team to define appropriate thresholds for the performance metrics. + +This guide layouts how to enable the feature for your MaxText workload. + +## Enabling GCP Workload Observabiltiy +User can control which metric they want to report via config: + +### Heartbeat metric +- This metric will be a boolean flag. +- To turn on this metric, set `report_heartbeat_metric_for_gcp_monitoring` to `True` +- To control the frequency of heartbeat reporting (default is every 5 seconds), set `heartbeat_reporting_interval_in_seconds` to your desired value. + +### Performance metric +- This metric will be a double, capturing the training step time in seconds. +- To turn on this metric, set `report_performance_metric_for_gcp_monitoring` to `True` + +For an example, please refer to [base.yml](../MaxText/configs/base.yml). \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 19aa0b8e8..1c6a75b9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,9 @@ datasets gcsfs google-cloud-aiplatform==1.61.0 google-cloud-storage +google-cloud-monitoring +google-api-core +google-api-python-client grain-nightly flax>=0.8.0 ml-collections diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_stable_stack.txt index b1391098c..1576109d6 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_stable_stack.txt @@ -19,3 +19,6 @@ mlperf-logging@git+https://github.com/mlperf/logging.git google-jetstream jsonlines pathwaysutils@git+https://github.com/google/pathways-utils.git +google-cloud-monitoring +google-api-core +google-api-python-client \ No newline at end of file