diff --git a/README.md b/README.md index 77f9b33..5dc1e40 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,25 @@ ucode configure --workspaces https://first.databricks.com,https://second.databri When multiple workspaces are provided, `ucode` logs into and saves state for each workspace. Launch commands such as `ucode codex` use the first workspace in the list. +### Custom Claude model endpoints + +By default `ucode` discovers Claude models named `databricks-claude--` on the +AI Gateway anthropic route and picks the newest per family (`opus`, `sonnet`, `haiku`). If your +workspace exposes Claude through custom serving endpoints — for example Bedrock-backed models +named `acme-bedrock-claude-opus-4-8` — pass `--model-prefix` with the shared prefix before the +family token: + +```bash +ucode configure --agents claude --model-prefix acme-bedrock-claude- +``` + +This scopes discovery to those endpoints (the built-in `databricks-claude-*` hosted models no +longer match the prefix, so they are excluded), picks the newest version per family as the +default, and keeps every matching endpoint selectable. `ucode` writes the matching set to Claude +Code's `availableModels` allowlist so the `/model` picker is restricted to them, and persists the +prefix to state so later `ucode configure` runs reuse it without the flag. `ucode status` shows +the resolved prefix, per-family defaults, and the number of selectable models. + ### MCP servers (optional) ```bash diff --git a/src/ucode/agents/claude.py b/src/ucode/agents/claude.py index 0fe4ec7..2f79a89 100644 --- a/src/ucode/agents/claude.py +++ b/src/ucode/agents/claude.py @@ -17,6 +17,7 @@ write_json_file, ) from ucode.databricks import ( + CLAUDE_FAMILIES, build_auth_shell_command, build_tool_base_url, get_databricks_token, @@ -81,12 +82,38 @@ def _web_search_mcp_entry(workspace: str, search_model: str, profile: str | None } +def _build_available_models(claude_models: dict[str, str], claude_allowed: list[str]) -> list[str]: + """Allowlist written to Claude Code's ``availableModels`` setting. + + Combines the discovered endpoint ids (every allowed version) with the + family aliases (``opus``/``sonnet``/``haiku``) that have a default, so the + friendly family picker entries stay visible and resolve through the + ``ANTHROPIC_DEFAULT_*_MODEL`` env vars. + + Including aliases intentionally loosens strictness — Claude Code's matcher + treats an ``opus`` alias as "any opus version", so a same-family model would + pass the allowlist. We accept that to keep the default picker entries usable; + the Databricks gateway remains the hard backstop. With an empty allowlist we + fall back to the flat defaults' values (old state shape / back-compat). + """ + ids = [m for m in claude_allowed if isinstance(m, str) and m] + if not ids: + ids = [v for v in claude_models.values() if isinstance(v, str) and v] + aliases = [family for family in CLAUDE_FAMILIES if claude_models.get(family)] + out: list[str] = [] + for item in [*aliases, *sorted(set(ids))]: + if item not in out: + out.append(item) + return out + + def render_overlay( workspace: str, model: str, claude_models: dict[str, str] | None = None, disable_web_search: bool = False, profile: str | None = None, + claude_allowed: list[str] | None = None, ) -> tuple[dict, list[list[str]]]: """Return (overlay, managed_key_paths) for Claude settings.json. @@ -121,6 +148,14 @@ def render_overlay( overlay: dict = {"apiKeyHelper": build_auth_shell_command(workspace, profile), "env": env} keys: list[list[str]] = [["apiKeyHelper"]] + [["env", k] for k in env] + # Restrict the `/model` picker to the discovered models via Claude Code's + # `availableModels` allowlist, so hosted models that don't match the + # configured prefix aren't selectable. + available_models = _build_available_models(claude_models or {}, claude_allowed or []) + if available_models: + overlay["availableModels"] = available_models + keys.append(["availableModels"]) + # Disable Claude Code's built-in WebSearch (it routes through Anthropic's # hosted infra and fails through the Databricks gateway). The replacement # `web_search` MCP server is registered separately via the claude CLI. @@ -197,6 +232,7 @@ def write_tool_config(state: dict, model: str) -> dict: state.get("claude_models") or {}, disable_web_search=web_search_model is not None, profile=state.get("profile"), + claude_allowed=state.get("claude_allowed_models") or [], ) existing = read_json_safe(CLAUDE_SETTINGS_PATH) merged = deep_merge_dict(existing, overlay) diff --git a/src/ucode/cli.py b/src/ucode/cli.py index 5f11f09..f72d97b 100644 --- a/src/ucode/cli.py +++ b/src/ucode/cli.py @@ -28,6 +28,8 @@ from ucode.agents.pi import PI_SETTINGS_BACKUP_PATH, PI_SETTINGS_PATH from ucode.config_io import restore_file, set_dry_run from ucode.databricks import ( + CLAUDE_FAMILIES, + DEFAULT_CLAUDE_MODEL_PREFIX, build_shared_base_urls, discover_claude_models, discover_codex_models, @@ -39,6 +41,7 @@ get_databricks_token, install_databricks_cli, normalize_workspace_url, + resolve_claude_model_prefix, run_databricks_login, ) from ucode.mcp import ( @@ -47,7 +50,7 @@ purge_cross_workspace_mcp_residue, revert_mcp_configs, ) -from ucode.state import STATE_PATH, clear_state, load_state, save_state +from ucode.state import STATE_PATH, clear_state, load_full_state, load_state, save_state from ucode.ui import ( console, heading, @@ -146,6 +149,7 @@ def configure_shared_state( profile: str | None = None, tools: list[str] | None = None, force_login: bool = False, + model_prefix: str | None = None, ) -> dict: """Log into Databricks, enforce AI Gateway v2, fetch model lists, persist state. @@ -178,12 +182,22 @@ def configure_shared_state( want_gemini = fetch_all or "gemini" in tools or "opencode" in tools or "pi" in tools want_codex = fetch_all or "codex" in tools or "copilot" in tools or "pi" in tools + # Resolve the Claude endpoint name prefix (--model-prefix flag > this + # workspace's persisted value > built-in default) so discovery scopes to the + # right family of endpoints — e.g. Bedrock-backed `acme-bedrock-claude-*` + # instead of hosted `databricks-claude-*`. + existing_ws_state = load_full_state().get("workspaces", {}).get(workspace, {}) + claude_model_prefix = resolve_claude_model_prefix(existing_ws_state, override=model_prefix) + claude_reason: str | None = None gemini_reason: str | None = None codex_reason: str | None = None + claude_allowed: list[str] = [] with spinner("Fetching available models..."): if want_claude: - claude_models, claude_reason = discover_claude_models(workspace, token) + claude_models, claude_allowed, claude_reason = discover_claude_models( + workspace, token, claude_model_prefix + ) else: claude_models = {} if want_gemini: @@ -196,7 +210,7 @@ def configure_shared_state( codex_models = [] opencode_models: dict[str, list[str]] = {} if claude_models: - opencode_models["anthropic"] = list(claude_models.values()) + opencode_models["anthropic"] = claude_allowed or list(claude_models.values()) if gemini_models: opencode_models["gemini"] = gemini_models @@ -210,6 +224,8 @@ def configure_shared_state( state["base_urls"] = build_shared_base_urls(workspace) if want_claude: state["claude_models"] = claude_models + state["claude_allowed_models"] = claude_allowed + state["claude_model_prefix"] = claude_model_prefix if want_gemini: state["gemini_models"] = gemini_models if want_codex: @@ -236,13 +252,20 @@ def _configure_shared_workspace_states( tools: list[str] | None, *, force_login: bool, + model_prefix: str | None = None, ) -> list[dict]: if not workspaces: raise RuntimeError("At least one workspace must be provided.") states: list[dict] = [] for workspace, profile in workspaces: states.append( - configure_shared_state(workspace, profile=profile, tools=tools, force_login=force_login) + configure_shared_state( + workspace, + profile=profile, + tools=tools, + force_login=force_login, + model_prefix=model_prefix, + ) ) return states @@ -251,6 +274,7 @@ def configure_workspace_command( tool: str | None = None, selected_tools: list[str] | None = None, workspaces: list[tuple[str, str | None]] | None = None, + model_prefix: str | None = None, ) -> int: if tool is not None and selected_tools is not None: raise RuntimeError("Use either --agent or --agents, not both.") @@ -258,7 +282,9 @@ def configure_workspace_command( workspace_entries = workspaces or [_prompt_for_configuration(tool)] if tool is not None: - states = _configure_shared_workspace_states(workspace_entries, [tool], force_login=True) + states = _configure_shared_workspace_states( + workspace_entries, [tool], force_login=True, model_prefix=model_prefix + ) state = states[0] state = configure_single_tool(tool, state) spec = TOOL_SPECS[tool] @@ -285,7 +311,9 @@ def configure_workspace_command( raise RuntimeError(f"{spec['display']} validation failed — config reverted.") return 0 - states = _configure_shared_workspace_states(workspace_entries, selected_tools, force_login=True) + states = _configure_shared_workspace_states( + workspace_entries, selected_tools, force_login=True, model_prefix=model_prefix + ) state = states[0] save_state(state) @@ -362,6 +390,23 @@ def status() -> int: if profile: print_kv("CLI profile", profile) + claude_models = state.get("claude_models") or {} + if claude_models: + print_heading("Claude Models") + prefix = state.get("claude_model_prefix") + if isinstance(prefix, str) and prefix and prefix != DEFAULT_CLAUDE_MODEL_PREFIX: + print_kv("Endpoint prefix", prefix) + allowed = [m for m in (state.get("claude_allowed_models") or []) if isinstance(m, str)] + for family in CLAUDE_FAMILIES: + default = claude_models.get(family) + if not default: + continue + extras = [m for m in allowed if m != default and f"-{family}-" in m] + label = f"{default} (+{len(extras)} more)" if extras else default + print_kv(family.capitalize(), label) + if allowed: + print_kv("Selectable models", str(len(allowed))) + print_heading("Coding Agents") for tool, spec in TOOL_SPECS.items(): configured = tool in configured_tools @@ -591,6 +636,17 @@ def configure( help="Configure a comma-separated list of workspaces without prompting.", ), ] = None, + model_prefix: Annotated[ + str | None, + typer.Option( + "--model-prefix", + help=( + "Prefix for custom Claude endpoint names before the family token " + "(e.g. acme-bedrock-claude-). Scopes discovery to those endpoints and " + "excludes hosted databricks-claude-* models. Persisted for later runs." + ), + ), + ] = None, ) -> None: """Configure workspace URL and AI Gateway.""" if ctx.invoked_subcommand is not None: @@ -605,24 +661,30 @@ def configure( tool = normalize_tool(agent) install_tool_binary(tool, strict=True, update_existing=True) if workspace_entries is None: - configure_workspace_command(tool) + configure_workspace_command(tool, model_prefix=model_prefix) else: - configure_workspace_command(tool, workspaces=workspace_entries) + configure_workspace_command( + tool, workspaces=workspace_entries, model_prefix=model_prefix + ) elif agents is not None: selected_tools = _parse_agents_option(agents) if workspace_entries is None: - configure_workspace_command(selected_tools=selected_tools) + configure_workspace_command( + selected_tools=selected_tools, model_prefix=model_prefix + ) else: configure_workspace_command( - selected_tools=selected_tools, workspaces=workspace_entries + selected_tools=selected_tools, + workspaces=workspace_entries, + model_prefix=model_prefix, ) else: # Tool binaries are installed after the user picks which agents # they want, in configure_workspace_command. if workspace_entries is None: - configure_workspace_command() + configure_workspace_command(model_prefix=model_prefix) else: - configure_workspace_command(workspaces=workspace_entries) + configure_workspace_command(workspaces=workspace_entries, model_prefix=model_prefix) except RuntimeError as exc: print_err(str(exc)) raise typer.Exit(1) from None diff --git a/src/ucode/databricks.py b/src/ucode/databricks.py index 5ed537c..e5c2474 100644 --- a/src/ucode/databricks.py +++ b/src/ucode/databricks.py @@ -804,17 +804,53 @@ def build_auth_shell_command(workspace: str, profile: str | None = None) -> str: ) -def discover_claude_models(workspace: str, token: str) -> tuple[dict[str, str], str | None]: +DEFAULT_CLAUDE_MODEL_PREFIX = "databricks-claude-" +CLAUDE_FAMILIES: tuple[str, ...] = ("opus", "sonnet", "haiku") + + +def resolve_claude_model_prefix(state: dict | None = None, override: str | None = None) -> str: + """Prefix shared by Claude endpoint names before the ```` token. + + Endpoint names are expected to follow ``-``, e.g. + ``databricks-claude-opus-4-8`` or ``acme-bedrock-claude-opus-4-8``. + + Resolution order: an explicit ``override`` (the ``--model-prefix`` flag, + wins), then a persisted ``claude_model_prefix`` state key, then the default + ``databricks-claude-`` (which preserves the built-in hosted behaviour). + Setting a Bedrock prefix (e.g. ``acme-bedrock-claude-``) scopes discovery to + those endpoints and excludes ``databricks-claude-*`` hosted models, since + they no longer share the configured prefix. + """ + if isinstance(override, str) and override.strip(): + return override.strip() + if isinstance(state, dict): + persisted = state.get("claude_model_prefix") + if isinstance(persisted, str) and persisted.strip(): + return persisted.strip() + return DEFAULT_CLAUDE_MODEL_PREFIX + + +def discover_claude_models( + workspace: str, token: str, prefix: str | None = None +) -> tuple[dict[str, str], list[str], str | None]: """Discover Claude families on this workspace's AI Gateway. - Returns (models_by_family, reason). reason is None on success; otherwise it - describes why the dict is empty (HTTP error, network error, or no models - matching the expected naming convention). + Endpoint names are expected to follow ``-``. For + each family the newest matching id becomes the default; every matching id is + kept in the allowlist so additional versions stay selectable. + + :param prefix: Name prefix before the ```` token. Defaults to + :func:`resolve_claude_model_prefix` (env / persisted / built-in). + :returns: ``(defaults, allowed, reason)``. ``reason`` is None on success; + otherwise it describes why ``defaults`` is empty (HTTP error, network + error, or no models matching ``{opus,sonnet,haiku}-*``). """ + if prefix is None: + prefix = resolve_claude_model_prefix() hostname = workspace_hostname(workspace) payload, reason = _http_get_json(f"https://{hostname}/ai-gateway/anthropic/v1/models", token) if payload is None: - return {}, reason + return {}, [], reason data = cast(dict, payload) if isinstance(payload, dict) else {} raw_ids = [ @@ -823,29 +859,35 @@ def discover_claude_models(workspace: str, token: str) -> tuple[dict[str, str], if isinstance(m.get("id"), str) and not m["id"].endswith("-anthropic") ] - result: dict[str, str] = {} - for family, key in [("opus", "opus"), ("sonnet", "sonnet"), ("haiku", "haiku")]: + defaults: dict[str, str] = {} + allowed: list[str] = [] + for family in CLAUDE_FAMILIES: candidates = sorted( - [m for m in raw_ids if f"databricks-claude-{family}-" in m], + [m for m in raw_ids if m.startswith(f"{prefix}{family}-")], reverse=True, ) if candidates: - result[key] = candidates[0] - if result: - return result, None + defaults[family] = candidates[0] + allowed.extend(candidates) + if defaults: + return defaults, allowed, None if not raw_ids: - return {}, "AI Gateway returned no Claude model ids" + return {}, [], "AI Gateway returned no Claude model ids" sample = ", ".join(raw_ids[:5]) - return {}, ( - "AI Gateway returned model ids but none matched " - f"`databricks-claude-{{opus,sonnet,haiku}}-*` (got: {sample})" + return ( + {}, + [], + ( + "AI Gateway returned model ids but none matched " + f"`{prefix}{{opus,sonnet,haiku}}-*` (got: {sample})" + ), ) def fetch_ai_gateway_claude_models(workspace: str, token: str) -> dict[str, str]: - """Backwards-compatible wrapper that discards the diagnostic reason.""" - models, _ = discover_claude_models(workspace, token) - return models + """Backwards-compatible wrapper that returns only the family defaults.""" + defaults, _allowed, _reason = discover_claude_models(workspace, token) + return defaults def discover_endpoints_with_api_type( diff --git a/tests/test_agent_claude.py b/tests/test_agent_claude.py index ea33c63..5b9db73 100644 --- a/tests/test_agent_claude.py +++ b/tests/test_agent_claude.py @@ -92,6 +92,43 @@ def test_managed_keys_include_env_entries(self): assert len(env_keys) > 0 +class TestRenderOverlayAvailableModels: + def test_allowlist_combines_ids_and_family_aliases(self): + models = { + "opus": "acme-bedrock-claude-opus-4-8", + "sonnet": "acme-bedrock-claude-sonnet-4-6", + } + allowed = [ + "acme-bedrock-claude-opus-4-8", + "acme-bedrock-claude-opus-4-7", + "acme-bedrock-claude-sonnet-4-6", + ] + overlay, keys = claude.render_overlay( + WS, "s4", claude_models=models, claude_allowed=allowed + ) + available = overlay["availableModels"] + # family aliases for families that have a default + assert "opus" in available + assert "sonnet" in available + assert "haiku" not in available # no haiku default + # every allowed id, including the extra opus version + assert "acme-bedrock-claude-opus-4-7" in available + assert "acme-bedrock-claude-opus-4-8" in available + assert ["availableModels"] in keys + + def test_falls_back_to_defaults_when_no_allowlist(self): + models = {"opus": "acme-bedrock-claude-opus-4-8"} + overlay, _ = claude.render_overlay(WS, "s4", claude_models=models) + available = overlay["availableModels"] + assert "acme-bedrock-claude-opus-4-8" in available + assert "opus" in available + + def test_no_available_models_key_when_empty(self): + overlay, keys = claude.render_overlay(WS, "s4") + assert "availableModels" not in overlay + assert ["availableModels"] not in keys + + class TestRenderOverlayUserAgent: def _ua(self, monkeypatch) -> str: monkeypatch.setattr(claude, "ucode_version", lambda: "0.1.0") diff --git a/tests/test_cli.py b/tests/test_cli.py index e2c95e2..fa767af 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -225,6 +225,40 @@ def test_status_treats_available_tools_as_configured_agents(self): assert "https://example.databricks.com/ai-gateway/anthropic" not in result.output assert "https://example.databricks.com/ai-gateway/gemini" not in result.output + def test_shows_claude_models_with_bedrock_prefix_and_extra_versions(self): + state = { + **MINIMAL_STATE, + "claude_model_prefix": "acme-bedrock-claude-", + "claude_models": { + "opus": "acme-bedrock-claude-opus-4-8", + "sonnet": "acme-bedrock-claude-sonnet-4-6", + }, + "claude_allowed_models": [ + "acme-bedrock-claude-opus-4-8", + "acme-bedrock-claude-opus-4-7", + "acme-bedrock-claude-sonnet-4-6", + ], + } + with patch("ucode.cli.load_state", return_value=state): + result = runner.invoke(app, ["status"]) + + assert result.exit_code == 0, result.output + output = _strip_ansi(result.output) + assert "Claude Models" in output + assert "acme-bedrock-claude-" in output # endpoint prefix surfaced (non-default) + assert "acme-bedrock-claude-opus-4-8" in output + assert "(+1 more)" in output # the extra opus version + assert "Selectable models" in output + + def test_default_prefix_not_shown_in_status(self): + with patch("ucode.cli.load_state", return_value=MINIMAL_STATE): + result = runner.invoke(app, ["status"]) + + assert result.exit_code == 0, result.output + output = _strip_ansi(result.output) + # the built-in default prefix is noise — only non-default prefixes show + assert "Endpoint prefix" not in output + class TestRevert: def test_reverts_mcp_configs_before_clearing_state(self): @@ -380,7 +414,7 @@ def test_no_flag_calls_configure_all(self): ): result = runner.invoke(app, ["configure"]) assert result.exit_code == 0, result.output - mock_cfg.assert_called_once_with() + mock_cfg.assert_called_once_with(model_prefix=None) def test_agents_flag_calls_configure_with_tools(self): with ( @@ -391,7 +425,7 @@ def test_agents_flag_calls_configure_with_tools(self): result = runner.invoke(app, ["configure", "--agents", "claude,codex"]) assert result.exit_code == 0, result.output mock_install.assert_not_called() - mock_cfg.assert_called_once_with(selected_tools=["claude", "codex"]) + mock_cfg.assert_called_once_with(selected_tools=["claude", "codex"], model_prefix=None) def test_agents_flag_normalizes_aliases_and_dedupes(self): with ( @@ -401,7 +435,7 @@ def test_agents_flag_normalizes_aliases_and_dedupes(self): ): result = runner.invoke(app, ["configure", "--agents", " claude-code, codex,claude "]) assert result.exit_code == 0, result.output - mock_cfg.assert_called_once_with(selected_tools=["claude", "codex"]) + mock_cfg.assert_called_once_with(selected_tools=["claude", "codex"], model_prefix=None) def test_workspaces_flag_calls_configure_with_workspaces(self): with ( @@ -422,7 +456,8 @@ def test_workspaces_flag_calls_configure_with_workspaces(self): workspaces=[ ("https://first.databricks.com", None), ("https://second.databricks.com", None), - ] + ], + model_prefix=None, ) def test_agents_and_workspaces_flags_call_configure_with_both(self): @@ -437,7 +472,9 @@ def test_agents_and_workspaces_flags_call_configure_with_both(self): ) assert result.exit_code == 0, result.output mock_cfg.assert_called_once_with( - selected_tools=["claude", "codex"], workspaces=[("https://first.com", None)] + selected_tools=["claude", "codex"], + workspaces=[("https://first.com", None)], + model_prefix=None, ) def test_agent_and_workspaces_flags_call_configure_with_both(self): @@ -452,7 +489,9 @@ def test_agent_and_workspaces_flags_call_configure_with_both(self): ) assert result.exit_code == 0, result.output mock_install.assert_called_once_with("claude", strict=True, update_existing=True) - mock_cfg.assert_called_once_with("claude", workspaces=[("https://first.com", None)]) + mock_cfg.assert_called_once_with( + "claude", workspaces=[("https://first.com", None)], model_prefix=None + ) def test_agent_flag_calls_configure_with_tool(self): with ( @@ -463,7 +502,7 @@ def test_agent_flag_calls_configure_with_tool(self): result = runner.invoke(app, ["configure", "--agent", "claude"]) assert result.exit_code == 0, result.output mock_install.assert_called_once_with("claude", strict=True, update_existing=True) - mock_cfg.assert_called_once_with("claude") + mock_cfg.assert_called_once_with("claude", model_prefix=None) def test_agent_flag_normalizes_alias(self): with ( @@ -473,7 +512,7 @@ def test_agent_flag_normalizes_alias(self): ): result = runner.invoke(app, ["configure", "--agent", "claude-code"]) assert result.exit_code == 0, result.output - mock_cfg.assert_called_once_with("claude") + mock_cfg.assert_called_once_with("claude", model_prefix=None) def test_upgrade_runs_uv_tool_install(self): with patch("subprocess.run") as mock_run: @@ -611,7 +650,9 @@ def test_multiple_workspaces_configure_all_and_use_first(self, monkeypatch): } configured_shared: list[tuple[str, str | None, tuple[str, ...] | None, bool]] = [] - def fake_configure_shared_state(workspace, profile=None, tools=None, force_login=False): + def fake_configure_shared_state( + workspace, profile=None, tools=None, force_login=False, model_prefix=None + ): configured_shared.append( (workspace, profile, tuple(tools) if tools is not None else None, force_login) ) @@ -662,7 +703,9 @@ def _stub_external_deps(monkeypatch): monkeypatch.setattr(cli_mod, "find_profile_name_for_host", lambda w: None) monkeypatch.setattr(cli_mod, "get_databricks_token", lambda w, p: "token") monkeypatch.setattr(cli_mod, "ensure_ai_gateway_v2", lambda w, t: None) - monkeypatch.setattr(cli_mod, "discover_claude_models", lambda w, t: ({}, None)) + monkeypatch.setattr( + cli_mod, "discover_claude_models", lambda w, t, prefix=None: ({}, [], None) + ) monkeypatch.setattr(cli_mod, "discover_gemini_models", lambda w, t: ([], None)) monkeypatch.setattr(cli_mod, "discover_codex_models", lambda w, t: ([], None)) monkeypatch.setattr(cli_mod, "build_shared_base_urls", lambda w: {}) @@ -704,3 +747,69 @@ def test_skips_purge_when_workspace_unchanged(self, monkeypatch): cli_mod.configure_shared_state("https://same.databricks.com") assert purge_calls == [] + + +class TestConfigureSharedStateModelPrefix: + """The --model-prefix override threads into Claude discovery and persists.""" + + @staticmethod + def _stub(monkeypatch, capture): + import ucode.cli as cli_mod + + monkeypatch.setattr(cli_mod, "normalize_workspace_url", lambda w: w) + monkeypatch.setattr(cli_mod, "run_databricks_login", lambda w, p: None) + monkeypatch.setattr(cli_mod, "ensure_databricks_auth", lambda w, p=None: None) + monkeypatch.setattr(cli_mod, "find_profile_name_for_host", lambda w: None) + monkeypatch.setattr(cli_mod, "get_databricks_token", lambda w, p: "token") + monkeypatch.setattr(cli_mod, "ensure_ai_gateway_v2", lambda w, t: None) + monkeypatch.setattr(cli_mod, "load_state", lambda: {}) + monkeypatch.setattr(cli_mod, "load_full_state", lambda: {"workspaces": {}}) + monkeypatch.setattr(cli_mod, "build_shared_base_urls", lambda w: {}) + monkeypatch.setattr(cli_mod, "discover_gemini_models", lambda w, t: ([], None)) + monkeypatch.setattr(cli_mod, "discover_codex_models", lambda w, t: ([], None)) + + def fake_discover(w, t, prefix=None): + capture["prefix"] = prefix + return {"opus": f"{prefix}opus-4-8"}, [f"{prefix}opus-4-8"], None + + monkeypatch.setattr(cli_mod, "discover_claude_models", fake_discover) + + def test_flag_overrides_discovery_prefix_and_persists(self, monkeypatch): + import ucode.cli as cli_mod + + capture: dict = {} + self._stub(monkeypatch, capture) + + state = cli_mod.configure_shared_state( + "https://ws.databricks.com", model_prefix="acme-bedrock-claude-" + ) + + assert capture["prefix"] == "acme-bedrock-claude-" + assert state["claude_model_prefix"] == "acme-bedrock-claude-" + assert state["claude_models"] == {"opus": "acme-bedrock-claude-opus-4-8"} + assert state["claude_allowed_models"] == ["acme-bedrock-claude-opus-4-8"] + + def test_persisted_prefix_used_when_no_flag(self, monkeypatch): + import ucode.cli as cli_mod + + capture: dict = {} + self._stub(monkeypatch, capture) + monkeypatch.setattr( + cli_mod, + "load_full_state", + lambda: {"workspaces": {"https://ws.databricks.com": {"claude_model_prefix": "acme-"}}}, + ) + + cli_mod.configure_shared_state("https://ws.databricks.com") + + assert capture["prefix"] == "acme-" + + def test_default_prefix_when_nothing_set(self, monkeypatch): + import ucode.cli as cli_mod + + capture: dict = {} + self._stub(monkeypatch, capture) + + cli_mod.configure_shared_state("https://ws.databricks.com") + + assert capture["prefix"] == "databricks-claude-" diff --git a/tests/test_databricks.py b/tests/test_databricks.py index fffdba5..59b5c26 100644 --- a/tests/test_databricks.py +++ b/tests/test_databricks.py @@ -125,10 +125,73 @@ def test_selects_opus_4_8_when_advertised(self, monkeypatch): } monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None)) - models, reason = db_mod.discover_claude_models(WS, "token") + models, allowed, reason = db_mod.discover_claude_models(WS, "token") assert reason is None assert models["opus"] == "databricks-claude-opus-4-8" + # newest per family is the default; older opus stays in the allowlist + assert "databricks-claude-opus-4-7" in allowed + assert "databricks-claude-opus-4-8" in allowed + assert "databricks-claude-sonnet-4-6" in allowed + + def test_custom_prefix_scopes_to_bedrock_and_excludes_hosted(self, monkeypatch): + payload = { + "data": [ + {"id": "acme-bedrock-claude-opus-4-8"}, + {"id": "acme-bedrock-claude-sonnet-4-6"}, + {"id": "acme-bedrock-claude-haiku-4-5"}, + {"id": "databricks-claude-opus-4-7"}, # hosted — must be excluded + ] + } + monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None)) + + models, allowed, reason = db_mod.discover_claude_models(WS, "token", "acme-bedrock-claude-") + + assert reason is None + assert models == { + "opus": "acme-bedrock-claude-opus-4-8", + "sonnet": "acme-bedrock-claude-sonnet-4-6", + "haiku": "acme-bedrock-claude-haiku-4-5", + } + assert "databricks-claude-opus-4-7" not in allowed + assert all(m.startswith("acme-bedrock-claude-") for m in allowed) + + def test_no_match_reports_prefix_in_reason(self, monkeypatch): + payload = {"data": [{"id": "databricks-claude-opus-4-8"}]} + monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None)) + + models, allowed, reason = db_mod.discover_claude_models(WS, "token", "acme-bedrock-claude-") + + assert models == {} + assert allowed == [] + assert "acme-bedrock-claude-" in reason + + +class TestResolveClaudeModelPrefix: + def test_default_when_unset(self): + assert db_mod.resolve_claude_model_prefix() == "databricks-claude-" + + def test_override_wins_over_persisted(self): + assert ( + db_mod.resolve_claude_model_prefix( + {"claude_model_prefix": "other-"}, override="acme-bedrock-claude-" + ) + == "acme-bedrock-claude-" + ) + + def test_persisted_state_used_when_no_override(self): + assert ( + db_mod.resolve_claude_model_prefix({"claude_model_prefix": "acme-bedrock-claude-"}) + == "acme-bedrock-claude-" + ) + + def test_blank_override_falls_through_to_persisted(self): + assert ( + db_mod.resolve_claude_model_prefix( + {"claude_model_prefix": "acme-bedrock-claude-"}, override=" " + ) + == "acme-bedrock-claude-" + ) class TestBuildAuthShellCommand: