From a05aa460a6845661864a9b6593e7efe0b042eb45 Mon Sep 17 00:00:00 2001 From: tulsi Date: Thu, 2 Apr 2026 16:31:20 -0700 Subject: [PATCH] feat: provider model picker with selection, caching, and mid-chat switching Backend: - Expand acp:model_state payload to include availableModels from both SessionModelState and SessionConfigOption paths. - Add acp_get_model_state and acp_set_model Tauri commands. - Provider-scoped session storage prevents cross-provider session reuse. - load_session falls back to new_session when stored session ID is stale. - Bootstrap no longer persists agent session IDs (only driver.run does). Frontend: - Wire ChatView -> ChatInput -> ChatInputToolbar with real model data from shared modelStateByProvider store. - HomeScreen writes to shared store so ChatView reads cached models instantly. - ChatView skips ACP bootstrap when cached models exist for the provider. - setModelState preserves explicit user model selection over provider defaults. - Model and provider switches show system notifications in chat timeline. - Pre-send model sync in useChat applies selected model (best-effort). - Fix modelsLoading stuck true when bootstrap effects were cancelled. Tests: - Model picker rendering, loading state, and selection callback coverage. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/check-file-sizes.mjs | 10 + src-tauri/Cargo.lock | 1 + src-tauri/Cargo.toml | 1 + src-tauri/src/commands/acp.rs | 443 +++++++++++++++++- src-tauri/src/lib.rs | 2 + src-tauri/src/services/acp/mod.rs | 8 +- src-tauri/src/services/acp/payloads.rs | 13 +- src-tauri/src/services/acp/store.rs | 23 +- src-tauri/src/services/acp/writer.rs | 117 ++++- src/app/AppShell.tsx | 23 + src/features/chat/hooks/useAcpStream.ts | 18 +- src/features/chat/hooks/useChat.ts | 41 +- src/features/chat/stores/chatSessionStore.ts | 61 ++- src/features/chat/ui/ChatInput.tsx | 14 +- src/features/chat/ui/ChatInputToolbar.tsx | 30 +- src/features/chat/ui/ChatView.tsx | 240 +++++++++- .../chat/ui/__tests__/ChatInput.test.tsx | 39 ++ src/features/home/ui/HomeScreen.test.tsx | 9 + src/features/home/ui/HomeScreen.tsx | 106 ++++- src/shared/api/acp.ts | 51 ++ src/shared/types/chat.ts | 19 + 21 files changed, 1218 insertions(+), 51 deletions(-) diff --git a/scripts/check-file-sizes.mjs b/scripts/check-file-sizes.mjs index 65db5ce2..106cc430 100644 --- a/scripts/check-file-sizes.mjs +++ b/scripts/check-file-sizes.mjs @@ -10,6 +10,16 @@ const EXCEPTIONS = { justification: "Temporary allowance while project creation and home/chat handoff wiring remain in the shell.", }, + "src/features/chat/ui/ChatView.tsx": { + limit: 650, + justification: + "Model picker bootstrap, provider/model switching, and notification logic. Candidate for extraction once model picker stabilizes.", + }, + "src-tauri/src/commands/acp.rs": { + limit: 700, + justification: + "ACP command handlers including model state discovery and set_model. run_acp_get_model_state and run_acp_set_model are extraction candidates.", + }, }; // Directories excluded from size checks (imported library code) diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 68bbd27f..df8bb0d8 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1485,6 +1485,7 @@ name = "goose2" version = "0.1.0" dependencies = [ "acp-client", + "agent-client-protocol", "async-trait", "chrono", "dirs", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 4eeb452c..cc252b2c 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -33,6 +33,7 @@ futures = "0.3" tokio-stream = "0.1" tokio-util = { version = "0.7", features = ["rt"] } async-trait = "0.1" +agent-client-protocol = { version = "0.10", features = ["unstable"] } acp-client = { git = "https://github.com/block/builderbot", rev = "dbd5bc9c0cd559c7a67ab2d56545dba110f4a4cb" } doctor = { git = "https://github.com/block/builderbot", rev = "bfce29264181236a2e5b4fd903ace6291f6a7594" } diff --git a/src-tauri/src/commands/acp.rs b/src-tauri/src/commands/acp.rs index b197bebb..b408e221 100644 --- a/src-tauri/src/commands/acp.rs +++ b/src-tauri/src/commands/acp.rs @@ -1,14 +1,27 @@ +use agent_client_protocol::{ + Agent, Client, ClientSideConnection, Error as AcpError, Implementation, InitializeRequest, + LoadSessionRequest, NewSessionRequest, ProtocolVersion, RequestPermissionRequest, + RequestPermissionResponse, Result as AcpResult, SessionConfigKind, SessionConfigOption, + SessionConfigOptionCategory, SessionConfigSelectOptions, SessionModelState, + SessionNotification, SetSessionConfigOptionRequest, SetSessionModelRequest, +}; +use async_trait::async_trait; use serde::Serialize; use std::path::PathBuf; +use std::process::Stdio; use std::sync::Arc; use tauri::{AppHandle, State}; +use tokio::process::Command; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; -use crate::services::acp::{make_composite_key, AcpRunningSession, AcpService, AcpSessionRegistry}; +use crate::services::acp::{ + make_composite_key, AcpRunningSession, AcpService, AcpSessionRegistry, TauriStore, +}; use crate::services::sessions::SessionStore; use crate::types::messages::{ MessageCompletionStatus, MessageContent, MessageMetadata, ToolCallStatus, }; -use acp_client::discover_providers; +use acp_client::{discover_providers, find_acp_agent_by_id}; /// Response type for an ACP provider, sent to the frontend. #[derive(Serialize)] @@ -18,6 +31,24 @@ pub struct AcpProviderResponse { label: String, } +#[derive(Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct AcpAvailableModelResponse { + id: String, + name: String, + description: Option, +} + +#[derive(Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct AcpModelStateResponse { + source: String, + config_id: Option, + current_model_id: String, + current_model_name: Option, + available_models: Vec, +} + fn default_artifacts_working_dir() -> PathBuf { if let Some(home_dir) = dirs::home_dir() { return home_dir.join(".goose").join("artifacts"); @@ -127,6 +158,414 @@ pub async fn acp_send_message( .await } +struct NoopAcpClient; + +#[async_trait(?Send)] +impl Client for NoopAcpClient { + async fn request_permission( + &self, + _args: RequestPermissionRequest, + ) -> AcpResult { + Err(AcpError::method_not_found()) + } + + async fn session_notification(&self, _args: SessionNotification) -> AcpResult<()> { + Ok(()) + } +} + +async fn run_acp_set_model( + working_dir: PathBuf, + provider_id: String, + agent_session_id: String, + model_id: String, + source: String, + config_id: Option, +) -> Result<(), String> { + let provider = find_acp_agent_by_id(&provider_id) + .ok_or_else(|| format!("Unknown or unavailable agent provider: {provider_id}"))?; + + let mut child = Command::new(&provider.binary_path) + .args(&provider.acp_args) + .current_dir(&working_dir) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::null()) + .kill_on_drop(true) + .spawn() + .map_err(|error| { + format!( + "Failed to spawn {} (binary: {}, cwd: {}): {error}", + provider.label, + provider.binary_path.display(), + working_dir.display() + ) + })?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| "Failed to open ACP stdin".to_string())?; + let stdout = child + .stdout + .take() + .ok_or_else(|| "Failed to open ACP stdout".to_string())?; + + let (connection, io_future) = ClientSideConnection::new( + NoopAcpClient, + stdin.compat_write(), + stdout.compat(), + |future| { + tokio::task::spawn_local(future); + }, + ); + + tokio::task::spawn_local(async move { + if let Err(error) = io_future.await { + log::error!("ACP IO error during set_model: {error:?}"); + } + }); + + let init_request = InitializeRequest::new(ProtocolVersion::LATEST) + .client_info(Implementation::new("goose2", env!("CARGO_PKG_VERSION"))); + let init_response = connection + .initialize(init_request) + .await + .map_err(|error| format!("ACP init failed: {error:?}"))?; + + if !init_response.agent_capabilities.load_session { + return Err("Agent does not support load_session".to_string()); + } + + connection + .load_session(LoadSessionRequest::new( + agent_session_id.clone(), + working_dir.clone(), + )) + .await + .map_err(|error| format!("Failed to load ACP session: {error:?}"))?; + + match source.as_str() { + "config_option" => { + let config_id = config_id.ok_or_else(|| { + "Missing config option ID for config-option-backed model selection".to_string() + })?; + + connection + .set_session_config_option(SetSessionConfigOptionRequest::new( + agent_session_id, + config_id, + model_id.as_str(), + )) + .await + .map_err(|error| format!("Failed to set ACP session config option: {error:?}"))?; + } + _ => { + connection + .set_session_model(SetSessionModelRequest::new(agent_session_id, model_id)) + .await + .map_err(|error| format!("Failed to set ACP session model: {error:?}"))?; + } + } + + let _ = child.start_kill(); + let _ = child.wait().await; + + Ok(()) +} + +fn model_state_from_session_models( + model_state: SessionModelState, +) -> Option { + if model_state.available_models.is_empty() { + return None; + } + + let current_model_name = model_state + .available_models + .iter() + .find(|model| model.model_id == model_state.current_model_id) + .map(|model| model.name.clone()); + let available_models = model_state + .available_models + .into_iter() + .map(|model| AcpAvailableModelResponse { + id: model.model_id.to_string(), + name: model.name, + description: model.description, + }) + .collect(); + + Some(AcpModelStateResponse { + source: "session_model".to_string(), + config_id: None, + current_model_id: model_state.current_model_id.to_string(), + current_model_name, + available_models, + }) +} + +fn model_state_from_config_options( + config_options: Vec, +) -> Option { + let option = config_options + .into_iter() + .find(|option| matches!(option.category, Some(SessionConfigOptionCategory::Model)))?; + let config_id = option.id.to_string(); + let option_name = option.name.clone(); + + let select = match option.kind { + SessionConfigKind::Select(select) => select, + #[allow(unreachable_patterns)] + _ => return None, + }; + + let current_model_id = select.current_value.to_string(); + let available_models = match select.options { + SessionConfigSelectOptions::Ungrouped(options) => options + .into_iter() + .map(|model| AcpAvailableModelResponse { + id: model.value.to_string(), + name: model.name, + description: model.description, + }) + .collect::>(), + SessionConfigSelectOptions::Grouped(groups) => groups + .into_iter() + .flat_map(|group| group.options.into_iter()) + .map(|model| AcpAvailableModelResponse { + id: model.value.to_string(), + name: model.name, + description: model.description, + }) + .collect::>(), + _ => return None, + }; + let current_model_name = available_models + .iter() + .find(|model| model.id == current_model_id) + .map(|model| model.name.clone()) + .or(Some(option_name)); + + Some(AcpModelStateResponse { + source: "config_option".to_string(), + config_id: Some(config_id), + current_model_id, + current_model_name, + available_models, + }) +} + +fn normalize_model_state( + models: Option, + config_options: Option>, +) -> Result { + if let Some(model_state) = models.and_then(model_state_from_session_models) { + return Ok(model_state); + } + + if let Some(model_state) = config_options.and_then(model_state_from_config_options) { + return Ok(model_state); + } + + Err("ACP session did not return model options".to_string()) +} + +async fn run_acp_get_model_state( + working_dir: PathBuf, + _session_id: String, + provider_id: String, + tauri_store: TauriStore, +) -> Result { + let provider = find_acp_agent_by_id(&provider_id) + .ok_or_else(|| format!("Unknown or unavailable agent provider: {provider_id}"))?; + + let mut child = Command::new(&provider.binary_path) + .args(&provider.acp_args) + .current_dir(&working_dir) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::null()) + .kill_on_drop(true) + .spawn() + .map_err(|error| { + format!( + "Failed to spawn {} (binary: {}, cwd: {}): {error}", + provider.label, + provider.binary_path.display(), + working_dir.display() + ) + })?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| "Failed to open ACP stdin".to_string())?; + let stdout = child + .stdout + .take() + .ok_or_else(|| "Failed to open ACP stdout".to_string())?; + + let (connection, io_future) = ClientSideConnection::new( + NoopAcpClient, + stdin.compat_write(), + stdout.compat(), + |future| { + tokio::task::spawn_local(future); + }, + ); + + tokio::task::spawn_local(async move { + if let Err(error) = io_future.await { + log::error!("ACP IO error during model bootstrap: {error:?}"); + } + }); + + let result = async { + let init_request = InitializeRequest::new(ProtocolVersion::LATEST) + .client_info(Implementation::new("goose2", env!("CARGO_PKG_VERSION"))); + let init_response = connection + .initialize(init_request) + .await + .map_err(|error| format!("ACP init failed: {error:?}"))?; + + // Try to load an existing session. If the stored session ID is stale + // (e.g. provider was restarted), fall back to creating a new session. + let mut loaded = None; + if let Some(agent_session_id) = tauri_store.get_agent_session_id() { + if init_response.agent_capabilities.load_session { + match connection + .load_session(LoadSessionRequest::new( + agent_session_id.clone(), + working_dir.clone(), + )) + .await + { + Ok(session_response) => { + loaded = Some(session_response); + } + Err(error) => { + log::warn!( + "Failed to load ACP session {agent_session_id}, \ + falling back to new session: {error:?}" + ); + } + } + } + } + + if let Some(session_response) = loaded { + normalize_model_state(session_response.models, session_response.config_options) + } else { + let session_response = connection + .new_session(NewSessionRequest::new(working_dir.clone())) + .await + .map_err(|error| format!("Failed to create ACP session: {error:?}"))?; + + // Do NOT save the agent session ID here. This bootstrap process is + // throwaway — it will be killed after we read the model state. The + // real send path (driver.run) creates its own session and persists + // the ID via the Store trait. Saving here causes the send path to + // find a stale ID from a dead process and fail with "Resource not found". + + normalize_model_state(session_response.models, session_response.config_options) + } + } + .await; + + let _ = child.start_kill(); + let _ = child.wait().await; + + result +} + +#[tauri::command] +pub async fn acp_get_model_state( + session_store: State<'_, Arc>, + session_id: String, + provider_id: String, + persona_id: Option, + working_dir: Option, + persist_session: Option, +) -> Result { + let current_dir = std::env::current_dir() + .map_err(|error| format!("Failed to determine current working directory: {error}"))?; + let working_dir = resolve_working_dir(working_dir, ¤t_dir)?; + + if persist_session.unwrap_or(true) { + session_store.ensure_session(&session_id, Some(provider_id.clone())); + } + let tauri_store = TauriStore::with_provider( + Arc::clone(&session_store), + session_id.clone(), + persona_id.clone(), + Some(provider_id.clone()), + ); + + tokio::task::spawn_blocking(move || { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|error| format!("Failed to build tokio runtime: {error}"))?; + let local = tokio::task::LocalSet::new(); + + local.block_on(&runtime, async move { + run_acp_get_model_state(working_dir, session_id, provider_id, tauri_store).await + }) + }) + .await + .map_err(|error| format!("ACP model bootstrap task panicked: {error}"))? +} + +#[allow(clippy::too_many_arguments)] +#[tauri::command] +pub async fn acp_set_model( + session_store: State<'_, Arc>, + session_id: String, + provider_id: String, + model_id: String, + source: String, + config_id: Option, + persona_id: Option, + working_dir: Option, +) -> Result<(), String> { + let current_dir = std::env::current_dir() + .map_err(|error| format!("Failed to determine current working directory: {error}"))?; + let working_dir = resolve_working_dir(working_dir, ¤t_dir)?; + let tauri_store = TauriStore::with_provider( + Arc::clone(&session_store), + session_id.clone(), + persona_id.clone(), + Some(provider_id.clone()), + ); + let agent_session_id = tauri_store.get_agent_session_id().ok_or_else(|| { + let key = make_composite_key(&session_id, persona_id.as_deref()); + format!("No ACP session found for '{key}'. Send a message first before changing models.") + })?; + + tokio::task::spawn_blocking(move || { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|error| format!("Failed to build tokio runtime: {error}"))?; + let local = tokio::task::LocalSet::new(); + + local.block_on(&runtime, async move { + run_acp_set_model( + working_dir, + provider_id, + agent_session_id, + model_id, + source, + config_id, + ) + .await + }) + }) + .await + .map_err(|error| format!("ACP set_model task panicked: {error}"))? +} + #[cfg(test)] mod tests { use super::{expand_home_dir, resolve_working_dir}; diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 894a03a2..8e63f076 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -49,7 +49,9 @@ pub fn run() { commands::sessions::unarchive_session, commands::chat::chat_send_message, commands::acp::discover_acp_providers, + commands::acp::acp_get_model_state, commands::acp::acp_send_message, + commands::acp::acp_set_model, commands::acp::acp_cancel_session, commands::acp::acp_list_running, commands::acp::acp_cancel_all, diff --git a/src-tauri/src/services/acp/mod.rs b/src-tauri/src/services/acp/mod.rs index 1bf967f2..ed548708 100644 --- a/src-tauri/src/services/acp/mod.rs +++ b/src-tauri/src/services/acp/mod.rs @@ -112,8 +112,12 @@ impl AcpService { ®istry_key, writer_impl.assistant_message_id().to_string(), ); - let tauri_store = - TauriStore::new(Arc::clone(&session_store), session_id.clone(), persona_id); + let tauri_store = TauriStore::with_provider( + Arc::clone(&session_store), + session_id.clone(), + persona_id, + Some(provider_id.clone()), + ); let agent_session_id = tauri_store.get_agent_session_id(); let store: Arc = Arc::new(tauri_store); let writer: Arc = writer_impl.clone(); diff --git a/src-tauri/src/services/acp/payloads.rs b/src-tauri/src/services/acp/payloads.rs index 9cf04509..24345990 100644 --- a/src-tauri/src/services/acp/payloads.rs +++ b/src-tauri/src/services/acp/payloads.rs @@ -65,8 +65,19 @@ pub(super) struct SessionInfoPayload { #[derive(Clone, Serialize)] #[serde(rename_all = "camelCase")] -pub(super) struct ModelStatePayload { +pub(crate) struct AvailableModelPayload { + pub id: String, + pub name: String, + pub description: Option, +} + +#[derive(Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ModelStatePayload { pub session_id: String, + pub source: String, + pub config_id: Option, pub current_model_id: String, pub current_model_name: Option, + pub available_models: Vec, } diff --git a/src-tauri/src/services/acp/store.rs b/src-tauri/src/services/acp/store.rs index 336da737..cab6c4c5 100644 --- a/src-tauri/src/services/acp/store.rs +++ b/src-tauri/src/services/acp/store.rs @@ -23,6 +23,10 @@ pub struct TauriStore { /// Optional persona scoping. When `Some`, the effective key used for /// on-disk persistence becomes `{session_id}__{persona_id}`. persona_id: Option, + /// The ACP provider ID for this session. When set, cached agent session + /// IDs are only returned if the stored provider matches — this prevents + /// cross-provider session reuse when the user switches providers mid-chat. + provider_id: Option, } impl TauriStore { @@ -32,10 +36,11 @@ impl TauriStore { /// * `session_id` – the Goose chat session ID. /// * `persona_id` – when `Some`, enables per-persona composite keying. /// Pass `None` to preserve single-persona behaviour. - pub fn new( + pub fn with_provider( session_store: Arc, session_id: String, persona_id: Option, + provider_id: Option, ) -> Self { let sessions_dir = dirs::home_dir() .expect("home dir") @@ -47,6 +52,7 @@ impl TauriStore { session_store, session_id, persona_id, + provider_id, } } @@ -60,11 +66,25 @@ impl TauriStore { } /// Look up a previously stored agent session ID, or `None` for new sessions. + /// + /// When `provider_id` is set on this store, only returns a cached session + /// if the stored provider matches. This prevents reusing a Goose session + /// ID when the user has switched to Codex (or vice versa). pub fn get_agent_session_id(&self) -> Option { let key = self.effective_key(); let path = self.sessions_dir.join(format!("{key}.json")); let json = std::fs::read_to_string(&path).ok()?; let mapping: serde_json::Value = serde_json::from_str(&json).ok()?; + + // If we know the current provider, reject cached sessions from a different provider + if let Some(ref expected_provider) = self.provider_id { + if let Some(stored_provider) = mapping["provider_id"].as_str() { + if stored_provider != expected_provider.as_str() { + return None; + } + } + } + mapping["agent_session_id"].as_str().map(String::from) } @@ -158,6 +178,7 @@ impl Store for TauriStore { let payload = serde_json::json!({ "session_id": self.session_id, "persona_id": self.persona_id, + "provider_id": self.provider_id, "agent_session_id": agent_session_id, }); let json = serde_json::to_string_pretty(&payload) diff --git a/src-tauri/src/services/acp/writer.rs b/src-tauri/src/services/acp/writer.rs index 3dbe241d..0926ff65 100644 --- a/src-tauri/src/services/acp/writer.rs +++ b/src-tauri/src/services/acp/writer.rs @@ -1,9 +1,13 @@ use std::sync::Arc; +use agent_client_protocol::{SessionConfigKind, SessionConfigSelectOptions}; use async_trait::async_trait; use tauri::Emitter; -use acp_client::{MessageWriter, SessionInfoUpdate, SessionModelState}; +use acp_client::{ + MessageWriter, SessionConfigOption, SessionConfigOptionCategory, SessionInfoUpdate, + SessionModelState, +}; use crate::services::sessions::SessionStore; use crate::types::messages::{ @@ -11,8 +15,8 @@ use crate::types::messages::{ }; use super::payloads::{ - DonePayload, MessageCreatedPayload, ModelStatePayload, SessionInfoPayload, TextPayload, - ToolCallPayload, ToolResultPayload, ToolTitlePayload, + AvailableModelPayload, DonePayload, MessageCreatedPayload, ModelStatePayload, + SessionInfoPayload, TextPayload, ToolCallPayload, ToolResultPayload, ToolTitlePayload, }; /// A [`MessageWriter`] implementation that streams ACP output to the frontend @@ -134,6 +138,10 @@ impl TauriMessageWriter { ); } } + + fn emit_model_state(&self, payload: ModelStatePayload) { + let _ = self.app_handle.emit("acp:model_state", payload); + } } fn append_text_block(content: &mut Vec, text: &str) { @@ -168,6 +176,86 @@ fn find_latest_unpaired_tool_request( None } +fn model_payload_from_session_model_state( + session_id: String, + state: &SessionModelState, +) -> ModelStatePayload { + let available_models = state + .available_models + .iter() + .map(|model| AvailableModelPayload { + id: model.model_id.to_string(), + name: model.name.clone(), + description: model.description.clone(), + }) + .collect(); + let current_model_name = state + .available_models + .iter() + .find(|model| model.model_id == state.current_model_id) + .map(|model| model.name.clone()); + + ModelStatePayload { + session_id, + source: "session_model".to_string(), + config_id: None, + current_model_id: state.current_model_id.to_string(), + current_model_name, + available_models, + } +} + +fn model_payload_from_config_options( + session_id: String, + options: &[SessionConfigOption], +) -> Option { + let option = options + .iter() + .find(|option| matches!(option.category, Some(SessionConfigOptionCategory::Model)))?; + + let select = match &option.kind { + SessionConfigKind::Select(select) => select, + #[allow(unreachable_patterns)] + _ => return None, + }; + + let available_models = match &select.options { + SessionConfigSelectOptions::Ungrouped(options) => options + .iter() + .map(|model| AvailableModelPayload { + id: model.value.to_string(), + name: model.name.clone(), + description: model.description.clone(), + }) + .collect::>(), + SessionConfigSelectOptions::Grouped(groups) => groups + .iter() + .flat_map(|group| group.options.iter()) + .map(|model| AvailableModelPayload { + id: model.value.to_string(), + name: model.name.clone(), + description: model.description.clone(), + }) + .collect::>(), + _ => return None, + }; + let current_model_id = select.current_value.to_string(); + let current_model_name = available_models + .iter() + .find(|model| model.id == current_model_id) + .map(|model| model.name.clone()) + .or_else(|| Some(option.name.clone())); + + Some(ModelStatePayload { + session_id, + source: "config_option".to_string(), + config_id: Some(option.id.to_string()), + current_model_id, + current_model_name, + available_models, + }) +} + #[async_trait] impl MessageWriter for TauriMessageWriter { async fn append_text(&self, text: &str) { @@ -314,18 +402,15 @@ impl MessageWriter for TauriMessageWriter { } async fn on_model_state_update(&self, state: &SessionModelState) { - let current_model_name = state - .available_models - .iter() - .find(|m| m.model_id == state.current_model_id) - .map(|m| m.name.clone()); - let _ = self.app_handle.emit( - "acp:model_state", - ModelStatePayload { - session_id: self.session_id.clone(), - current_model_id: state.current_model_id.to_string(), - current_model_name, - }, - ); + self.emit_model_state(model_payload_from_session_model_state( + self.session_id.clone(), + state, + )); + } + + async fn on_config_option_update(&self, options: &[SessionConfigOption]) { + if let Some(payload) = model_payload_from_config_options(self.session_id.clone(), options) { + self.emit_model_state(payload); + } } } diff --git a/src/app/AppShell.tsx b/src/app/AppShell.tsx index c4ae12bd..4cb27d40 100644 --- a/src/app/AppShell.tsx +++ b/src/app/AppShell.tsx @@ -110,6 +110,12 @@ export function AppShell({ children }: { children?: React.ReactNode }) { const [homeSelectedPersonaId, setHomeSelectedPersonaId] = useState< string | undefined >(); + const [homeSelectedModelId, setHomeSelectedModelId] = useState< + string | undefined + >(); + const [homeSelectedModelName, setHomeSelectedModelName] = useState< + string | undefined + >(); const createNewTab = useCallback( async (title = "New Chat", project?: ProjectInfo) => { @@ -145,6 +151,13 @@ export function AppShell({ children }: { children?: React.ReactNode }) { personaId, }); + if (homeSelectedModelId || homeSelectedModelName) { + sessionStore.updateSession(session.id, { + currentModelId: homeSelectedModelId, + modelName: homeSelectedModelName, + }); + } + sessionStore.setActiveSession(session.id); setActiveView("chat"); @@ -156,6 +169,8 @@ export function AppShell({ children }: { children?: React.ReactNode }) { chatStore, sessionStore, agentStore.activeAgentId, + homeSelectedModelId, + homeSelectedModelName, homeSelectedPersonaId, homeSelectedProvider, ], @@ -289,10 +304,14 @@ export function AppShell({ children }: { children?: React.ReactNode }) { providerId?: string, personaId?: string, projectId?: string | null, + modelId?: string, + modelName?: string, images?: PastedImage[], ) => { setHomeSelectedProvider(providerId); setHomeSelectedPersonaId(personaId); + setHomeSelectedModelId(modelId); + setHomeSelectedModelName(modelName); setPendingInitialMessage(initialMessage); setPendingInitialImages(images); const selectedProject = @@ -308,6 +327,8 @@ export function AppShell({ children }: { children?: React.ReactNode }) { setPendingInitialImages(undefined); setHomeSelectedProvider(undefined); setHomeSelectedPersonaId(undefined); + setHomeSelectedModelId(undefined); + setHomeSelectedModelName(undefined); }); }, [createNewTab, projectStore.projects], @@ -385,6 +406,8 @@ export function AppShell({ children }: { children?: React.ReactNode }) { setPendingInitialImages(undefined); setHomeSelectedProvider(undefined); setHomeSelectedPersonaId(undefined); + setHomeSelectedModelId(undefined); + setHomeSelectedModelName(undefined); }} /> ) : ( diff --git a/src/features/chat/hooks/useAcpStream.ts b/src/features/chat/hooks/useAcpStream.ts index 05d53115..1571aaf7 100644 --- a/src/features/chat/hooks/useAcpStream.ts +++ b/src/features/chat/hooks/useAcpStream.ts @@ -2,6 +2,7 @@ import { useEffect } from "react"; import { listen, type UnlistenFn } from "@tauri-apps/api/event"; import { useChatStore } from "../stores/chatStore"; import { useChatSessionStore } from "../stores/chatSessionStore"; +import type { ModelOption, ModelSelectionSource } from "@/shared/types/chat"; import type { Message, MessageCompletionStatus, @@ -57,8 +58,11 @@ interface AcpSessionInfoPayload { interface AcpModelStatePayload { sessionId: string; + source: ModelSelectionSource; + configId?: string; currentModelId: string; currentModelName?: string; + availableModels: ModelOption[]; } interface AcpUsageUpdatePayload { @@ -367,15 +371,17 @@ export function useAcpStream(enabled: boolean): void { }), ); - // acp:model_state — update model name from ACP provider + // acp:model_state — update the current model and cache available models unlisteners.push( listen("acp:model_state", (event) => { if (!active) return; - const modelName = - event.payload.currentModelName ?? event.payload.currentModelId; - useChatSessionStore - .getState() - .updateSession(event.payload.sessionId, { modelName }); + useChatSessionStore.getState().setModelState(event.payload.sessionId, { + source: event.payload.source, + configId: event.payload.configId, + currentModelId: event.payload.currentModelId, + currentModelName: event.payload.currentModelName, + availableModels: event.payload.availableModels, + }); }), ); diff --git a/src/features/chat/hooks/useChat.ts b/src/features/chat/hooks/useChat.ts index 97847a0b..494ef0d8 100644 --- a/src/features/chat/hooks/useChat.ts +++ b/src/features/chat/hooks/useChat.ts @@ -6,7 +6,12 @@ import { createUserMessage, } from "@/shared/types/messages"; import type { ChatState, TokenState } from "@/shared/types/chat"; -import { acpSendMessage, acpCancelSession } from "@/shared/api/acp"; +import { + acpSendMessage, + acpCancelSession, + acpGetModelState, + acpSetModel, +} from "@/shared/api/acp"; import { useAgentStore } from "@/features/agents/stores/agentStore"; import { findLastIndex } from "@/shared/lib/arrays"; @@ -177,6 +182,40 @@ export function useChat( const providerId = providerOverride ?? agent?.provider ?? "goose"; const systemPrompt = systemPromptOverride ?? agent?.systemPrompt ?? undefined; + const activeSession = useChatSessionStore + .getState() + .getSession(sessionId); + const targetModelId = activeSession?.currentModelId; + + if (targetModelId) { + // Ensure ACP session exists and apply model if it differs from target. + // This handles: (a) first send after Home screen model selection, + // (b) safety net if handleModelChange's set_model didn't stick. + try { + const modelState = await acpGetModelState(sessionId, providerId, { + personaId: effectivePersonaInfo?.id, + workingDir: workingDirOverride, + persistSession: true, + }); + + if ( + modelState.currentModelId !== targetModelId && + modelState.availableModels.some( + (model) => model.id === targetModelId, + ) + ) { + await acpSetModel(sessionId, providerId, targetModelId, { + source: modelState.source, + configId: modelState.configId, + personaId: effectivePersonaInfo?.id, + workingDir: workingDirOverride, + }); + } + } catch (modelError) { + // Model sync is best-effort — don't block the send + console.warn("Pre-send model sync failed:", modelError); + } + } // Send via ACP — response streams back through Tauri events // which are handled by the global useAcpStream listener in AppShell. diff --git a/src/features/chat/stores/chatSessionStore.ts b/src/features/chat/stores/chatSessionStore.ts index a39e4cca..230f4a70 100644 --- a/src/features/chat/stores/chatSessionStore.ts +++ b/src/features/chat/stores/chatSessionStore.ts @@ -6,7 +6,7 @@ import { unarchiveSession as apiUnarchiveSession, updateSession as apiUpdateSession, } from "@/shared/api/chat"; -import type { Session } from "@/shared/types/chat"; +import type { ProviderModelState, Session } from "@/shared/types/chat"; const SESSION_CACHE_STORAGE_KEY = "goose:chat-sessions"; @@ -18,6 +18,7 @@ export interface ChatSession { agentId?: string; providerId?: string; personaId?: string; + currentModelId?: string; modelName?: string; createdAt: string; // ISO timestamp updatedAt: string; @@ -27,6 +28,7 @@ export interface ChatSession { interface ChatSessionStoreState { sessions: ChatSession[]; + modelStateByProvider: Record; activeSessionId: string | null; isLoading: boolean; } @@ -42,6 +44,17 @@ interface ChatSessionStoreActions { }) => Promise; loadSessions: () => Promise; updateSession: (id: string, patch: Partial) => void; + setModelState: ( + sessionId: string, + modelState: { + providerId?: string; + source: ProviderModelState["source"]; + configId?: string; + currentModelId: string; + currentModelName?: string; + availableModels: ProviderModelState["availableModels"]; + }, + ) => void; archiveSession: (id: string) => Promise; unarchiveSession: (id: string) => Promise; @@ -86,6 +99,7 @@ function sessionToChatSession(session: Session): ChatSession { projectId: session.projectId, providerId: session.providerId, personaId: session.personaId, + currentModelId: session.currentModelId, modelName: session.modelName, createdAt: session.createdAt, updatedAt: session.updatedAt, @@ -97,6 +111,7 @@ function sessionToChatSession(session: Session): ChatSession { export const useChatSessionStore = create((set, get) => ({ // State sessions: loadCachedSessions(), + modelStateByProvider: {}, activeSessionId: null, isLoading: false, @@ -177,6 +192,50 @@ export const useChatSessionStore = create((set, get) => ({ } }, + setModelState: (sessionId, modelState) => { + const session = get().getSession(sessionId); + const providerId = modelState.providerId ?? session?.providerId; + + // If the user has explicitly selected a model (session.currentModelId is set + // and differs from what the provider reports), preserve their choice. + // Only set the session's model when it hasn't been explicitly chosen yet. + const userHasExplicitSelection = + session?.currentModelId && + session.currentModelId !== modelState.currentModelId; + + set((state) => ({ + sessions: state.sessions.map((candidate) => + candidate.id === sessionId + ? { + ...candidate, + ...(userHasExplicitSelection + ? {} + : { + currentModelId: modelState.currentModelId, + modelName: + modelState.currentModelName ?? modelState.currentModelId, + }), + updatedAt: new Date().toISOString(), + } + : candidate, + ), + modelStateByProvider: providerId + ? { + ...state.modelStateByProvider, + [providerId]: { + source: modelState.source, + configId: modelState.configId, + currentModelId: modelState.currentModelId, + currentModelName: modelState.currentModelName, + availableModels: modelState.availableModels, + }, + } + : state.modelStateByProvider, + })); + + persistSessions(get().sessions); + }, + archiveSession: async (id) => { const previousActiveSessionId = get().activeSessionId; set((state) => { diff --git a/src/features/chat/ui/ChatInput.tsx b/src/features/chat/ui/ChatInput.tsx index 9b6e6245..65b4094a 100644 --- a/src/features/chat/ui/ChatInput.tsx +++ b/src/features/chat/ui/ChatInput.tsx @@ -2,6 +2,7 @@ import { useState, useRef, useCallback, useEffect, useMemo } from "react"; import { X } from "lucide-react"; import type { AcpProvider } from "@/shared/api/acp"; import type { Persona } from "@/shared/types/agents"; +import type { ModelOption } from "@/shared/types/chat"; import { cn } from "@/shared/lib/cn"; import { Button } from "@/shared/ui/button"; import { @@ -15,13 +16,6 @@ import { ImageLightbox } from "@/shared/ui/ImageLightbox"; import type { PastedImage } from "@/shared/types/messages"; import { resizeImage } from "../lib/resizeImage"; -export interface ModelOption { - id: string; - name: string; - displayName?: string; - provider?: string; -} - export interface ProjectOption { id: string; name: string; @@ -48,6 +42,8 @@ interface ChatInputProps { onProviderChange?: (providerId: string) => void; // Model currentModel?: string; + currentModelId?: string; + modelsLoading?: boolean; availableModels?: ModelOption[]; onModelChange?: (modelId: string) => void; // Project @@ -134,6 +130,8 @@ export function ChatInput({ selectedProvider = "goose", onProviderChange, currentModel = "Claude Sonnet 4", + currentModelId, + modelsLoading = false, availableModels = [], onModelChange, selectedProjectId = null, @@ -425,6 +423,8 @@ export function ChatInput({ selectedProvider={selectedProvider} onProviderChange={(id) => onProviderChange?.(id)} currentModel={currentModel} + currentModelId={currentModelId} + modelsLoading={modelsLoading} availableModels={availableModels} onModelChange={onModelChange} selectedProjectId={selectedProjectId} diff --git a/src/features/chat/ui/ChatInputToolbar.tsx b/src/features/chat/ui/ChatInputToolbar.tsx index c88e5933..8be44928 100644 --- a/src/features/chat/ui/ChatInputToolbar.tsx +++ b/src/features/chat/ui/ChatInputToolbar.tsx @@ -9,11 +9,12 @@ import { } from "lucide-react"; import type { AcpProvider } from "@/shared/api/acp"; import type { Persona } from "@/shared/types/agents"; +import type { ModelOption } from "@/shared/types/chat"; import { cn } from "@/shared/lib/cn"; import { ChatInputSelector } from "./ChatInputSelector"; import { ContextRing } from "./ContextRing"; import { PersonaPicker } from "./PersonaPicker"; -import type { ModelOption, ProjectOption } from "./ChatInput"; +import type { ProjectOption } from "./ChatInput"; import { DropdownMenu, DropdownMenuTrigger, @@ -63,6 +64,8 @@ interface ChatInputToolbarProps { onProviderChange: (providerId: string) => void; // Model currentModel: string; + currentModelId?: string; + modelsLoading?: boolean; availableModels: ModelOption[]; onModelChange?: (modelId: string) => void; // Project @@ -97,6 +100,8 @@ export function ChatInputToolbar({ selectedProvider, onProviderChange, currentModel, + currentModelId, + modelsLoading = false, availableModels, onModelChange, selectedProjectId, @@ -131,6 +136,8 @@ export function ChatInputToolbar({ const providerLabel = availableProviderItems.find((provider) => provider.id === selectedProvider) ?.label ?? formatProviderLabel(selectedProvider); + const modelTriggerLabel = + currentModel || (modelsLoading ? "Loading models..." : "Select model"); const handleProjectValueChange = (value: string) => { if (value === CREATE_PROJECT_VALUE) { @@ -172,7 +179,10 @@ export function ChatInputToolbar({ /> )} - {availableModels.length > 0 && ( + {(availableModels.length > 0 || + Boolean(selectedProvider) || + currentModel || + modelsLoading) && ( - + Model {availableModels.map((model) => ( @@ -208,7 +226,7 @@ export function ChatInputToolbar({ )} - {model.id === currentModel && ( + {model.id === (currentModelId ?? currentModel) && ( )} diff --git a/src/features/chat/ui/ChatView.tsx b/src/features/chat/ui/ChatView.tsx index 22ee75ff..06640f1b 100644 --- a/src/features/chat/ui/ChatView.tsx +++ b/src/features/chat/ui/ChatView.tsx @@ -10,6 +10,7 @@ import { useProviderSelection } from "@/features/agents/hooks/useProviderSelecti import { useChatSessionStore } from "../stores/chatSessionStore"; import { getProject, type ProjectInfo } from "@/features/projects/api/projects"; import { useProjectStore } from "@/features/projects/stores/projectStore"; +import { acpGetModelState } from "@/shared/api/acp"; import { buildProjectSystemPrompt, composeSystemPrompt, @@ -66,6 +67,9 @@ export function ChatView({ const session = useChatSessionStore((s) => s.sessions.find((candidate) => candidate.id === activeSessionId), ); + const modelStateByProvider = useChatSessionStore( + (s) => s.modelStateByProvider, + ); const projects = useProjectStore((s) => s.projects); const storedProject = useProjectStore((s) => session?.projectId @@ -78,6 +82,7 @@ export function ChatView({ const [homeArtifactsRoot, setHomeArtifactsRoot] = useState( null, ); + const [modelsLoading, setModelsLoading] = useState(false); const project = storedProject ?? fallbackProject; const availableProjects = useMemo( () => @@ -99,6 +104,36 @@ export function ChatView({ globalSelectedProvider; const selectedPersona = personas.find((p) => p.id === selectedPersonaId); + const providerModelState = selectedProvider + ? modelStateByProvider[selectedProvider] + : undefined; + const availableModels = providerModelState?.availableModels ?? []; + const lastModelBootstrapKeyRef = useRef(null); + const bootstrapSessionId = useMemo( + () => + `chat-bootstrap-${activeSessionId}-${selectedProvider}-${selectedPersonaId ?? "default"}`, + [activeSessionId, selectedPersonaId, selectedProvider], + ); + const currentModel = useMemo(() => { + if (modelsLoading) return "Loading models..."; + // Session's saved display name is authoritative after model selection + if (session?.modelName) return session.modelName; + // Fall back to matching against available models list + const matchedModel = session?.currentModelId + ? availableModels.find((model) => model.id === session.currentModelId) + : undefined; + if (matchedModel) return matchedModel.displayName ?? matchedModel.name; + // Provider's reported current model + if (providerModelState?.currentModelName) + return providerModelState.currentModelName; + return ""; + }, [ + modelsLoading, + session?.modelName, + session?.currentModelId, + availableModels, + providerModelState?.currentModelName, + ]); const projectFolders = useMemo( () => getProjectFolderOption(project), [project], @@ -171,12 +206,43 @@ export function ChatView({ return; } + // Show a notification if there are already messages in the chat + const currentMessages = + useChatStore.getState().messagesBySession[activeSessionId] ?? []; + if (currentMessages.length > 0) { + const providerLabel = + providers.find((p) => p.id === providerId)?.label ?? providerId; + // Include the default model if we already have cached model state + const cachedState = + useChatSessionStore.getState().modelStateByProvider[providerId]; + const defaultModelName = + cachedState?.currentModelName ?? cachedState?.currentModelId; + const switchText = defaultModelName + ? `Switched to ${providerLabel} (${defaultModelName})` + : `Switched to ${providerLabel}`; + useChatStore.getState().addMessage(activeSessionId, { + id: crypto.randomUUID(), + role: "system", + created: Date.now(), + content: [ + { + type: "systemNotification", + notificationType: "info", + text: switchText, + }, + ], + metadata: { userVisible: true, agentVisible: false }, + }); + } + setGlobalSelectedProvider(providerId); - useChatSessionStore - .getState() - .updateSession(activeSessionId, { providerId }); + useChatSessionStore.getState().updateSession(activeSessionId, { + providerId, + currentModelId: undefined, + modelName: undefined, + }); }, - [activeSessionId, selectedProvider, setGlobalSelectedProvider], + [activeSessionId, providers, selectedProvider, setGlobalSelectedProvider], ); const handleProjectChange = useCallback( @@ -188,6 +254,156 @@ export function ChatView({ [activeSessionId], ); + const handleModelChange = useCallback( + (modelId: string) => { + if (modelId === session?.currentModelId) { + return; + } + + const matchedModel = availableModels.find( + (model) => model.id === modelId, + ); + const displayName = + matchedModel?.displayName ?? matchedModel?.name ?? modelId; + + // Update session — the pre-send sync in useChat will apply the model + // to the real ACP session when the user sends the next message. + useChatSessionStore.getState().updateSession(activeSessionId, { + currentModelId: modelId, + modelName: displayName, + }); + + // Show a system notification if there are already messages in the chat + const currentMessages = + useChatStore.getState().messagesBySession[activeSessionId] ?? []; + if (currentMessages.length > 0) { + useChatStore.getState().addMessage(activeSessionId, { + id: crypto.randomUUID(), + role: "system", + created: Date.now(), + content: [ + { + type: "systemNotification", + notificationType: "info", + text: `Switched to ${displayName}`, + }, + ], + metadata: { userVisible: true, agentVisible: false }, + }); + } + }, + [activeSessionId, availableModels, session?.currentModelId], + ); + + useEffect(() => { + if (!selectedProvider) { + return; + } + + // Don't fire until effectiveWorkingDir is resolved — avoids a wasted + // first bootstrap that gets immediately cancelled when homeArtifactsRoot arrives. + if (!effectiveWorkingDir) { + return; + } + + // If the shared store already has models for this provider (e.g. from + // HomeScreen), use them immediately and skip the expensive ACP spawn. + const cached = + useChatSessionStore.getState().modelStateByProvider[selectedProvider]; + if (cached && cached.availableModels.length > 0) { + // Read from store at call time to avoid closing over session?.currentModelId + // (which would cause cascading re-fires if added to the dependency array) + const desiredModelId = useChatSessionStore + .getState() + .getSession(activeSessionId)?.currentModelId; + const desiredModel = desiredModelId + ? cached.availableModels.find((model) => model.id === desiredModelId) + : undefined; + + useChatSessionStore.getState().setModelState(activeSessionId, { + providerId: selectedProvider, + source: cached.source, + configId: cached.configId, + currentModelId: desiredModelId ?? cached.currentModelId, + currentModelName: + desiredModel?.displayName ?? + desiredModel?.name ?? + cached.currentModelName, + availableModels: cached.availableModels, + }); + setModelsLoading(false); + return; + } + + const bootstrapKey = [ + bootstrapSessionId, + selectedProvider, + selectedPersonaId ?? "", + effectiveWorkingDir, + ].join(":"); + + if (lastModelBootstrapKeyRef.current === bootstrapKey) { + return; + } + + let cancelled = false; + lastModelBootstrapKeyRef.current = bootstrapKey; + setModelsLoading(true); + + acpGetModelState(bootstrapSessionId, selectedProvider, { + personaId: selectedPersonaId ?? undefined, + workingDir: effectiveWorkingDir, + persistSession: false, + }) + .then((modelState) => { + if (cancelled) { + return; + } + + const desiredModelId = useChatSessionStore + .getState() + .getSession(activeSessionId)?.currentModelId; + const desiredModel = modelState.availableModels.find( + (model) => model.id === desiredModelId, + ); + + useChatSessionStore.getState().setModelState(activeSessionId, { + providerId: selectedProvider, + source: modelState.source, + configId: modelState.configId, + currentModelId: desiredModelId ?? modelState.currentModelId, + currentModelName: + desiredModel?.displayName ?? + desiredModel?.name ?? + modelState.currentModelName, + availableModels: modelState.availableModels, + }); + }) + .catch((error) => { + if (cancelled) { + return; + } + + console.error("Failed to bootstrap ACP model state:", error); + lastModelBootstrapKeyRef.current = null; + }) + .finally(() => { + // Always reset loading — if cancelled, the next effect will set it true again. + // Leaving it true on cancellation causes permanently stuck "Loading models..." + setModelsLoading(false); + }); + + return () => { + cancelled = true; + }; + }, [ + activeSessionId, + bootstrapSessionId, + effectiveWorkingDir, + selectedPersonaId, + selectedProvider, + ]); + // When persona changes, update the provider to match persona's default const handlePersonaChange = useCallback( (personaId: string | null) => { @@ -313,13 +529,20 @@ export function ChatView({ useEffect(() => { if ( (initialMessage || initialImages?.length) && - !initialMessageSent.current + !initialMessageSent.current && + !modelsLoading ) { initialMessageSent.current = true; handleSend(initialMessage ?? "", undefined, initialImages); onInitialMessageConsumed?.(); } - }, [initialMessage, initialImages, handleSend, onInitialMessageConsumed]); + }, [ + initialMessage, + initialImages, + handleSend, + modelsLoading, + onInitialMessageConsumed, + ]); const isStreaming = chatState === "streaming"; const showIndicator = @@ -370,6 +593,11 @@ export function ChatView({ providersLoading={providersLoading} selectedProvider={selectedProvider} onProviderChange={handleProviderChange} + currentModel={currentModel} + currentModelId={session?.currentModelId} + modelsLoading={modelsLoading} + availableModels={availableModels} + onModelChange={handleModelChange} selectedProjectId={session?.projectId ?? null} availableProjects={availableProjects} onProjectChange={handleProjectChange} diff --git a/src/features/chat/ui/__tests__/ChatInput.test.tsx b/src/features/chat/ui/__tests__/ChatInput.test.tsx index 89ed74f0..c1fc736b 100644 --- a/src/features/chat/ui/__tests__/ChatInput.test.tsx +++ b/src/features/chat/ui/__tests__/ChatInput.test.tsx @@ -85,6 +85,7 @@ describe("ChatInput", () => { , ); @@ -101,6 +102,44 @@ describe("ChatInput", () => { expect(screen.getByText("Claude Sonnet 4")).toBeInTheDocument(); }); + it("calls onModelChange when selecting a model", async () => { + const onModelChange = vi.fn(); + const user = userEvent.setup(); + + render( + , + ); + + await user.click(screen.getByRole("button", { name: /select model/i })); + await user.click(screen.getByText("GPT-4.1")); + + expect(onModelChange).toHaveBeenCalledWith("gpt-4.1"); + }); + + it("shows a disabled loading model picker before models are available", () => { + render( + , + ); + + expect( + screen.getByRole("button", { name: /select model/i }), + ).toBeDisabled(); + expect(screen.getByText("Loading models...")).toBeInTheDocument(); + }); + it("shows default provider label", () => { render( ({ { id: "goose", label: "Goose" }, { id: "openai", label: "OpenAI" }, ]), + acpGetModelState: vi.fn().mockResolvedValue({ + source: "session_model", + currentModelId: "claude-sonnet-4", + currentModelName: "Claude Sonnet 4", + availableModels: [ + { id: "claude-sonnet-4", name: "Claude Sonnet 4" }, + { id: "opus-1m", name: "Opus (1M context)" }, + ], + }), })); vi.mock("@/features/agents/hooks/useProviderSelection", () => ({ diff --git a/src/features/home/ui/HomeScreen.tsx b/src/features/home/ui/HomeScreen.tsx index 1f773f30..2520435c 100644 --- a/src/features/home/ui/HomeScreen.tsx +++ b/src/features/home/ui/HomeScreen.tsx @@ -1,4 +1,4 @@ -import { useState, useEffect, useCallback } from "react"; +import { useState, useEffect, useCallback, useMemo } from "react"; import { getStoredProvider, useAgentStore, @@ -7,6 +7,8 @@ import { useProviderSelection } from "@/features/agents/hooks/useProviderSelecti import { ChatInput } from "@/features/chat/ui/ChatInput"; import type { PastedImage } from "@/shared/types/messages"; import { useProjectStore } from "@/features/projects/stores/projectStore"; +import { acpGetModelState } from "@/shared/api/acp"; +import { useChatSessionStore } from "@/features/chat/stores/chatSessionStore"; function HomeClock() { const [time, setTime] = useState(new Date()); @@ -46,6 +48,8 @@ interface HomeScreenProps { providerId?: string, personaId?: string, projectId?: string | null, + modelId?: string, + modelName?: string, images?: PastedImage[], ) => void; onCreateProject?: (options?: { @@ -79,6 +83,28 @@ export function HomeScreen({ const [selectedProjectId, setSelectedProjectId] = useState( null, ); + const modelStateByProvider = useChatSessionStore( + (s) => s.modelStateByProvider, + ); + const providerModelState = selectedProvider + ? modelStateByProvider[selectedProvider] + : undefined; + const availableModels = providerModelState?.availableModels ?? []; + const selectedModelId = providerModelState?.currentModelId; + const selectedModelName = providerModelState?.currentModelName; + const [modelsLoading, setModelsLoading] = useState(false); + const selectedProject = useMemo( + () => + selectedProjectId + ? (projects.find((project) => project.id === selectedProjectId) ?? null) + : null, + [projects, selectedProjectId], + ); + const bootstrapSessionId = useMemo( + () => + `home-bootstrap-${selectedProvider}-${selectedPersonaId ?? "default"}`, + [selectedPersonaId, selectedProvider], + ); const handlePersonaChange = useCallback( (personaId: string | null) => { @@ -106,12 +132,66 @@ export function HomeScreen({ selectedProvider, effectivePersonaId, selectedProjectId, + selectedModelId, + selectedModelName, images, ); }, - [onStartChat, selectedPersonaId, selectedProjectId, selectedProvider], + [ + onStartChat, + selectedModelId, + selectedModelName, + selectedPersonaId, + selectedProjectId, + selectedProvider, + ], ); + useEffect(() => { + let cancelled = false; + setModelsLoading(true); + + acpGetModelState(bootstrapSessionId, selectedProvider, { + personaId: selectedPersonaId ?? undefined, + workingDir: selectedProject?.workingDirs[0] ?? undefined, + persistSession: false, + }) + .then((modelState) => { + if (cancelled) { + return; + } + + useChatSessionStore.getState().setModelState("home", { + providerId: selectedProvider, + source: modelState.source, + configId: modelState.configId, + currentModelId: modelState.currentModelId, + currentModelName: + modelState.currentModelName ?? modelState.currentModelId, + availableModels: modelState.availableModels, + }); + }) + .catch((error) => { + if (cancelled) { + return; + } + + console.error("Failed to bootstrap home model state:", error); + }) + .finally(() => { + setModelsLoading(false); + }); + + return () => { + cancelled = true; + }; + }, [ + bootstrapSessionId, + selectedPersonaId, + selectedProject, + selectedProvider, + ]); + return (
@@ -135,6 +215,28 @@ export function HomeScreen({ providersLoading={providersLoading} selectedProvider={selectedProvider} onProviderChange={setSelectedProvider} + currentModel={ + modelsLoading ? "Loading models..." : (selectedModelName ?? "") + } + currentModelId={selectedModelId} + modelsLoading={modelsLoading} + availableModels={availableModels} + onModelChange={(modelId) => { + const model = availableModels.find( + (candidate) => candidate.id === modelId, + ); + if (selectedProvider) { + useChatSessionStore.getState().setModelState("home", { + providerId: selectedProvider, + source: providerModelState?.source ?? "session_model", + configId: providerModelState?.configId, + currentModelId: modelId, + currentModelName: + model?.displayName ?? model?.name ?? modelId, + availableModels, + }); + } + }} selectedProjectId={selectedProjectId} availableProjects={projects.map((project) => ({ id: project.id, diff --git a/src/shared/api/acp.ts b/src/shared/api/acp.ts index 6eb22adb..31934a3b 100644 --- a/src/shared/api/acp.ts +++ b/src/shared/api/acp.ts @@ -1,4 +1,5 @@ import { invoke } from "@tauri-apps/api/core"; +import type { ModelOption, ModelSelectionSource } from "@/shared/types/chat"; export interface AcpProvider { id: string; @@ -14,6 +15,14 @@ export interface AcpSendMessageOptions { images?: [string, string][]; } +export interface AcpModelState { + source: ModelSelectionSource; + configId?: string; + currentModelId: string; + currentModelName?: string; + availableModels: ModelOption[]; +} + /** Discover ACP providers installed on the system. */ export async function discoverAcpProviders(): Promise { return invoke("discover_acp_providers"); @@ -49,3 +58,45 @@ export async function acpCancelSession( personaId: personaId ?? null, }); } + +/** Load or create an ACP session and return its model state. */ +export async function acpGetModelState( + sessionId: string, + providerId: string, + options: { + personaId?: string; + workingDir?: string; + persistSession?: boolean; + } = {}, +): Promise { + return invoke("acp_get_model_state", { + sessionId, + providerId, + personaId: options.personaId ?? null, + workingDir: options.workingDir ?? null, + persistSession: options.persistSession ?? true, + }); +} + +/** Set the active model for an existing ACP session. */ +export async function acpSetModel( + sessionId: string, + providerId: string, + modelId: string, + options: { + source?: ModelSelectionSource; + configId?: string; + personaId?: string; + workingDir?: string; + } = {}, +): Promise { + return invoke("acp_set_model", { + sessionId, + providerId, + modelId, + source: options.source ?? "session_model", + configId: options.configId ?? null, + personaId: options.personaId ?? null, + workingDir: options.workingDir ?? null, + }); +} diff --git a/src/shared/types/chat.ts b/src/shared/types/chat.ts index a0a161bc..88ddf667 100644 --- a/src/shared/types/chat.ts +++ b/src/shared/types/chat.ts @@ -55,6 +55,7 @@ export interface Session { projectId?: string | null; providerId?: string; personaId?: string; + currentModelId?: string; modelName?: string; createdAt: string; updatedAt: string; @@ -62,6 +63,24 @@ export interface Session { messageCount: number; } +export interface ModelOption { + id: string; + name: string; + displayName?: string; + description?: string; + provider?: string; +} + +export type ModelSelectionSource = "session_model" | "config_option"; + +export interface ProviderModelState { + source: ModelSelectionSource; + configId?: string; + currentModelId: string; + currentModelName?: string; + availableModels: ModelOption[]; +} + // SSE event types (from goosed server) export type MessageEventType = | "message"