Skip to content

Commit

Permalink
post_live_metrics: Add _post_in_chunks. (#80)
Browse files Browse the repository at this point in the history
* post_live_metrics: Add `_post_in_chunks`.

* post_live_metrics: Add MAX_NUMBER_OF_PLOTS
  • Loading branch information
daavoo authored Sep 1, 2023
1 parent 0e23376 commit 8d15544
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 43 deletions.
115 changes: 91 additions & 24 deletions src/dvc_studio_client/post_live_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from os import getenv
from typing import Any, Dict, Literal, Optional
from urllib.parse import urljoin
Expand All @@ -12,6 +13,13 @@
from .env import DVC_STUDIO_TOKEN, STUDIO_ENDPOINT, STUDIO_TOKEN
from .schema import SCHEMAS_BY_TYPE

# Studio PROD and DEV have a hardcoded limit of 30MB for the request body.
MAX_REQUEST_SIZE = 29000000
# Studio backend discards files larger than 10MB as big files when parsing commits.
MAX_PLOT_SIZE = 10000000
# Studio backend limit number of files inside a plot directory.
MAX_NUMBER_OF_PLOTS = 200


def get_studio_token_and_repo_url(studio_token=None, studio_repo_url=None):
studio_token = studio_token or getenv(DVC_STUDIO_TOKEN) or getenv(STUDIO_TOKEN)
Expand All @@ -22,6 +30,82 @@ def get_studio_token_and_repo_url(studio_token=None, studio_repo_url=None):
return config.get("token"), config.get("repo_url")


def _single_post(url, body, token):
try:
response = requests.post(
url,
json=body,
headers={
"Content-type": "application/json",
"Authorization": f"token {token}",
},
timeout=(30, 5),
)
except RequestException as e:
logger.warning(f"Failed to post to Studio: {e}")
return False

message = response.content.decode()
logger.debug(
f"post_to_studio: {response.status_code=}" f", {message=}" if message else ""
)

if response.status_code != 200:
logger.warning(f"Failed to post to Studio: {message}")
return False

return True


def _post_in_chunks(url, body, token):
plots = body.pop("plots")

# First, post only metrics and params
if not _single_post(url, body, token):
return False
body.pop("metrics", None)
body.pop("params", None)

# Studio backend has a limitation on the size of the request body.
# So we try to send as many plots as possible without xeceeding the limit.
body["plots"] = {}
total_size = 0
for n, (plot_name, plot_data) in enumerate(plots.items()):
if n >= MAX_NUMBER_OF_PLOTS:
logger.warning(
f"Number of plots exceeds Studio limit ({MAX_NUMBER_OF_PLOTS}). "
"Some plots will not be sent."
)
break

if "data" in plot_data:
size = len(json.dumps(plot_data["data"]).encode("utf-8"))
elif "image" in plot_data:
size = len(plot_data["image"])

if size > MAX_PLOT_SIZE:
logger.warning(
f"Size of plot exceeds Studio limit ({MAX_PLOT_SIZE}). "
f"{plot_name} will not be sent."
)
continue

total_size += size
if total_size > MAX_REQUEST_SIZE:
logger.warning(
f"Total size of plots exceeds Studio limit ({MAX_REQUEST_SIZE}). "
"Some plots will not be sent."
)
break
body["plots"][plot_name] = plot_data

if body["plots"]:
if not _single_post(url, body, token):
return False

return True


def post_live_metrics( # noqa: C901
event_type: Literal["start", "data", "done"],
baseline_sha: str,
Expand Down Expand Up @@ -98,6 +182,9 @@ def post_live_metrics( # noqa: C901
plots={
"dvclive/plots/metrics/foo.tsv": {
"data": [{"step": 0, "foo": 1.0}]
},
"dvclive/plots/images/bar.png": {
"image": "base64-string"
}
}
```
Expand Down Expand Up @@ -156,7 +243,6 @@ def post_live_metrics( # noqa: C901
body["step"] = step
if plots:
body["plots"] = plots

elif event_type == "done":
if experiment_rev:
body["experiment_rev"] = experiment_rev
Expand All @@ -172,31 +258,12 @@ def post_live_metrics( # noqa: C901
return None

logger.debug(f"post_studio_live_metrics `{event_type=}`")
logger.debug(f"JSON body `{body=}`")

path = getenv(STUDIO_ENDPOINT) or "api/live"
url = urljoin(config["url"], path)
try:
response = requests.post(
url,
json=body,
headers={
"Content-type": "application/json",
"Authorization": f"token {config['token']}",
},
timeout=(30, 5),
)
except RequestException as e:
logger.warning(f"Failed to post to Studio: {e}")
return False

message = response.content.decode()
logger.debug(
f"post_to_studio: {response.status_code=}" f", {message=}" if message else ""
)
token = config["token"]

if response.status_code != 200:
logger.warning(f"Failed to post to Studio: {message}")
return False
if body["type"] != "data" or "plots" not in body:
return _single_post(url, body, token)

return True
return _post_in_chunks(url, body, token)
198 changes: 179 additions & 19 deletions tests/test_post_live_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
STUDIO_TOKEN,
)
from dvc_studio_client.post_live_metrics import (
MAX_NUMBER_OF_PLOTS,
get_studio_token_and_repo_url,
post_live_metrics,
)
Expand Down Expand Up @@ -190,6 +191,7 @@ def test_post_live_metrics_data(mocker, monkeypatch):
timeout=(30, 5),
)

mocked_post = mocker.patch("requests.post", return_value=mocked_response)
assert post_live_metrics(
"data",
"f" * 40,
Expand All @@ -199,25 +201,50 @@ def test_post_live_metrics_data(mocker, monkeypatch):
metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}},
plots={"dvclive/plots/metrics/foo.tsv": {"data": [{"step": 0, "foo": 1.0}]}},
)
mocked_post.assert_called_with(
"https://studio.iterative.ai/api/live",
json={
"type": "data",
"repo_url": "FOO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "fooname",
"client": "fooclient",
"step": 0,
"metrics": {"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}},
"plots": {
"dvclive/plots/metrics/foo.tsv": {"data": [{"step": 0, "foo": 1.0}]}
},
},
headers={
"Authorization": "token FOO_TOKEN",
"Content-type": "application/json",
},
timeout=(30, 5),

assert mocked_post.has_calls(
[
mocker.call(
"https://studio.iterative.ai/api/live",
json={
"type": "data",
"repo_url": "FOO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "fooname",
"client": "fooclient",
"step": 0,
"metrics": {
"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}
},
},
headers={
"Authorization": "token FOO_TOKEN",
"Content-type": "application/json",
},
timeout=(30, 5),
),
mocker.call(
"https://studio.iterative.ai/api/live",
json={
"type": "data",
"repo_url": "FOO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "fooname",
"client": "fooclient",
"step": 0,
"plots": {
"dvclive/plots/metrics/foo.tsv": {
"data": [{"step": 0, "foo": 1.0}]
}
},
},
headers={
"Authorization": "token FOO_TOKEN",
"Content-type": "application/json",
},
timeout=(30, 5),
),
]
)


Expand Down Expand Up @@ -431,3 +458,136 @@ def test_get_studio_token_and_repo_url_skip_repo_url(monkeypatch):
token, repo_url = get_studio_token_and_repo_url()
assert token is None
assert repo_url is None # Skipped call to get_repo_url


def test_post_in_chunks(mocker, monkeypatch):
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")

mocked_response = mocker.MagicMock()
mocked_response.status_code = 200

mocked_image = mocker.MagicMock("foo")
mocked_image.__len__.return_value = 9000000

mocked_post = mocker.patch("requests.post", return_value=mocked_response)
assert post_live_metrics(
"data",
"f" * 40,
"fooname",
"fooclient",
step=0,
metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}},
plots={"dvclive/plots/images/foo.png": {"image": mocked_image}},
)
# 1 call for metrics and params, 1 call for plots
assert mocked_post.call_count == 2

# 3.png will not be sent because it exceeds the limit size.
mocked_post = mocker.patch("requests.post", return_value=mocked_response)
assert post_live_metrics(
"data",
"f" * 40,
"fooname",
"fooclient",
step=0,
metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}},
plots={
"dvclive/plots/images/0.png": {"image": mocked_image},
"dvclive/plots/images/1.png": {"image": mocked_image},
"dvclive/plots/images/2.png": {"image": mocked_image},
"dvclive/plots/images/3.png": {"image": mocked_image},
},
)
assert mocked_post.call_count == 2
assert mocked_post.has_calls(
[
mocker.call(
"https://studio.iterative.ai/api/live",
json={
"type": "data",
"repo_url": "FOO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "fooname",
"client": "fooclient",
"step": 0,
"metrics": {
"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}
},
},
headers={
"Authorization": "token FOO_TOKEN",
"Content-type": "application/json",
},
timeout=(30, 5),
),
mocker.call(
"https://studio.iterative.ai/api/live",
json={
"type": "data",
"repo_url": "FOO_REPO_URL",
"baseline_sha": "f" * 40,
"name": "fooname",
"client": "fooclient",
"step": 0,
"plots": {"dvclive/plots/images/foo.png": {"image": mocked_image}},
},
headers={
"Authorization": "token FOO_TOKEN",
"Content-type": "application/json",
},
timeout=(30, 5),
),
]
)


def test_post_in_chunks_skip_large_single_plot(mocker, monkeypatch):
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")

mocked_response = mocker.MagicMock()
mocked_response.status_code = 200

mocked_image = mocker.MagicMock("foo")
mocked_image.__len__.return_value = 29200000

mocked_post = mocker.patch("requests.post", return_value=mocked_response)
assert post_live_metrics(
"data",
"f" * 40,
"fooname",
"fooclient",
step=0,
metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}},
plots={"dvclive/plots/images/foo.png": {"image": mocked_image}},
)
assert mocked_post.call_count == 1


def test_post_in_chunks_max_number_of_plots(mocker, monkeypatch):
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")

mocked_response = mocker.MagicMock()
mocked_response.status_code = 200

plots = {}
for i in range(MAX_NUMBER_OF_PLOTS + 2):
plots[f"dvclive/plots/images/{i}.png"] = {
"data": [{"step": i, "foo": float(i)}]
}
mocked_post = mocker.patch("requests.post", return_value=mocked_response)
assert post_live_metrics(
"data",
"f" * 40,
"fooname",
"fooclient",
step=0,
metrics={"dvclive/metrics.json": {"data": {"step": 0, "foo": 1}}},
plots=plots,
)
assert mocked_post.call_count == 2
assert (
len(mocked_post.call_args_list[-1][1]["json"]["plots"]) == MAX_NUMBER_OF_PLOTS
)

0 comments on commit 8d15544

Please sign in to comment.