Skip to content

Commit

Permalink
add GCP workload observability feature
Browse files Browse the repository at this point in the history
  • Loading branch information
jcyang43 committed Jan 15, 2025
1 parent ed5bb31 commit 9254b60
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 5 deletions.
5 changes: 5 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,11 @@ monitor_goodput: True
goodput_upload_interval_seconds: 60
enable_pathways_goodput: False

# GCP workload monitoring
report_heartbeat_metric_for_gcp_monitoring: False
heartbeat_reporting_interval_in_seconds: 5
report_performance_metric_for_gcp_monitoring: False

# Vertex AI Tensorboard Configurations - https://github.com/google/maxtext/tree/main/getting_started/Use_Vertex_AI_Tensorboard.md
# Set to True for GCE, False if running via XPK
use_vertex_tensorboard: False
Expand Down
203 changes: 203 additions & 0 deletions MaxText/monitoring/gcp_workload_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import datetime
import os
import time
import queue
import threading

import max_logging
import requests # type: ignore[pyi-error]
import jax

from google.api import metric_pb2, monitored_resource_pb2
from google.api_core.exceptions import GoogleAPIError
from google.cloud import monitoring_v3
from urllib3.util.retry import Retry


_METADATA_SERVER_URL = "http://metadata.google.internal/computeMetadata/v1/"
_METADATA_HEADERS = {"Metadata-Flavor": "Google"}


class GCPWorkloadMonitor:
"""Interface for reporting metrics to GCP for monitoring."""

def __init__(self, run_name: str):
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%f")
self.workload_id = f"{run_name if run_name else 'maxtext-unnamed'}-{timestamp}"
self.zone = get_node_zone()
self.project_id = get_gcp_project_id()
self.client = monitoring_v3.MetricServiceClient()
self.heartbeat_reporting_started = False
self.performance_reporting_started = False
self.termination_event = threading.Event()

def __del__(self):
self.termination_event.set()

def start_heartbeat_reporting_thread(self, interval: int):
"""Starts a thread that reports heartbeat every {interval} seconds until termination event is set."""
if self.heartbeat_reporting_started:
raise RuntimeError("Heartbeat reporting thread already started")
max_logging.log("Starting background thread for reporting heartbeat for workload observability")
self.heartbeat_reporting_started = True
t = threading.Thread(target=self._report_heartbeat_thread, args=(interval,))
t.daemon = True
t.start()

def start_performance_reporting_thread(self, metrics_queue: queue.Queue):
"""Starts a thread that reports performance metric sent to metrics_queue until termination event is set."""
if self.performance_reporting_started:
raise RuntimeError("Performance reporting thread already started")
max_logging.log("Starting background thread for reporting performance for workload observability")
self.performance_reporting_started = True
t = threading.Thread(target=self._report_performance_thread, args=(metrics_queue,))
t.daemon = True
t.start()

def _report_heartbeat_thread(self, interval: int):
"""Reports heartbeat metric to GCP every {interval} seconds until termination event is set."""
local_rank = os.getenv("LOCAL_RANK", "0")
global_rank = jax.process_index()
while not self.termination_event.is_set():
self._report_heartbeat(local_rank, str(global_rank))
time.sleep(interval)

def _report_performance_thread(self, metrics_queue: queue.Queue):
"""Reports performance metric to GCP whenever new metric arrives at the metrics_queue until termination event is set."""
while not self.termination_event.is_set():
try:
# adding a timeout of 1s to ensure we don't block indefinitely and miss the stop event
performance_metric = metrics_queue.get(timeout=1)
self._report_performance(performance_metric)
except queue.Empty:
continue

def _report_heartbeat(self, local_rank: str, global_rank: str):
"""Reports heartbeat metric for the process specified by the given local rank & global rank."""
try:
now = time.time()
seconds = int(now)
nanos = int((now - seconds) * 10**9)

