-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add GCP workload observability feature
- Loading branch information
Showing
5 changed files
with
255 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
import datetime | ||
import os | ||
import time | ||
import queue | ||
import threading | ||
|
||
import max_logging | ||
import requests # type: ignore[pyi-error] | ||
import jax | ||
|
||
from google.api import metric_pb2, monitored_resource_pb2 | ||
from google.api_core.exceptions import GoogleAPIError | ||
from google.cloud import monitoring_v3 | ||
from urllib3.util.retry import Retry | ||
|
||
|
||
_METADATA_SERVER_URL = "http://metadata.google.internal/computeMetadata/v1/" | ||
_METADATA_HEADERS = {"Metadata-Flavor": "Google"} | ||
|
||
|
||
class GCPWorkloadMonitor: | ||
"""Interface for reporting metrics to GCP for monitoring.""" | ||
|
||
def __init__(self, run_name: str): | ||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%f") | ||
self.workload_id = f"{run_name if run_name else 'maxtext-unnamed'}-{timestamp}" | ||
self.zone = get_node_zone() | ||
self.project_id = get_gcp_project_id() | ||
self.client = monitoring_v3.MetricServiceClient() | ||
self.heartbeat_reporting_started = False | ||
self.performance_reporting_started = False | ||
self.termination_event = threading.Event() | ||
|
||
def __del__(self): | ||
self.termination_event.set() | ||
|
||
def start_heartbeat_reporting_thread(self, interval: int): | ||
"""Starts a thread that reports heartbeat every {interval} seconds until termination event is set.""" | ||
if self.heartbeat_reporting_started: | ||
raise RuntimeError("Heartbeat reporting thread already started") | ||
max_logging.log("Starting background thread for reporting heartbeat for workload observability") | ||
self.heartbeat_reporting_started = True | ||
t = threading.Thread(target=self._report_heartbeat_thread, args=(interval,)) | ||
t.daemon = True | ||
t.start() | ||
|
||
def start_performance_reporting_thread(self, metrics_queue: queue.Queue): | ||
"""Starts a thread that reports performance metric sent to metrics_queue until termination event is set.""" | ||
if self.performance_reporting_started: | ||
raise RuntimeError("Performance reporting thread already started") | ||
max_logging.log("Starting background thread for reporting performance for workload observability") | ||
self.performance_reporting_started = True | ||
t = threading.Thread(target=self._report_performance_thread, args=(metrics_queue,)) | ||
t.daemon = True | ||
t.start() | ||
|
||
def _report_heartbeat_thread(self, interval: int): | ||
"""Reports heartbeat metric to GCP every {interval} seconds until termination event is set.""" | ||
local_rank = os.getenv("LOCAL_RANK", "0") | ||
global_rank = jax.process_index() | ||
while not self.termination_event.is_set(): | ||
self._report_heartbeat(local_rank, str(global_rank)) | ||
time.sleep(interval) | ||
|
||
def _report_performance_thread(self, metrics_queue: queue.Queue): | ||
"""Reports performance metric to GCP whenever new metric arrives at the metrics_queue until termination event is set.""" | ||
while not self.termination_event.is_set(): | ||
try: | ||
# adding a timeout of 1s to ensure we don't block indefinitely and miss the stop event | ||
performance_metric = metrics_queue.get(timeout=1) | ||
self._report_performance(performance_metric) | ||
except queue.Empty: | ||
continue | ||
|
||
def _report_heartbeat(self, local_rank: str, global_rank: str): | ||
"""Reports heartbeat metric for the process specified by the given local rank & global rank.""" | ||
try: | ||
now = time.time() | ||
seconds = int(now) | ||
nanos = int((now - seconds) * 10**9) | ||
|
||
# Create a TimeSeries object for the heartbeat metric | ||
series = monitoring_v3.TimeSeries( | ||
metric=metric_pb2.Metric( | ||
type="compute.googleapis.com/workload_process/heartbeat", | ||
labels={ | ||
"local_rank": local_rank, | ||
"instance_id": _get_gcp_metadata(category="instance", attribute="id"), | ||
}, | ||
), | ||
resource=monitored_resource_pb2.MonitoredResource( | ||
type="compute.googleapis.com/WorkloadProcess", | ||
labels={ | ||
"project_id": self.project_id, | ||
"location": self.zone, | ||
"workload_id": self.workload_id, | ||
"replica_id": "0", | ||
"process_id": global_rank, | ||
}, | ||
), | ||
points=[ | ||
monitoring_v3.Point( | ||
interval=monitoring_v3.TimeInterval(end_time={"seconds": seconds, "nanos": nanos}), | ||
value=monitoring_v3.TypedValue(bool_value=True), | ||
), | ||
], | ||
) | ||
|
||
# Send data to Google Cloud Monitoring | ||
self.client.create_time_series( | ||
request={"name": f"projects/{self.project_id}", "time_series": [series]}, | ||
timeout=30, | ||
) | ||
max_logging.log("Heartbeat metric successfully sent to GCP.") | ||
except GoogleAPIError as e: | ||
max_logging.log(f"Failed to send heartbeat to GCP: {e}") | ||
except Exception as e: | ||
max_logging.log(f"Unexpected error while sending heartbeat to GCP: {e}") | ||
|
||
def _report_performance(self, performance_metric): | ||
"""Reports performance metric to GCP.""" | ||
try: | ||
now = time.time() | ||
seconds = int(now) | ||
nanos = int((now - seconds) * 10**9) | ||
|
||
# Create a TimeSeries object for the performance metric | ||
series = monitoring_v3.TimeSeries( | ||
metric=metric_pb2.Metric( | ||
type="compute.googleapis.com/workload/performance", | ||
), | ||
resource=monitored_resource_pb2.MonitoredResource( | ||
type="compute.googleapis.com/Workload", | ||
labels={ | ||
"location": self.zone, | ||
"workload_id": self.workload_id, | ||
"replica_id": "0", | ||
}, | ||
), | ||
points=[ | ||
monitoring_v3.Point( | ||
interval=monitoring_v3.TimeInterval(end_time={"seconds": seconds, "nanos": nanos}), | ||
value=monitoring_v3.TypedValue(double_value=performance_metric), | ||
), | ||
], | ||
) | ||
|
||
# Send data to Google Cloud Monitoring | ||
self.client.create_time_series( | ||
request={"name": f"projects/{self.project_id}", "time_series": [series]}, | ||
timeout=30, | ||
) | ||
max_logging.log("Performance metric successfully sent to GCP.") | ||
except GoogleAPIError as e: | ||
max_logging.log(f"Failed to send performance to GCP: {e}") | ||
except Exception as e: | ||
max_logging.log(f"Unexpected error while sending performance to GCP: {e}") | ||
|
||
|
||
def _get_gcp_metadata(category: str, attribute: str, timeout=5, retries=3): | ||
""" | ||
Fetch the specified attribute from GCP metadata server. | ||
Args: | ||
category (str): The high-level metadata category (ex: 'instance', 'project'). | ||
attribute (str): The attribute to fetch under this category (ex: 'id', 'zone'). | ||
timeout (int): Timeout for the request in seconds. | ||
retries (int): Number of retry attempts for transient failures. | ||
Returns: | ||
str: The metadata value as a string, or None if the request fails. | ||
""" | ||
target_url = f"{_METADATA_SERVER_URL}{category}/{attribute}" | ||
|
||
session = requests.Session() | ||
retry_strategy = Retry( | ||
total=retries, | ||
backoff_factor=0.5, | ||
# Retry on the following status codes | ||
status_forcelist=[429, 500, 502, 503, 504], | ||
) | ||
adapter = requests.adapters.HTTPAdapter(max_retries=retry_strategy) | ||
session.mount("http://", adapter) | ||
|
||
try: | ||
response = session.get(target_url, headers=_METADATA_HEADERS, timeout=timeout) | ||
response.raise_for_status() | ||
return response.text | ||
except requests.exceptions.RequestException as e: | ||
max_logging.log(f"Failed to retrieve metadata for {category}/{attribute}: {e}") | ||
return None | ||
|
||
|
||
def get_gcp_project_id(): | ||
"""Returns the project id of the current GCP project.""" | ||
return _get_gcp_metadata("project", "project-id") | ||
|
||
|
||
def get_node_zone(): | ||
"""Returns the zone of the GCE instance.""" | ||
zone_path = _get_gcp_metadata("instance", "zone") | ||
# example zone_path: "projects/123456789/zones/us-central1-a" | ||
return zone_path.rsplit("/", 1)[-1] if zone_path else None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Enable GCP Workload Observabiltiy | ||
This guide provides an overview on how to enable GCP workload observability for your MaxText workload. | ||
|
||
## Overview | ||
Google offers a monitoring and alerting feature that is well suited for critical MaxText workloads sensitive to infrastructure changes. | ||
Once enabled, metrics will be automatically sent to [Cloud Monarch](https://research.google/pubs/monarch-googles-planet-scale-in-memory-time-series-database/) for monitoring. | ||
If a metric hits its pre-defined threshold, the Google Cloud on-call team will be alerted to see if any action is needed. | ||
|
||
The feature currently supports heartbeat and performance (training step time in seconds) metrics. In the near future, support for the goodput metric will also be added. | ||
Users should work with their Customer Engineer (CE) and the Google team to define appropriate thresholds for the performance metrics. | ||
|
||
This guide layouts how to enable the feature for your MaxText workload. | ||
|
||
## Enabling GCP Workload Observabiltiy | ||
User can control which metric they want to report via config: | ||
|
||
### Heartbeat metric | ||
- This metric will be a boolean flag. | ||
- To turn on this metric, set `report_heartbeat_metric_for_gcp_monitoring` to `True` | ||
- To control the frequency of heartbeat reporting (default is every 5 seconds), set `heartbeat_reporting_interval_in_seconds` to your desired value. | ||
|
||
### Performance metric | ||
- This metric will be a double, capturing the training step time in seconds. | ||
- To turn on this metric, set `report_performance_metric_for_gcp_monitoring` to `True` | ||
|
||
For an example, please refer to [base.yml](../MaxText/configs/base.yml). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters