Skip to content

Commit

Permalink
add support for DVC_STUDIO_TOKEN and DVC_STUDIO_URL envvar (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Apr 18, 2023
1 parent fd9aac6 commit d36060c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
2 changes: 2 additions & 0 deletions src/dvc_studio_client/env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
DVC_STUDIO_CLIENT_LOGLEVEL = "DVC_STUDIO_CLIENT_LOGLEVEL"
DVC_STUDIO_TOKEN = "DVC_STUDIO_TOKEN" # nosec B105
DVC_STUDIO_URL = "DVC_STUDIO_URL"
STUDIO_ENDPOINT = "STUDIO_ENDPOINT"
STUDIO_REPO_URL = "STUDIO_REPO_URL"
STUDIO_TOKEN = "STUDIO_TOKEN" # nosec B105
16 changes: 13 additions & 3 deletions src/dvc_studio_client/post_live_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import lru_cache
from os import getenv
from typing import Any, Dict, Literal, Optional
from urllib.parse import urljoin

import requests
from requests.exceptions import RequestException
Expand All @@ -10,12 +11,16 @@

from .env import (
DVC_STUDIO_CLIENT_LOGLEVEL,
DVC_STUDIO_TOKEN,
DVC_STUDIO_URL,
STUDIO_ENDPOINT,
STUDIO_REPO_URL,
STUDIO_TOKEN,
)
from .schema import SCHEMAS_BY_TYPE

STUDIO_URL = "https://studio.iterative.ai"

logger = logging.getLogger(__name__)
logger.setLevel(getenv(DVC_STUDIO_CLIENT_LOGLEVEL, "INFO").upper())

Expand Down Expand Up @@ -49,9 +54,11 @@ def get_studio_repo_url() -> Optional[str]:


def get_studio_token_and_repo_url():
studio_token = getenv(STUDIO_TOKEN, None)
studio_token = getenv(DVC_STUDIO_TOKEN) or getenv(STUDIO_TOKEN)
if studio_token is None:
logger.debug("STUDIO_TOKEN not found. Skipping `post_studio_live_metrics`")
logger.debug(
f"{DVC_STUDIO_TOKEN} not found. Skipping `post_studio_live_metrics`"
)
return None, None

studio_repo_url = getenv(STUDIO_REPO_URL, None)
Expand Down Expand Up @@ -175,9 +182,12 @@ def post_live_metrics(
logger.debug(f"post_studio_live_metrics `{event_type=}`")
logger.debug(f"JSON body `{body=}`")

base_url = getenv(DVC_STUDIO_URL) or STUDIO_URL
path = getenv(STUDIO_ENDPOINT) or "api/live"
url = urljoin(base_url, path)
try:
response = requests.post(
getenv(STUDIO_ENDPOINT, "https://studio.iterative.ai/api/live"),
url,
json=body,
headers={
"Content-type": "application/json",
Expand Down
35 changes: 24 additions & 11 deletions tests/test_post_live_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import logging
import os

import pytest
from dulwich.porcelain import clone, init
from requests import RequestException

from dvc_studio_client.env import STUDIO_ENDPOINT, STUDIO_REPO_URL, STUDIO_TOKEN
from dvc_studio_client.env import (
DVC_STUDIO_TOKEN,
DVC_STUDIO_URL,
STUDIO_REPO_URL,
STUDIO_TOKEN,
)
from dvc_studio_client.post_live_metrics import (
_get_remote_url,
get_studio_token_and_repo_url,
Expand All @@ -20,16 +26,23 @@ def test_get_url(monkeypatch, tmp_path_factory):
assert _get_remote_url() == source


@pytest.mark.parametrize("var", [DVC_STUDIO_TOKEN, STUDIO_TOKEN])
def test_studio_token_envvar(monkeypatch, var):
monkeypatch.setenv(var, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")
assert get_studio_token_and_repo_url() == ("FOO_TOKEN", "FOO_REPO_URL")


def test_post_live_metrics_skip_on_missing_token(caplog):
with caplog.at_level(logging.DEBUG, logger="dvc_studio_client.post_live_metrics"):
assert post_live_metrics("start", "current_rev", "fooname", "fooclient") is None
assert caplog.records[0].message == (
"STUDIO_TOKEN not found. Skipping `post_studio_live_metrics`"
"DVC_STUDIO_TOKEN not found. Skipping `post_studio_live_metrics`"
)


def test_post_live_metrics_skip_on_schema_error(caplog, monkeypatch):
monkeypatch.setenv(STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")
with caplog.at_level(logging.DEBUG, logger="dvc_studio_client.post_live_metrics"):
assert post_live_metrics("start", "bad_hash", "fooname", "fooclient") is None
Expand All @@ -40,8 +53,8 @@ def test_post_live_metrics_skip_on_schema_error(caplog, monkeypatch):


def test_post_live_metrics_start_event(mocker, monkeypatch):
monkeypatch.setenv(STUDIO_ENDPOINT, "https://0.0.0.0")
monkeypatch.setenv(STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_URL, "https://0.0.0.0")
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")

mocked_response = mocker.MagicMock()
Expand All @@ -56,7 +69,7 @@ def test_post_live_metrics_start_event(mocker, monkeypatch):
)

mocked_post.assert_called_with(
"https://0.0.0.0",
"https://0.0.0.0/api/live",
json={
"type": "start",
"repo_url": "FOO_REPO_URL",
Expand All @@ -80,7 +93,7 @@ def test_post_live_metrics_start_event(mocker, monkeypatch):
)

mocked_post.assert_called_with(
"https://0.0.0.0",
"https://0.0.0.0/api/live",
json={
"type": "start",
"repo_url": "FOO_REPO_URL",
Expand All @@ -98,15 +111,15 @@ def test_post_live_metrics_start_event(mocker, monkeypatch):


def test_post_live_metrics_data_skip_if_no_step(caplog, monkeypatch):
monkeypatch.setenv(STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")

assert post_live_metrics("data", "f" * 40, "fooname", "fooclient") is None
assert caplog.records[0].message == ("Missing `step` in `data` event.")


def test_post_live_metrics_data(mocker, monkeypatch):
monkeypatch.setenv(STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")

mocked_response = mocker.MagicMock()
Expand Down Expand Up @@ -189,7 +202,7 @@ def test_post_live_metrics_data(mocker, monkeypatch):


def test_post_live_metrics_done(mocker, monkeypatch):
monkeypatch.setenv(STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")

mocked_response = mocker.MagicMock()
Expand Down Expand Up @@ -264,7 +277,7 @@ def test_post_live_metrics_done(mocker, monkeypatch):


def test_post_live_metrics_bad_response(mocker, monkeypatch):
monkeypatch.setenv(STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")

mocked_response = mocker.MagicMock()
Expand Down

0 comments on commit d36060c

Please sign in to comment.