# Create a TimeSeries object for the heartbeat metric
series = monitoring_v3.TimeSeries(
metric=metric_pb2.Metric(
type="compute.googleapis.com/workload_process/heartbeat",
labels={
"local_rank": local_rank,
"instance_id": _get_gcp_metadata(category="instance", attribute="id"),
},
),
resource=monitored_resource_pb2.MonitoredResource(
type="compute.googleapis.com/WorkloadProcess",
labels={
"project_id": self.project_id,
"location": self.zone,
"workload_id": self.workload_id,
"replica_id": "0",
"process_id": global_rank,
},
),
points=[
monitoring_v3.Point(
interval=monitoring_v3.TimeInterval(end_time={"seconds": seconds, "nanos": nanos}),
value=monitoring_v3.TypedValue(bool_value=True),
),
],
)

# Send data to Google Cloud Monitoring
self.client.create_time_series(
request={"name": f"projects/{self.project_id}", "time_series": [series]},
timeout=30,
)
max_logging.log("Heartbeat metric successfully sent to GCP.")
except GoogleAPIError as e:
max_logging.log(f"Failed to send heartbeat to GCP: {e}")
except Exception as e:
max_logging.log(f"Unexpected error while sending heartbeat to GCP: {e}")

def _report_performance(self, performance_metric):
"""Reports performance metric to GCP."""
try:
now = time.time()
seconds = int(now)
nanos = int((now - seconds) * 10**9)

# Create a TimeSeries object for the performance metric
series = monitoring_v3.TimeSeries(
metric=metric_pb2.Metric(
type="compute.googleapis.com/workload/performance",
),
resource=monitored_resource_pb2.MonitoredResource(
type="compute.googleapis.com/Workload",
labels={
"location": self.zone,
"workload_id": self.workload_id,
"replica_id": "0",
},
),
points=[
monitoring_v3.Point(
interval=monitoring_v3.TimeInterval(end_time={"seconds": seconds, "nanos": nanos}),
value=monitoring_v3.TypedValue(double_value=performance_metric),
),
],
)

# Send data to Google Cloud Monitoring
self.client.create_time_series(
request={"name": f"projects/{self.project_id}", "time_series": [series]},
timeout=30,
)
max_logging.log("Performance metric successfully sent to GCP.")
except GoogleAPIError as e:
max_logging.log(f"Failed to send performance to GCP: {e}")
except Exception as e:
max_logging.log(f"Unexpected error while sending performance to GCP: {e}")


def _get_gcp_metadata(category: str, attribute: str, timeout=5, retries=3):
"""
Fetch the specified attribute from GCP metadata server.
Args:
category (str): The high-level metadata category (ex: 'instance', 'project').
attribute (str): The attribute to fetch under this category (ex: 'id', 'zone').
timeout (int): Timeout for the request in seconds.
retries (int): Number of retry attempts for transient failures.
Returns:
str: The metadata value as a string, or None if the request fails.
"""
target_url = f"{_METADATA_SERVER_URL}{category}/{attribute}"

session = requests.Session()
retry_strategy = Retry(
total=retries,
backoff_factor=0.5,
# Retry on the following status codes
status_forcelist=[429, 500, 502, 503, 504],
)
adapter = requests.adapters.HTTPAdapter(max_retries=retry_strategy)
session.mount("http://", adapter)

try:
response = session.get(target_url, headers=_METADATA_HEADERS, timeout=timeout)
response.raise_for_status()
return response.text
except requests.exceptions.RequestException as e:
max_logging.log(f"Failed to retrieve metadata for {category}/{attribute}: {e}")
return None


def get_gcp_project_id():
"""Returns the project id of the current GCP project."""
return _get_gcp_metadata("project", "project-id")


def get_node_zone():
"""Returns the zone of the GCE instance."""
zone_path = _get_gcp_metadata("instance", "zone")
# example zone_path: "projects/123456789/zones/us-central1-a"
return zone_path.rsplit("/", 1)[-1] if zone_path else None
22 changes: 17 additions & 5 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import sys
import functools
import time
import queue

