diff --git a/src/dvc_studio_client/env.py b/src/dvc_studio_client/env.py index ee7d28e..c06b0e0 100644 --- a/src/dvc_studio_client/env.py +++ b/src/dvc_studio_client/env.py @@ -1,5 +1,7 @@ DVC_STUDIO_CLIENT_LOGLEVEL = "DVC_STUDIO_CLIENT_LOGLEVEL" +DVC_STUDIO_OFFLINE = "DVC_STUDIO_OFFLINE" DVC_STUDIO_TOKEN = "DVC_STUDIO_TOKEN" # nosec B105 +DVC_STUDIO_REPO_URL = "DVC_STUDIO_REPO_URL" DVC_STUDIO_URL = "DVC_STUDIO_URL" STUDIO_ENDPOINT = "STUDIO_ENDPOINT" STUDIO_REPO_URL = "STUDIO_REPO_URL" diff --git a/src/dvc_studio_client/post_live_metrics.py b/src/dvc_studio_client/post_live_metrics.py index 56059f2..7e4b195 100644 --- a/src/dvc_studio_client/post_live_metrics.py +++ b/src/dvc_studio_client/post_live_metrics.py @@ -1,4 +1,5 @@ import logging +import re from functools import lru_cache from os import getenv from typing import Any, Dict, Literal, Optional @@ -11,6 +12,8 @@ from .env import ( DVC_STUDIO_CLIENT_LOGLEVEL, + DVC_STUDIO_OFFLINE, + DVC_STUDIO_REPO_URL, DVC_STUDIO_TOKEN, DVC_STUDIO_URL, STUDIO_ENDPOINT, @@ -55,17 +58,99 @@ def get_studio_repo_url() -> Optional[str]: def get_studio_token_and_repo_url(studio_token=None, studio_repo_url=None): studio_token = studio_token or getenv(DVC_STUDIO_TOKEN) or getenv(STUDIO_TOKEN) - if studio_token is None: + """Get studio token and repo_url. Kept for backwards compatibility.""" + config = get_studio_config( + studio_token=studio_token, studio_repo_url=studio_repo_url + ) + return config["token"], config["repo_url"] + + +def get_studio_config( + dvc_studio_config: Optional[Dict[str, Any]] = None, + offline: bool = False, + studio_token: Optional[str] = None, + studio_repo_url: Optional[str] = None, + studio_url: Optional[str] = None, +) -> Dict[str, Any]: + """Get studio config options. + + Args: + dvc_studio_config (Optional[dict]): Dict returned by dvc.Repo.config["studio"]. + offline (bool): Whether offline mode is enabled. Default: false. + studio_token (Optional[str]): Studio access token obtained from the UI. + studio_repo_url (Optional[str]): URL of the Git repository that has been + imported into Studio UI. + studio_url (Optional[str]): Base URL of Studio UI (if self-hosted). + Returns: + Dict: + Config options for posting live metrics. + Keys match the DVC studio config section. + Example: + { + "token": "mytoken", + "repo_url": "git@github.com:iterative/dvc-studio-client.git", + "url": "https://studio.iterative.ai", + } + """ + + config = {} + if not dvc_studio_config: + dvc_studio_config = {} + + def to_bool(var): + if var is None: + return False + return bool(re.search("1|y|yes|true", str(var), flags=re.I)) + + offline = ( + offline + or to_bool(getenv(DVC_STUDIO_OFFLINE)) + or to_bool(dvc_studio_config.get("offline")) + ) + if offline: + logger.debug("Offline mode enabled. Skipping `post_studio_live_metrics`") + return {} + + studio_token = ( + studio_token + or getenv(DVC_STUDIO_TOKEN) + or getenv(STUDIO_TOKEN) + or dvc_studio_config.get("token") + ) + if not studio_token: logger.debug( f"{DVC_STUDIO_TOKEN} not found. Skipping `post_studio_live_metrics`" ) - return None, None + return {} + config["token"] = studio_token - studio_repo_url = studio_repo_url or getenv(STUDIO_REPO_URL, None) + studio_repo_url = ( + studio_repo_url + or getenv(DVC_STUDIO_REPO_URL) + or getenv(STUDIO_REPO_URL) + or dvc_studio_config.get("repo_url") + ) if studio_repo_url is None: - logger.debug(f"`{STUDIO_REPO_URL}` not found. Trying to automatically find it.") + logger.debug( + f"{DVC_STUDIO_REPO_URL} not found. Trying to automatically find it." + ) studio_repo_url = get_studio_repo_url() - return studio_token, studio_repo_url + if studio_repo_url: + config["repo_url"] = studio_repo_url + else: + logger.debug( + f"{DVC_STUDIO_REPO_URL} not found. Skipping `post_studio_live_metrics`" + ) + return {} + + studio_url = studio_url or getenv(DVC_STUDIO_URL) or dvc_studio_config.get("url") + if studio_url: + config["url"] = studio_url + else: + logger.debug(f"{DVC_STUDIO_URL} not found. Using {STUDIO_URL}.") + config["url"] = STUDIO_URL + + return config def post_live_metrics( # noqa: C901 @@ -80,13 +165,16 @@ def post_live_metrics( # noqa: C901 params: Optional[Dict[str, Any]] = None, plots: Optional[Dict[str, Any]] = None, step: Optional[int] = None, + dvc_studio_config: Optional[Dict[str, Any]] = None, + offline: bool = False, studio_token: Optional[str] = None, studio_repo_url: Optional[str] = None, + studio_url: Optional[str] = None, ) -> Optional[bool]: """Post `event_type` to Studio's `api/live`. - Requires the environment variable `STUDIO_TOKEN` to be set. - If the environment variable `STUDIO_REPO_URL` is not set, will attempt to + Requires the environment variable `DVC_STUDIO_TOKEN` to be set. + If the environment variable `DVC_STUDIO_REPO_URL` is not set, will attempt to infer it from `git ls-remote --get-url`. Args: @@ -144,29 +232,36 @@ def post_live_metrics( # noqa: C901 } } ``` - step: (Optional[int]): Current step of the training loop. + step (Optional[int]): Current step of the training loop. Usually comes from DVCLive `Live.step` property. Required in when `event_type="data"`. Defaults to `None`. + dvc_studio_config (Optional[Dict]): DVC config options for Studio. + offline (bool): Whether offline mode is enabled. studio_token (Optional[str]): Studio access token obtained from the UI. studio_repo_url (Optional[str]): URL of the Git repository that has been imported into Studio UI. + studio_url (Optional[str]): Base URL of Studio UI (if self-hosted). Returns: Optional[bool]: `True` - if received status code 200 from Studio. `False` - if received other status code or RequestException raised. `None`- if prerequisites weren't met and the request was not sent. """ - studio_token, studio_repo_url = get_studio_token_and_repo_url( - studio_token, studio_repo_url + config = get_studio_config( + dvc_studio_config=dvc_studio_config, + offline=offline, + studio_token=studio_token, + studio_repo_url=studio_repo_url, + studio_url=studio_url, ) - if any(x is None for x in (studio_token, studio_repo_url)): + if not config: return None body = { "type": event_type, - "repo_url": studio_repo_url, + "repo_url": config["repo_url"], "baseline_sha": baseline_sha, "name": name, "client": client, @@ -210,16 +305,15 @@ def post_live_metrics( # noqa: C901 logger.debug(f"post_studio_live_metrics `{event_type=}`") logger.debug(f"JSON body `{body=}`") - base_url = getenv(DVC_STUDIO_URL) or STUDIO_URL path = getenv(STUDIO_ENDPOINT) or "api/live" - url = urljoin(base_url, path) + url = urljoin(config["url"], path) try: response = requests.post( url, json=body, headers={ "Content-type": "application/json", - "Authorization": f"token {studio_token}", + "Authorization": f"token {config['token']}", }, timeout=5, ) diff --git a/tests/test_post_live_metrics.py b/tests/test_post_live_metrics.py index a781449..6c817fd 100644 --- a/tests/test_post_live_metrics.py +++ b/tests/test_post_live_metrics.py @@ -6,14 +6,17 @@ from requests import RequestException from dvc_studio_client.env import ( + DVC_STUDIO_OFFLINE, + DVC_STUDIO_REPO_URL, DVC_STUDIO_TOKEN, DVC_STUDIO_URL, STUDIO_REPO_URL, STUDIO_TOKEN, ) from dvc_studio_client.post_live_metrics import ( + STUDIO_URL, _get_remote_url, - get_studio_token_and_repo_url, + get_studio_config, post_live_metrics, ) @@ -26,11 +29,111 @@ def test_get_url(monkeypatch, tmp_path_factory): assert _get_remote_url() == source -@pytest.mark.parametrize("var", [DVC_STUDIO_TOKEN, STUDIO_TOKEN]) -def test_studio_token_envvar(monkeypatch, var): - monkeypatch.setenv(var, "FOO_TOKEN") - monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL") - assert get_studio_token_and_repo_url() == ("FOO_TOKEN", "FOO_REPO_URL") +@pytest.mark.parametrize( + "token,repo_url", + [(DVC_STUDIO_TOKEN, DVC_STUDIO_REPO_URL), (STUDIO_TOKEN, STUDIO_REPO_URL)], +) +def test_studio_config_envvar(monkeypatch, token, repo_url): + monkeypatch.setenv(token, "FOO_TOKEN") + monkeypatch.setenv(repo_url, "FOO_REPO_URL") + assert get_studio_config() == { + "token": "FOO_TOKEN", + "repo_url": "FOO_REPO_URL", + "url": STUDIO_URL, + } + + +def test_studio_config_dvc_studio_config(): + dvc_studio_config = { + "token": "FOO_TOKEN", + "repo_url": "FOO_REPO_URL", + "url": "FOO_URL", + } + expected = { + "token": "FOO_TOKEN", + "repo_url": "FOO_REPO_URL", + "url": "FOO_URL", + } + assert get_studio_config(dvc_studio_config=dvc_studio_config) == expected + + +def test_studio_config_kwarg(monkeypatch): + expected = { + "token": "FOO_TOKEN", + "repo_url": "FOO_REPO_URL", + "url": "FOO_URL", + } + assert ( + get_studio_config( + studio_token="FOO_TOKEN", + studio_repo_url="FOO_REPO_URL", + studio_url="FOO_URL", + ) + == expected + ) + + +def test_studio_config_envvar_override(monkeypatch): + monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN") + monkeypatch.setenv(DVC_STUDIO_URL, "FOO_URL") + monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL") + dvc_studio_config = { + "token": "BAR_TOKEN", + "url": "BAR_URL", + } + expected = { + "token": "FOO_TOKEN", + "repo_url": "FOO_REPO_URL", + "url": "FOO_URL", + } + assert get_studio_config(dvc_studio_config=dvc_studio_config) == expected + + +def test_studio_config_kwarg_override(monkeypatch): + monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN") + monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL") + monkeypatch.setenv(DVC_STUDIO_URL, "FOO_URL") + expected = { + "token": "BAR_TOKEN", + "repo_url": "BAR_REPO_URL", + "url": "BAR_URL", + } + assert ( + get_studio_config( + studio_token="BAR_TOKEN", + studio_repo_url="BAR_REPO_URL", + studio_url="BAR_URL", + ) + == expected + ) + + +@pytest.mark.parametrize( + "val", + ("1", "y", "yes", "true", True, 1), +) +def test_studio_config_offline(monkeypatch, val): + monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN") + monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL") + + assert get_studio_config() != {} + + assert get_studio_config(offline=val) == {} + + monkeypatch.setenv(DVC_STUDIO_OFFLINE, val) + assert get_studio_config() == {} + + monkeypatch.setenv(DVC_STUDIO_OFFLINE, val) + assert get_studio_config() == {} + + assert get_studio_config(dvc_studio_config={"offline": True}) == {} + + +def test_studio_config_infer_url(monkeypatch): + monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN") + monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL") + + assert get_studio_config()["url"] == STUDIO_URL def test_post_live_metrics_skip_on_missing_token(caplog): @@ -43,7 +146,8 @@ def test_post_live_metrics_skip_on_missing_token(caplog): def test_post_live_metrics_skip_on_schema_error(caplog, monkeypatch): monkeypatch.setenv(DVC_STUDIO_TOKEN, "FOO_TOKEN") - monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL") + monkeypatch.setenv(DVC_STUDIO_REPO_URL, "FOO_REPO_URL") + monkeypatch.setenv(DVC_STUDIO_URL, STUDIO_URL) with caplog.at_level(logging.DEBUG, logger="dvc_studio_client.post_live_metrics"): assert post_live_metrics("start", "bad_hash", "fooname", "fooclient") is None assert caplog.records[0].message == ( @@ -340,11 +444,10 @@ def test_post_live_metrics_bad_response(mocker, monkeypatch): ) -def test_get_studio_token_and_repo_url_skip_repo_url(monkeypatch): +def test_get_studio_config_skip_repo_url(monkeypatch): monkeypatch.setenv(STUDIO_REPO_URL, "FOO_REPO_URL") - token, repo_url = get_studio_token_and_repo_url() - assert token is None - assert repo_url is None # Skipped call to get_repo_url + config = get_studio_config() + assert config == {} # Skipped call to get_repo_url def test_post_live_metrics_token_and_repo_url_args(mocker, monkeypatch):