diff --git a/src/dvc_studio_client/post_live_metrics.py b/src/dvc_studio_client/post_live_metrics.py index cfd2374..d03da1d 100644 --- a/src/dvc_studio_client/post_live_metrics.py +++ b/src/dvc_studio_client/post_live_metrics.py @@ -1,3 +1,4 @@ +import json from os import getenv from typing import Any, Dict, Literal, Optional from urllib.parse import urljoin @@ -12,6 +13,13 @@ from .env import DVC_STUDIO_TOKEN, STUDIO_ENDPOINT, STUDIO_TOKEN from .schema import SCHEMAS_BY_TYPE +# Studio PROD and DEV have a hardcoded limit of 30MB for the request body. +MAX_REQUEST_SIZE = 29000000 +# Studio backend discards files larger than 10MB as big files when parsing commits. +MAX_PLOT_SIZE = 10000000 +# Studio backend limit number of files inside a plot directory. +MAX_NUMBER_OF_PLOTS = 200 + def get_studio_token_and_repo_url(studio_token=None, studio_repo_url=None): studio_token = studio_token or getenv(DVC_STUDIO_TOKEN) or getenv(STUDIO_TOKEN) @@ -22,6 +30,82 @@ def get_studio_token_and_repo_url(studio_token=None, studio_repo_url=None): return config.get("token"), config.get("repo_url") +def _single_post(url, body, token): + try: + response = requests.post( + url, + json=body, + headers={ + "Content-type": "application/json", + "Authorization": f"token {token}", + }, + timeout=(30, 5), + ) + except RequestException as e: + logger.warning(f"Failed to post to Studio: {e}") + return False + + message = response.content.decode() + logger.debug( + f"post_to_studio: {response.status_code=}" f", {message=}" if message else "" + ) + + if response.status_code != 200: + logger.warning(f"Failed to post to Studio: {message}") + return False + + return True + + +def _post_in_chunks(url, body, token): + plots = body.pop("plots") + + # First, post only metrics and params + if not _single_post(url, body, token): + return False + body.pop("metrics", None) + body.pop("params", None) + + # Studio backend has a limitation on the size of the request body. + # So we try to send as many plots as possible without xeceeding the limit. + body["plots"] = {} + total_size = 0 + for n, (plot_name, plot_data) in enumerate(plots.items()): + if n >= MAX_NUMBER_OF_PLOTS: + logger.warning( + f"Number of plots exceeds Studio limit ({MAX_NUMBER_OF_PLOTS}). " + "Some plots will not be sent." + ) + break + + if "data" in plot_data: + size = len(json.dumps(plot_data["data"]).encode("utf-8")) + elif "image" in plot_data: + size = len(plot_data["image"]) + + if size > MAX_PLOT_SIZE: + logger.warning( + f"Size of plot exceeds Studio limit ({MAX_PLOT_SIZE}). " + f"{plot_name} will not be sent." + ) + continue + + total_size += size + if total_size > MAX_REQUEST_SIZE: + logger.warning( + f"Total size of plots exceeds Studio limit ({MAX_REQUEST_SIZE}). " + "Some plots will not be sent." + ) + break + body["plots"][plot_name] = plot_data + + if body["plots"]: + if not _single_post(url, body, token): + return False + + return True + + def post_live_metrics( # noqa: C901 event_type: Literal["start", "data", "done"], baseline_sha: str, @@ -98,6 +182,9 @@ def post_live_metrics( # noqa: C901 plots={ "dvclive/plots/metrics/foo.tsv": { "data": [{"step": 0, "foo": 1.0}] + }, + "dvclive/plots/images/bar.png": { + "image": "base64-string" } } ``` @@ -156,7 +243,6 @@ def post_live_metrics( # noqa: C901 body["step"] = step if plots: body["plots"] = plots - elif event_type == "done": if experiment_rev: body["experiment_rev"] = experiment_rev @@ -172,31 +258,12 @@ def post_live_metrics( # noqa: C901 return None logger.debug(f"post_studio_live_metrics `{event_type=}`") - logger.debug(f"JSON body `{body=}`") path = getenv(STUDIO_ENDPOINT) or "api/live" url = urljoin(config["url"], path) - try: - response = requests.post( - url, - json=body, - headers={ - "Content-type": "application/json", - "Authorization": f"token {config['token']}", - }, - timeout=(30, 5), - ) - except RequestException as e: - logger.warning(f"Failed to post to Studio: {e}") - return False - - message = response.content.decode() - logger.debug( - f"post_to_studio: {response.status_code=}" f", {message=}" if message else "" - ) + token = config["token"] - if response.status_code != 200: - logger.warning(f"Failed to post to Studio: {message}") - return False + if body["type"] != "data" or "plots" not in body: + return _single_post(url, body, token) - return True + return _post_in_chunks(url, body, token) diff --git a/tests/test_post_live_metrics.py b/tests/test_post_live_metrics.py index d0732d1..5cc973b 100644 --- a/tests/test_post_live_metrics.py +++ b/tests/test_post_live_metrics.py @@ -12,6 +12,7 @@ STUDIO_TOKEN, ) from dvc_studio_client.post_live_metrics import ( + MAX_NUMBER_OF_PLOTS, get_studio_token_and_repo_url, post_live_metrics, ) @@ -190,6 +191,7 @@ def test_post_live_metrics_data(mocker, monkeypatch): timeout=(30, 5), ) + mocked_post = mocker.patch("requests.post", return_value=mocked_response) assert post_live_metrics( "data", "f" * 40, @@ -199,25 +201,50 @@ def test_post_live_metrics_data(mocker, monkeypatch): metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}}, plots={"dvclive/plots/metrics/foo.tsv": {"data": [{"step": 0, "foo": 1.0}]}}, ) - mocked_post.assert_called_with( - "https://studio.iterative.ai/api/live", - json={ - "type": "data", - "repo_url": "FOO_REPO_URL", - "baseline_sha": "f" * 40, - "name": "fooname", - "client": "fooclient", - "step": 0, - "metrics": {"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}}, - "plots": { - "dvclive/plots/metrics/foo.tsv": {"data": [{"step": 0, "foo": 1.0}]} - }, - }, - headers={ - "Authorization": "token FOO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), + + assert mocked_post.has_calls( + [ + mocker.call( + "https://studio.iterative.ai/api/live", + json={ + "type": "data", + "repo_url": "FOO_REPO_URL", + "baseline_sha": "f" * 40, + "name": "fooname", + "client": "fooclient", + "step": 0, + "metrics": { + "dvclive/metrics.json": {"data": {"step": 0, "foo": 1}} + }, + }, + headers={ + "Authorization": "token FOO_TOKEN", + "Content-type": "application/json", + }, + timeout=(30, 5), + ), + mocker.call( + "https://studio.iterative.ai/api/live", + json={ + "type": "data", + "repo_url": "FOO_REPO_URL", + "baseline_sha": "f" * 40, + "name": "fooname", + "client": "fooclient", + "step": 0, + "plots": { + "dvclive/plots/metrics/foo.tsv": { + "data": [{"step": 0, "foo": 1.0}] + } + }, + }, + headers={ + "Authorization": "token FOO_TOKEN", + "Content-type": "application/json", + }, + timeout=(30, 5), + ), + ] ) @@ -431,3 +458,136 @@ def test_get_studio_token_and_repo_url_skip_repo_url(monkeypatch): token, repo_url = get_studio_token_and_repo_url() assert token is None assert repo_url is None # Skipped call to get_repo_url + + +def test_post_in_chunks(mocker, monkeypatch): + monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN") + monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL") + + mocked_response = mocker.MagicMock() + mocked_response.status_code = 200 + + mocked_image = mocker.MagicMock("foo") + mocked_image.__len__.return_value = 9000000 + + mocked_post = mocker.patch("requests.post", return_value=mocked_response) + assert post_live_metrics( + "data", + "f" * 40, + "fooname", + "fooclient", + step=0, + metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}}, + plots={"dvclive/plots/images/foo.png": {"image": mocked_image}}, + ) + # 1 call for metrics and params, 1 call for plots + assert mocked_post.call_count == 2 + + # 3.png will not be sent because it exceeds the limit size. + mocked_post = mocker.patch("requests.post", return_value=mocked_response) + assert post_live_metrics( + "data", + "f" * 40, + "fooname", + "fooclient", + step=0, + metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}}, + plots={ + "dvclive/plots/images/0.png": {"image": mocked_image}, + "dvclive/plots/images/1.png": {"image": mocked_image}, + "dvclive/plots/images/2.png": {"image": mocked_image}, + "dvclive/plots/images/3.png": {"image": mocked_image}, + }, + ) + assert mocked_post.call_count == 2 + assert mocked_post.has_calls( + [ + mocker.call( + "https://studio.iterative.ai/api/live", + json={ + "type": "data", + "repo_url": "FOO_REPO_URL", + "baseline_sha": "f" * 40, + "name": "fooname", + "client": "fooclient", + "step": 0, + "metrics": { + "dvclive/metrics.json": {"data": {"step": 0, "foo": 1}} + }, + }, + headers={ + "Authorization": "token FOO_TOKEN", + "Content-type": "application/json", + }, + timeout=(30, 5), + ), + mocker.call( + "https://studio.iterative.ai/api/live", + json={ + "type": "data", + "repo_url": "FOO_REPO_URL", + "baseline_sha": "f" * 40, + "name": "fooname", + "client": "fooclient", + "step": 0, + "plots": {"dvclive/plots/images/foo.png": {"image": mocked_image}}, + }, + headers={ + "Authorization": "token FOO_TOKEN", + "Content-type": "application/json", + }, + timeout=(30, 5), + ), + ] + ) + + +def test_post_in_chunks_skip_large_single_plot(mocker, monkeypatch): + monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN") + monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL") + + mocked_response = mocker.MagicMock() + mocked_response.status_code = 200 + + mocked_image = mocker.MagicMock("foo") + mocked_image.__len__.return_value = 29200000 + + mocked_post = mocker.patch("requests.post", return_value=mocked_response) + assert post_live_metrics( + "data", + "f" * 40, + "fooname", + "fooclient", + step=0, + metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}}, + plots={"dvclive/plots/images/foo.png": {"image": mocked_image}}, + ) + assert mocked_post.call_count == 1 + + +def test_post_in_chunks_max_number_of_plots(mocker, monkeypatch): + monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN") + monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL") + + mocked_response = mocker.MagicMock() + mocked_response.status_code = 200 + + plots = {} + for i in range(MAX_NUMBER_OF_PLOTS + 2): + plots[f"dvclive/plots/images/{i}.png"] = { + "data": [{"step": i, "foo": float(i)}] + } + mocked_post = mocker.patch("requests.post", return_value=mocked_response) + assert post_live_metrics( + "data", + "f" * 40, + "fooname", + "fooclient", + step=0, + metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}}, + plots=plots, + ) + assert mocked_post.call_count == 2 + assert ( + len(mocked_post.call_args_list[-1][1]["json"]["plots"]) == MAX_NUMBER_OF_PLOTS + )