from typing import Sequence, Optional
from absl import app
Expand Down Expand Up @@ -52,6 +53,8 @@
from input_pipeline.input_pipeline_interface import create_data_iterator
from layers import models

from monitoring.gcp_workload_monitor import GCPWorkloadMonitor

import jax.numpy as jnp
from jax import random
from jax.sharding import Mesh
Expand Down Expand Up @@ -859,6 +862,15 @@ def train_loop(config, state=None):
example_batch = None
last_step_completion = datetime.datetime.now()

performance_metric_queue = None
if config.report_heartbeat_metric_for_gcp_monitoring or config.report_performance_metric_for_gcp_monitoring:
gcp_workload_monitor = GCPWorkloadMonitor(config.run_name)
if config.report_heartbeat_metric_for_gcp_monitoring:
gcp_workload_monitor.start_heartbeat_reporting_thread(config.heartbeat_reporting_interval_in_seconds)
if config.report_performance_metric_for_gcp_monitoring:
performance_metric_queue = queue.Queue()
gcp_workload_monitor.start_performance_reporting_thread(performance_metric_queue)

for step in np.arange(start_step, config.steps):
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
Expand All @@ -875,11 +887,11 @@ def train_loop(config, state=None):
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, nextrng)

new_time = datetime.datetime.now()
record_scalar_metrics(
metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens
)
last_step_completion = new_time
step_time_delta = datetime.datetime.now() - last_step_completion
record_scalar_metrics(metrics, step_time_delta, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
if performance_metric_queue:
performance_metric_queue.put(step_time_delta.total_seconds())
last_step_completion = datetime.datetime.now()

if checkpoint_manager is not None:
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
Expand Down
26 changes: 26 additions & 0 deletions getting_started/GCP_Workload_Observability.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Enable GCP Workload Observabiltiy
This guide provides an overview on how to enable GCP workload observability for your MaxText workload.

## Overview
Google offers a monitoring and alerting feature that is well suited for critical MaxText workloads sensitive to infrastructure changes.
Once enabled, metrics will be automatically sent to [Cloud Monarch](https://research.google/pubs/monarch-googles-planet-scale-in-memory-time-series-database/) for monitoring.
If a metric hits its pre-defined threshold, the Google Cloud on-call team will be alerted to see if any action is needed.

The feature currently supports heartbeat and performance (training step time in seconds) metrics. In the near future, support for the goodput metric will also be added.
Users should work with their Customer Engineer (CE) and the Google team to define appropriate thresholds for the performance metrics.

This guide layouts how to enable the feature for your MaxText workload.

## Enabling GCP Workload Observabiltiy
User can control which metric they want to report via config:

### Heartbeat metric
- This metric will be a boolean flag.
- To turn on this metric, set `report_heartbeat_metric_for_gcp_monitoring` to `True`
- To control the frequency of heartbeat reporting (default is every 5 seconds), set `heartbeat_reporting_interval_in_seconds` to your desired value.

### Performance metric
- This metric will be a double, capturing the training step time in seconds.
- To turn on this metric, set `report_performance_metric_for_gcp_monitoring` to `True`

For an example, please refer to [base.yml](../MaxText/configs/base.yml).
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ datasets
gcsfs
google-cloud-aiplatform==1.61.0
google-cloud-storage
google-cloud-monitoring
google-api-core
google-api-python-client
grain-nightly
flax>=0.8.0
ml-collections
Expand Down
3 changes: 3 additions & 0 deletions requirements_with_jax_stable_stack.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ mlperf-logging@git+https://github.com/mlperf/logging.git
google-jetstream
jsonlines
pathwaysutils@git+https://github.com/google/pathways-utils.git
google-cloud-monitoring
google-api-core
google-api-python-client

0 comments on commit 9254b60

Please sign in to comment.