Skip to content

Commit 827d639

Browse files
committed
Fix provider model lookup
1 parent 2a6ffb1 commit 827d639

File tree

1 file changed

+148
-6
lines changed

1 file changed

+148
-6
lines changed

src-tauri/src/shared/codex_core.rs

Lines changed: 148 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,149 @@ fn normalize_optional_string(value: Option<String>) -> Option<String> {
985985
.filter(|raw| !raw.is_empty())
986986
}
987987

988+
fn provider_has_model(config_providers: &Value, provider_id: &str, model_id: &str) -> bool {
989+
let provider_id = provider_id.trim();
990+
let model_id = model_id.trim();
991+
if provider_id.is_empty() || model_id.is_empty() {
992+
return false;
993+
}
994+
995+
let providers = config_providers
996+
.get("providers")
997+
.and_then(|v| v.as_array())
998+
.cloned()
999+
.unwrap_or_default();
1000+
1001+
providers.iter().any(|provider| {
1002+
let pid = provider
1003+
.get("id")
1004+
.and_then(|v| v.as_str())
1005+
.map(str::trim)
1006+
.unwrap_or_default();
1007+
if pid != provider_id {
1008+
return false;
1009+
}
1010+
let models = provider.get("models").and_then(|v| v.as_object());
1011+
if let Some(models_map) = models {
1012+
if models_map.contains_key(model_id) {
1013+
return true;
1014+
}
1015+
return models_map.values().any(|model| {
1016+
model.get("id")
1017+
.and_then(|v| v.as_str())
1018+
.map(str::trim)
1019+
.unwrap_or_default()
1020+
== model_id
1021+
});
1022+
}
1023+
false
1024+
})
1025+
}
1026+
1027+
fn resolve_model_override_from_providers(config_providers: &Value, requested_model: &str) -> Option<Value> {
1028+
let requested = requested_model.trim();
1029+
if requested.is_empty() {
1030+
return None;
1031+
}
1032+
1033+
if let Some((provider_id, model_id)) = requested.split_once('/') {
1034+
let provider_id = provider_id.trim();
1035+
let model_id = model_id.trim();
1036+
if provider_id.is_empty() || model_id.is_empty() {
1037+
return None;
1038+
}
1039+
if provider_has_model(config_providers, provider_id, model_id) {
1040+
return Some(json!({ "providerID": provider_id, "modelID": model_id }));
1041+
}
1042+
return None;
1043+
}
1044+
1045+
let providers = config_providers
1046+
.get("providers")
1047+
.and_then(|v| v.as_array())
1048+
.cloned()
1049+
.unwrap_or_default();
1050+
let mut matches: Vec<String> = Vec::new();
1051+
1052+
for provider in &providers {
1053+
let provider_id = provider
1054+
.get("id")
1055+
.and_then(|v| v.as_str())
1056+
.map(str::trim)
1057+
.unwrap_or_default();
1058+
if provider_id.is_empty() {
1059+
continue;
1060+
}
1061+
let Some(models_map) = provider.get("models").and_then(|v| v.as_object()) else {
1062+
continue;
1063+
};
1064+
let found = models_map.contains_key(requested)
1065+
|| models_map.values().any(|model| {
1066+
model.get("id")
1067+
.and_then(|v| v.as_str())
1068+
.map(str::trim)
1069+
.unwrap_or_default()
1070+
== requested
1071+
});
1072+
if found {
1073+
matches.push(provider_id.to_string());
1074+
}
1075+
}
1076+
1077+
if matches.len() == 1 {
1078+
return Some(json!({ "providerID": matches[0], "modelID": requested }));
1079+
}
1080+
1081+
None
1082+
}
1083+
1084+
async fn resolve_prompt_model_override(
1085+
session: &WorkspaceSession,
1086+
requested_model: &str,
1087+
) -> Option<Value> {
1088+
let requested = requested_model.trim();
1089+
if requested.is_empty() {
1090+
return None;
1091+
}
1092+
1093+
let has_provider = requested.contains('/');
1094+
1095+
let cached = session.models_cache.lock().await.clone();
1096+
if let Some(cache) = cached {
1097+
if let Some(model) = resolve_model_override_from_providers(&cache, requested) {
1098+
return Some(model);
1099+
}
1100+
// If the client sent a qualified id but it's no longer present in the
1101+
// current provider list, silently fall back to server default instead
1102+
// of sending an invalid override that yields a scary 400.
1103+
if has_provider {
1104+
return None;
1105+
}
1106+
}
1107+
1108+
// Best-effort refresh for legacy unqualified ids when cache is stale/missing.
1109+
if !has_provider {
1110+
if let Ok(fresh) = session.rest_get("/config/providers").await {
1111+
*session.models_cache.lock().await = Some(fresh.clone());
1112+
if let Some(model) = resolve_model_override_from_providers(&fresh, requested) {
1113+
return Some(model);
1114+
}
1115+
}
1116+
}
1117+
1118+
// If already qualified and no cache was available to validate it, keep the
1119+
// explicit override rather than dropping a likely-valid user choice.
1120+
if let Some((provider_id, model_id)) = requested.split_once('/') {
1121+
let provider_id = provider_id.trim();
1122+
let model_id = model_id.trim();
1123+
if !provider_id.is_empty() && !model_id.is_empty() {
1124+
return Some(json!({ "providerID": provider_id, "modelID": model_id }));
1125+
}
1126+
}
1127+
1128+
None
1129+
}
1130+
9881131
fn emit_turn_error<E: EventSink>(
9891132
event_sink: &E,
9901133
workspace_id: &str,
@@ -1086,12 +1229,11 @@ pub(crate) async fn send_user_message_core<E: EventSink>(
10861229
let requested_model = normalize_optional_string(model);
10871230
let requested_effort = normalize_optional_string(effort);
10881231
if let Some(ref model_id) = requested_model {
1089-
// REST API requires model as { providerID, modelID }.
1090-
// The frontend sends a qualified "provider/model" string.
1091-
if let Some((provider, mid)) = model_id.split_once('/') {
1092-
body["model"] = json!({ "providerID": provider, "modelID": mid });
1093-
} else {
1094-
body["model"] = json!({ "modelID": model_id });
1232+
// Be resilient to stale or legacy (unqualified) persisted selections.
1233+
// If we cannot safely resolve a valid `{ providerID, modelID }`, omit the
1234+
// override and let OpenCode use its current default model.
1235+
if let Some(model_override) = resolve_prompt_model_override(session.as_ref(), model_id).await {
1236+
body["model"] = model_override;
10951237
}
10961238
}
10971239
if let Some(ref effort_level) = requested_effort {

0 commit comments

Comments
 (0)