Skip to content

Commit

Permalink
add get_studio_config (#35)
Browse files Browse the repository at this point in the history
* add get_studio_config

* Update src/dvc_studio_client/post_live_metrics.py

Co-authored-by: David de la Iglesia Castro <[email protected]>

* parametrize offline test

* Update src/dvc_studio_client/post_live_metrics.py

* use get_studio_config for get_studio_token_and_repo_url

* check for repo_url in dvc config

* make dvc_studio_config consistent

---------

Co-authored-by: David de la Iglesia Castro <[email protected]>
  • Loading branch information
dberenbaum and daavoo authored May 15, 2023
1 parent 609a3d1 commit b54d601
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 26 deletions.
2 changes: 2 additions & 0 deletions src/dvc_studio_client/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
DVC_STUDIO_CLIENT_LOGLEVEL = "DVC_STUDIO_CLIENT_LOGLEVEL"
DVC_STUDIO_OFFLINE = "DVC_STUDIO_OFFLINE"
DVC_STUDIO_TOKEN = "DVC_STUDIO_TOKEN" # nosec B105
DVC_STUDIO_REPO_URL = "DVC_STUDIO_REPO_URL"
DVC_STUDIO_URL = "DVC_STUDIO_URL"
STUDIO_ENDPOINT = "STUDIO_ENDPOINT"
STUDIO_REPO_URL = "STUDIO_REPO_URL"
Expand Down
124 changes: 109 additions & 15 deletions src/dvc_studio_client/post_live_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
from functools import lru_cache
from os import getenv
from typing import Any, Dict, Literal, Optional
Expand All @@ -11,6 +12,8 @@

from .env import (
DVC_STUDIO_CLIENT_LOGLEVEL,
DVC_STUDIO_OFFLINE,
DVC_STUDIO_REPO_URL,
DVC_STUDIO_TOKEN,
DVC_STUDIO_URL,
STUDIO_ENDPOINT,
Expand Down Expand Up @@ -55,17 +58,99 @@ def get_studio_repo_url() -> Optional[str]:

def get_studio_token_and_repo_url(studio_token=None, studio_repo_url=None):
studio_token = studio_token or getenv(DVC_STUDIO_TOKEN) or getenv(STUDIO_TOKEN)
if studio_token is None:
"""Get studio token and repo_url. Kept for backwards compatibility."""
config = get_studio_config(
studio_token=studio_token, studio_repo_url=studio_repo_url
)
return config["token"], config["repo_url"]


def get_studio_config(
dvc_studio_config: Optional[Dict[str, Any]] = None,
offline: bool = False,
studio_token: Optional[str] = None,
studio_repo_url: Optional[str] = None,
studio_url: Optional[str] = None,
) -> Dict[str, Any]:
"""Get studio config options.
Args:
dvc_studio_config (Optional[dict]): Dict returned by dvc.Repo.config["studio"].
offline (bool): Whether offline mode is enabled. Default: false.
studio_token (Optional[str]): Studio access token obtained from the UI.
studio_repo_url (Optional[str]): URL of the Git repository that has been
imported into Studio UI.
studio_url (Optional[str]): Base URL of Studio UI (if self-hosted).
Returns:
Dict:
Config options for posting live metrics.
Keys match the DVC studio config section.
Example:
{
"token": "mytoken",
"repo_url": "[email protected]:iterative/dvc-studio-client.git",
"url": "https://studio.iterative.ai",
}
"""

config = {}
if not dvc_studio_config:
dvc_studio_config = {}

def to_bool(var):
if var is None:
return False
return bool(re.search("1|y|yes|true", str(var), flags=re.I))

offline = (
offline
or to_bool(getenv(DVC_STUDIO_OFFLINE))
or to_bool(dvc_studio_config.get("offline"))
)
if offline:
logger.debug("Offline mode enabled. Skipping `post_studio_live_metrics`")
return {}

studio_token = (
studio_token
or getenv(DVC_STUDIO_TOKEN)
or getenv(STUDIO_TOKEN)
or dvc_studio_config.get("token")
)
if not studio_token:
logger.debug(
f"{DVC_STUDIO_TOKEN} not found. Skipping `post_studio_live_metrics`"
)
return None, None
return {}
config["token"] = studio_token

studio_repo_url = studio_repo_url or getenv(STUDIO_REPO_URL, None)
studio_repo_url = (
studio_repo_url
or getenv(DVC_STUDIO_REPO_URL)
or getenv(STUDIO_REPO_URL)
or dvc_studio_config.get("repo_url")
)
if studio_repo_url is None:
logger.debug(f"`{STUDIO_REPO_URL}` not found. Trying to automatically find it.")
logger.debug(
f"{DVC_STUDIO_REPO_URL} not found. Trying to automatically find it."
)
studio_repo_url = get_studio_repo_url()
return studio_token, studio_repo_url
if studio_repo_url:
config["repo_url"] = studio_repo_url
else:
logger.debug(
f"{DVC_STUDIO_REPO_URL} not found. Skipping `post_studio_live_metrics`"
)
return {}

studio_url = studio_url or getenv(DVC_STUDIO_URL) or dvc_studio_config.get("url")
if studio_url:
config["url"] = studio_url
else:
logger.debug(f"{DVC_STUDIO_URL} not found. Using {STUDIO_URL}.")
config["url"] = STUDIO_URL

return config


def post_live_metrics( # noqa: C901
Expand All @@ -80,13 +165,16 @@ def post_live_metrics( # noqa: C901
params: Optional[Dict[str, Any]] = None,
plots: Optional[Dict[str, Any]] = None,
step: Optional[int] = None,
dvc_studio_config: Optional[Dict[str, Any]] = None,
offline: bool = False,
studio_token: Optional[str] = None,
studio_repo_url: Optional[str] = None,
studio_url: Optional[str] = None,
) -> Optional[bool]:
"""Post `event_type` to Studio's `api/live`.
Requires the environment variable `STUDIO_TOKEN` to be set.
If the environment variable `STUDIO_REPO_URL` is not set, will attempt to
Requires the environment variable `DVC_STUDIO_TOKEN` to be set.
If the environment variable `DVC_STUDIO_REPO_URL` is not set, will attempt to
infer it from `git ls-remote --get-url`.
Args:
Expand Down Expand Up @@ -144,29 +232,36 @@ def post_live_metrics( # noqa: C901
}
}
```
step: (Optional[int]): Current step of the training loop.
step (Optional[int]): Current step of the training loop.
Usually comes from DVCLive `Live.step` property.
Required in when `event_type="data"`.
Defaults to `None`.
dvc_studio_config (Optional[Dict]): DVC config options for Studio.
offline (bool): Whether offline mode is enabled.
studio_token (Optional[str]): Studio access token obtained from the UI.
studio_repo_url (Optional[str]): URL of the Git repository that has been
imported into Studio UI.
studio_url (Optional[str]): Base URL of Studio UI (if self-hosted).
Returns:
Optional[bool]:
`True` - if received status code 200 from Studio.
`False` - if received other status code or RequestException raised.
`None`- if prerequisites weren't met and the request was not sent.
"""
studio_token, studio_repo_url = get_studio_token_and_repo_url(
studio_token, studio_repo_url
config = get_studio_config(
dvc_studio_config=dvc_studio_config,
offline=offline,
studio_token=studio_token,
studio_repo_url=studio_repo_url,
studio_url=studio_url,
)

if any(x is None for x in (studio_token, studio_repo_url)):
if not config:
return None

body = {
"type": event_type,
"repo_url": studio_repo_url,
"repo_url": config["repo_url"],
"baseline_sha": baseline_sha,
"name": name,
"client": client,
Expand Down Expand Up @@ -210,16 +305,15 @@ def post_live_metrics( # noqa: C901
logger.debug(f"post_studio_live_metrics `{event_type=}`")
logger.debug(f"JSON body `{body=}`")

base_url = getenv(DVC_STUDIO_URL) or STUDIO_URL
path = getenv(STUDIO_ENDPOINT) or "api/live"
url = urljoin(base_url, path)
url = urljoin(config["url"], path)
try:
response = requests.post(
url,
json=body,
headers={
"Content-type": "application/json",
"Authorization": f"token {studio_token}",
"Authorization": f"token {config['token']}",
},
timeout=5,
)
Expand Down
125 changes: 114 additions & 11 deletions tests/test_post_live_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from requests import RequestException

from dvc_studio_client.env import (
DVC_STUDIO_OFFLINE,
DVC_STUDIO_REPO_URL,
DVC_STUDIO_TOKEN,
DVC_STUDIO_URL,
STUDIO_REPO_URL,
STUDIO_TOKEN,
)
from dvc_studio_client.post_live_metrics import (
STUDIO_URL,
_get_remote_url,
get_studio_token_and_repo_url,
get_studio_config,
post_live_metrics,
)

Expand All @@ -26,11 +29,111 @@ def test_get_url(monkeypatch, tmp_path_factory):
assert _get_remote_url() == source


@pytest.mark.parametrize("var", [DVC_STUDIO_TOKEN, STUDIO_TOKEN])
def test_studio_token_envvar(monkeypatch, var):
monkeypatch.setenv(var, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")
assert get_studio_token_and_repo_url() == ("FOO_TOKEN", "FOO_REPO_URL")
@pytest.mark.parametrize(
"token,repo_url",
[(DVC_STUDIO_TOKEN, DVC_STUDIO_REPO_URL), (STUDIO_TOKEN, STUDIO_REPO_URL)],
)
def test_studio_config_envvar(monkeypatch, token, repo_url):
monkeypatch.setenv(token, "FOO_TOKEN")
monkeypatch.setenv(repo_url, "FOO_REPO_URL")
assert get_studio_config() == {
"token": "FOO_TOKEN",
"repo_url": "FOO_REPO_URL",
"url": STUDIO_URL,
}


def test_studio_config_dvc_studio_config():
dvc_studio_config = {
"token": "FOO_TOKEN",
"repo_url": "FOO_REPO_URL",
"url": "FOO_URL",
}
expected = {
"token": "FOO_TOKEN",
"repo_url": "FOO_REPO_URL",
"url": "FOO_URL",
}
assert get_studio_config(dvc_studio_config=dvc_studio_config) == expected


def test_studio_config_kwarg(monkeypatch):
expected = {
"token": "FOO_TOKEN",
"repo_url": "FOO_REPO_URL",
"url": "FOO_URL",
}
assert (
get_studio_config(
studio_token="FOO_TOKEN",
studio_repo_url="FOO_REPO_URL",
studio_url="FOO_URL",
)
== expected
)


def test_studio_config_envvar_override(monkeypatch):
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_URL, "FOO_URL")
monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL")
dvc_studio_config = {
"token": "BAR_TOKEN",
"url": "BAR_URL",
}
expected = {
"token": "FOO_TOKEN",
"repo_url": "FOO_REPO_URL",
"url": "FOO_URL",
}
assert get_studio_config(dvc_studio_config=dvc_studio_config) == expected


def test_studio_config_kwarg_override(monkeypatch):
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL")
monkeypatch.setenv(DVC_STUDIO_URL, "FOO_URL")
expected = {
"token": "BAR_TOKEN",
"repo_url": "BAR_REPO_URL",
"url": "BAR_URL",
}
assert (
get_studio_config(
studio_token="BAR_TOKEN",
studio_repo_url="BAR_REPO_URL",
studio_url="BAR_URL",
)
== expected
)


@pytest.mark.parametrize(
"val",
("1", "y", "yes", "true", True, 1),
)
def test_studio_config_offline(monkeypatch, val):
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL")

assert get_studio_config() != {}

assert get_studio_config(offline=val) == {}

monkeypatch.setenv(DVC_STUDIO_OFFLINE, val)
assert get_studio_config() == {}

monkeypatch.setenv(DVC_STUDIO_OFFLINE, val)
assert get_studio_config() == {}

assert get_studio_config(dvc_studio_config={"offline": True}) == {}


def test_studio_config_infer_url(monkeypatch):
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL")

assert get_studio_config()["url"] == STUDIO_URL


def test_post_live_metrics_skip_on_missing_token(caplog):
Expand All @@ -43,7 +146,8 @@ def test_post_live_metrics_skip_on_missing_token(caplog):

def test_post_live_metrics_skip_on_schema_error(caplog, monkeypatch):
monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN")
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")
monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL")
monkeypatch.setenv(DVC_STUDIO_URL, STUDIO_URL)
with caplog.at_level(logging.DEBUG, logger="dvc_studio_client.post_live_metrics"):
assert post_live_metrics("start", "bad_hash", "fooname", "fooclient") is None
assert caplog.records[0].message == (
Expand Down Expand Up @@ -340,11 +444,10 @@ def test_post_live_metrics_bad_response(mocker, monkeypatch):
)


def test_get_studio_token_and_repo_url_skip_repo_url(monkeypatch):
def test_get_studio_config_skip_repo_url(monkeypatch):
monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL")
token, repo_url = get_studio_token_and_repo_url()
assert token is None
assert repo_url is None # Skipped call to get_repo_url
config = get_studio_config()
assert config == {} # Skipped call to get_repo_url


def test_post_live_metrics_token_and_repo_url_args(mocker, monkeypatch):
Expand Down

0 comments on commit b54d601

Please sign in to comment.