diff --git a/docs/README.md b/docs/README.md index c4bf30e04..cd134c03c 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,6 +1,6 @@ # DeepChat 文档索引 -本文档反映 `2026-03-23` 完成的 legacy `AgentPresenter` retirement 之后的代码结构。 +本文档反映 `2026-04-11` 完成 legacy `AgentPresenter` retirement 与 legacy provider runtime retirement 之后的代码结构。 当前聊天主链路已经收敛为: @@ -36,6 +36,13 @@ Renderer | [docs/specs/legacy-agentpresenter-retirement/spec.md](./specs/legacy-agentpresenter-retirement/spec.md) | 本次 retirement 的目标、范围、兼容边界 | | [docs/specs/legacy-agentpresenter-retirement/plan.md](./specs/legacy-agentpresenter-retirement/plan.md) | 迁移/归档/验证计划 | | [docs/specs/legacy-agentpresenter-retirement/tasks.md](./specs/legacy-agentpresenter-retirement/tasks.md) | 已执行清单 | +| [docs/specs/legacy-llm-provider-runtime-retirement/spec.md](./specs/legacy-llm-provider-runtime-retirement/spec.md) | legacy provider runtime retirement 规格 | +| [docs/specs/legacy-llm-provider-runtime-retirement/plan.md](./specs/legacy-llm-provider-runtime-retirement/plan.md) | provider runtime 收口与依赖清理计划 | +| [docs/specs/legacy-llm-provider-runtime-retirement/tasks.md](./specs/legacy-llm-provider-runtime-retirement/tasks.md) | provider runtime 退役执行清单 | +| [docs/specs/provider-layer-simplification/spec.md](./specs/provider-layer-simplification/spec.md) | provider layer 第二轮内部收口规格 | +| [docs/specs/provider-layer-simplification/plan.md](./specs/provider-layer-simplification/plan.md) | registry + generic provider 合并计划 | +| [docs/specs/provider-layer-simplification/tasks.md](./specs/provider-layer-simplification/tasks.md) | provider layer 第二轮执行清单 | +| [docs/specs/ai-sdk-runtime/spec.md](./specs/ai-sdk-runtime/spec.md) | AI SDK runtime 规格,现已更新为 retired 状态 | | [docs/specs/architecture-simplification/spec.md](./specs/architecture-simplification/spec.md) | 整体减负治理规格 | | [docs/specs/architecture-simplification/plan.md](./specs/architecture-simplification/plan.md) | 分层/基线/guard 计划 | | [docs/specs/architecture-simplification/tasks.md](./specs/architecture-simplification/tasks.md) | 首期实施清单 | @@ -62,10 +69,14 @@ docs/ ├── specs/ │ ├── agent-cleanup/ │ ├── architecture-simplification/ +│ ├── ai-sdk-runtime/ +│ ├── provider-layer-simplification/ +│ ├── legacy-llm-provider-runtime-retirement/ │ └── legacy-agentpresenter-retirement/ └── archives/ ├── legacy-agentpresenter-architecture.md ├── legacy-agentpresenter-flows.md + ├── legacy-llm-provider-runtime.md ├── thread-presenter-migration-plan.md └── workspace-agent-refactoring-summary.md ``` @@ -78,6 +89,7 @@ docs/ | --- | --- | | [archives/legacy-agentpresenter-architecture.md](./archives/legacy-agentpresenter-architecture.md) | 旧 `AgentPresenter` 架构快照 | | [archives/legacy-agentpresenter-flows.md](./archives/legacy-agentpresenter-flows.md) | 旧 `startStreamCompletion` / permission / loop 流程 | +| [archives/legacy-llm-provider-runtime.md](./archives/legacy-llm-provider-runtime.md) | 旧 provider runtime 的历史归档与提交锚点 | | [archives/thread-presenter-migration-plan.md](./archives/thread-presenter-migration-plan.md) | 历史迁移设计 | | [archives/workspace-agent-refactoring-summary.md](./archives/workspace-agent-refactoring-summary.md) | 历史工作区改造总结 | diff --git a/docs/architecture/baselines/dependency-report.md b/docs/architecture/baselines/dependency-report.md index 46a2b1137..0d4dcf7cb 100644 --- a/docs/architecture/baselines/dependency-report.md +++ b/docs/architecture/baselines/dependency-report.md @@ -32,7 +32,6 @@ Generated on 2026-04-03. - `eventbus.ts`: 57 - `presenter/index.ts`: 44 - `presenter/llmProviderPresenter/runtimePorts.ts`: 34 -- `presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts`: 24 - `presenter/llmProviderPresenter/baseProvider.ts`: 19 - `presenter/remoteControlPresenter/types.ts`: 18 - `presenter/sqlitePresenter/tables/baseTable.ts`: 16 @@ -51,9 +50,7 @@ Generated on 2026-04-03. - `presenter/index.ts -> presenter/windowPresenter/index.ts -> presenter/windowPresenter/FloatingChatWindow.ts -> presenter/index.ts` - `presenter/index.ts -> presenter/shortcutPresenter.ts -> presenter/index.ts` - `presenter/index.ts -> presenter/llmProviderPresenter/index.ts -> presenter/llmProviderPresenter/baseProvider.ts -> presenter/devicePresenter/index.ts -> presenter/index.ts` -- `presenter/index.ts -> presenter/llmProviderPresenter/index.ts -> presenter/llmProviderPresenter/managers/providerInstanceManager.ts -> presenter/llmProviderPresenter/providers/deepseekProvider.ts -> presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts -> presenter/index.ts` - `presenter/index.ts -> presenter/llmProviderPresenter/index.ts -> presenter/llmProviderPresenter/managers/providerInstanceManager.ts -> presenter/llmProviderPresenter/providers/githubCopilotProvider.ts -> presenter/githubCopilotDeviceFlow.ts -> presenter/index.ts` -- `presenter/index.ts -> presenter/llmProviderPresenter/index.ts -> presenter/llmProviderPresenter/managers/providerInstanceManager.ts -> presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts -> presenter/index.ts` - `presenter/filePresenter/mime.ts -> presenter/filePresenter/CsvFileAdapter.ts -> presenter/filePresenter/BaseFileAdapter.ts -> presenter/filePresenter/mime.ts` - `presenter/index.ts -> presenter/sessionPresenter/index.ts -> presenter/index.ts` - `presenter/index.ts -> presenter/sessionPresenter/index.ts -> presenter/sessionPresenter/managers/conversationManager.ts -> presenter/index.ts` @@ -115,4 +112,3 @@ Generated on 2026-04-03. - `components/json-viewer/JsonValue.ts -> components/json-viewer/JsonObject.ts -> components/json-viewer/JsonValue.ts` - `components/json-viewer/JsonArray.ts -> components/json-viewer/JsonValue.ts -> components/json-viewer/JsonArray.ts` - `composables/usePageCapture.example.ts -> composables/usePageCapture.example.ts` - diff --git a/docs/architecture/baselines/zero-inbound-candidates.md b/docs/architecture/baselines/zero-inbound-candidates.md index 907b5d066..bbaa9505a 100644 --- a/docs/architecture/baselines/zero-inbound-candidates.md +++ b/docs/architecture/baselines/zero-inbound-candidates.md @@ -6,7 +6,7 @@ These files have no in-repo importers inside their scope and need manual classif ## main -- Candidate count: 16 +- Candidate count: 15 - `env.d.ts` - `lib/system.ts` @@ -14,7 +14,6 @@ These files have no in-repo importers inside their scope and need manual classif - `presenter/browser/BrowserContextBuilder.ts` - `presenter/configPresenter/aes.ts` - `presenter/llmProviderPresenter/oauthHelper.ts` -- `presenter/llmProviderPresenter/providers/openAIProvider.ts` - `presenter/mcpPresenter/agentMcpFilter.ts` - `presenter/searchPrompts/searchPrompts.ts` - `presenter/sessionPresenter/events.ts` @@ -73,4 +72,3 @@ These files have no in-repo importers inside their scope and need manual classif - `stores/systemPromptStore.ts` - `utils/maxOutputTokens.ts` - `views/SettingsTabView.vue` - diff --git a/docs/archives/legacy-llm-provider-runtime.md b/docs/archives/legacy-llm-provider-runtime.md new file mode 100644 index 000000000..52ba04296 --- /dev/null +++ b/docs/archives/legacy-llm-provider-runtime.md @@ -0,0 +1,38 @@ +# Legacy LLM Provider Runtime Archive + +## Summary + +DeepChat previously maintained two low-level provider runtimes under `llmProviderPresenter`: + +- the original provider-specific SDK implementations +- the newer shared AI SDK runtime + +That rollback window is now closed. The active codebase only keeps the AI SDK runtime. + +## Timeline + +- AI SDK migration landed in commit `4c8345a7` +- Legacy runtime retirement and dependency cleanup landed after the migration stabilized + +## Where To Find The Old Provider Implementation + +Use commit `3add4093b46f15072d5ec3a65c8097e23b4907c4` to inspect the historical provider implementation and legacy runtime code. + +That commit is the canonical source for: + +- legacy provider request code +- legacy stream parsing branches +- provider-specific MCP conversion APIs +- legacy rollback-path wiring + +## Current State + +- no `DEEPCHAT_LLM_RUNTIME` +- no `llmRuntimeMode` +- no legacy provider SDK fallback branches in active providers +- no provider-specific MCP conversion APIs exposed from presenters + +For current implementation details, read: + +- [docs/specs/ai-sdk-runtime/spec.md](../specs/ai-sdk-runtime/spec.md) +- [docs/specs/legacy-llm-provider-runtime-retirement/spec.md](../specs/legacy-llm-provider-runtime-retirement/spec.md) diff --git a/docs/specs/ai-sdk-runtime/plan.md b/docs/specs/ai-sdk-runtime/plan.md new file mode 100644 index 000000000..ae4b83d9e --- /dev/null +++ b/docs/specs/ai-sdk-runtime/plan.md @@ -0,0 +1,10 @@ +# AI SDK Runtime Plan + +Status: completed, rollback retired. See [../legacy-llm-provider-runtime-retirement/plan.md](../legacy-llm-provider-runtime-retirement/plan.md). + +1. Introduce shared AI SDK runtime modules without changing upper-layer interfaces. +2. Migrate OpenAI-compatible and OpenAI responses providers first. +3. Migrate Anthropic / Gemini / Vertex / Bedrock / Ollama to the shared runtime. +4. Keep routing providers (`new-api`, `zenmux`) as thin delegates over migrated providers. +5. Freeze `LLMCoreStreamEvent` behavior with adapter-focused tests. +6. Retire the rollback path and delete legacy state machines. diff --git a/docs/specs/ai-sdk-runtime/spec.md b/docs/specs/ai-sdk-runtime/spec.md new file mode 100644 index 000000000..e32f579f0 --- /dev/null +++ b/docs/specs/ai-sdk-runtime/spec.md @@ -0,0 +1,93 @@ +# AI SDK Runtime Spec + +## Status + +Completed in commit `4c8345a7`. + +As of `2026-04-11`, the rollback path is retired. See [../legacy-llm-provider-runtime-retirement/spec.md](../legacy-llm-provider-runtime-retirement/spec.md). + +## Goal + +Unify DeepChat's low-level LLM request pipeline on Vercel AI SDK while keeping the upper-layer contracts unchanged: + +- `BaseLLMProvider` +- `LLMProviderPresenter` +- `LLMCoreStreamEvent` +- existing provider IDs, model configs, and conversation history + +The AI SDK runtime is the only remaining implementation. + +## Non-Negotiable Compatibility + +- No functional regression in text streaming, reasoning streaming, tool call streaming, image output, prompt cache, proxy handling, request tracing, routing, and embeddings. +- `LLMCoreStreamEvent` event names, field names, and stop reasons remain unchanged. +- Existing `function_call_record` history must stay reusable across providers. +- Existing provider list / model list / provider check / key status responsibilities remain in provider classes. + +## Runtime Mode + +- Single runtime: `ai-sdk` +- `DEEPCHAT_LLM_RUNTIME` has been removed +- config setting `llmRuntimeMode` has been removed + +## Scope + +Shared runtime under `src/main/presenter/llmProviderPresenter/aiSdk/` provides: + +- provider factory +- model / message mapper +- MCP tool mapper +- streaming adapter +- image runtime +- embedding runtime +- provider-options mapper +- reasoning middleware +- legacy function-call compatibility middleware + +## Provider Rollout + +Phase 1: + +- `OpenAICompatibleProvider` +- `OpenAIResponsesProvider` +- all `extends OpenAICompatibleProvider` providers + +Phase 2: + +- `AnthropicProvider` +- `GeminiProvider` +- `VertexProvider` +- `AwsBedrockProvider` +- `OllamaProvider` + +Phase 3: + +- `NewApiProvider` +- `ZenmuxProvider` + +Out of scope for first unification pass: + +- `AcpProvider` +- `VoiceAIProvider` + +## Validation Matrix + +- pure text +- reasoning native +- reasoning via `` +- native tool streaming +- legacy `` fallback +- multi-tool history replay +- image input +- image output +- usage mapping +- prompt cache mapping +- proxy / trace / abort +- embeddings +- retired rollback path verification + +## Legacy Removal Exit Criteria + +- AI SDK runtime passes the provider regression matrix +- duplicated legacy stream parsers / tool parsers have no remaining callers +- retirement is documented in [../legacy-llm-provider-runtime-retirement/spec.md](../legacy-llm-provider-runtime-retirement/spec.md) diff --git a/docs/specs/ai-sdk-runtime/tasks.md b/docs/specs/ai-sdk-runtime/tasks.md new file mode 100644 index 000000000..0e910dc91 --- /dev/null +++ b/docs/specs/ai-sdk-runtime/tasks.md @@ -0,0 +1,17 @@ +# AI SDK Runtime Tasks + +- [x] Freeze migration scope in SDD docs. +- [x] Add shared AI SDK runtime modules. (implemented in 4c8345a7) +- [x] Integrate OpenAI-compatible runtime path. (implemented in 4c8345a7) +- [x] Integrate OpenAI responses runtime path. (implemented in 4c8345a7) +- [x] Integrate Anthropic runtime path. (implemented in 4c8345a7) +- [x] Integrate Gemini runtime path. (implemented in 4c8345a7) +- [x] Integrate Vertex runtime path. (implemented in 4c8345a7) +- [x] Integrate Bedrock runtime path. (implemented in 4c8345a7) +- [x] Integrate Ollama runtime path. (implemented in 4c8345a7) +- [x] Add regression tests for runtime adapter behavior. (implemented in 4c8345a7) +- [x] Retire `DEEPCHAT_LLM_RUNTIME` and `llmRuntimeMode`. +- [x] Remove legacy provider SDK fallback branches. +- [x] Remove provider-specific MCP conversion interfaces from presenter ports. +- [x] Record retirement history in [../legacy-llm-provider-runtime-retirement/tasks.md](../legacy-llm-provider-runtime-retirement/tasks.md). +- [x] Run format, i18n, lint, and targeted tests. diff --git a/docs/specs/legacy-llm-provider-runtime-retirement/plan.md b/docs/specs/legacy-llm-provider-runtime-retirement/plan.md new file mode 100644 index 000000000..9c4426e48 --- /dev/null +++ b/docs/specs/legacy-llm-provider-runtime-retirement/plan.md @@ -0,0 +1,20 @@ +# Legacy Provider Runtime Retirement Plan + +## Outcome + +Legacy provider runtime retirement is complete and no rollback path remains in the codebase. + +## Executed Plan + +1. Remove runtime selection and make AI SDK the only request runtime. +2. Collapse provider implementations onto shared AI SDK helpers. +3. Remove legacy MCP tool conversion surface from presenter interfaces. +4. Delete obsolete provider SDK dependencies and refresh lockfiles. +5. Archive the migration history and point readers to the final legacy-code commit. + +## Exit Conditions + +- No remaining source imports of `openai`, `@anthropic-ai/sdk`, `@google/genai`, `together-ai`, or `@aws-sdk/client-bedrock-runtime` +- No `DEEPCHAT_LLM_RUNTIME` or `llmRuntimeMode` references remain +- Main-process provider tests validate AI SDK-only behavior +- Documentation explicitly marks the rollback path as retired diff --git a/docs/specs/legacy-llm-provider-runtime-retirement/spec.md b/docs/specs/legacy-llm-provider-runtime-retirement/spec.md new file mode 100644 index 000000000..f4b4cd730 --- /dev/null +++ b/docs/specs/legacy-llm-provider-runtime-retirement/spec.md @@ -0,0 +1,45 @@ +# Legacy Provider Runtime Retirement Spec + +## Status + +Completed on `2026-04-11`. + +The hidden rollback window is closed. `llmProviderPresenter` now runs on AI SDK only. + +## Goal + +Retire the legacy provider runtime and remove obsolete provider SDK dependencies without changing upper-layer contracts: + +- `BaseLLMProvider` +- `LLMProviderPresenter` +- `LLMCoreStreamEvent` +- existing provider IDs, model configs, conversation history, and `function_call_record` compatibility + +## Scope + +- Remove `DEEPCHAT_LLM_RUNTIME` +- Remove config key `llmRuntimeMode` +- Delete `src/main/presenter/llmProviderPresenter/aiSdk/runtimeMode.ts` +- Remove legacy-only provider branches, stream parsers, and MCP conversion ports +- Keep provider-managed responsibilities that still matter: + - `ollama` local model management + - `@aws-sdk/client-bedrock` model discovery + +## Runtime State After Retirement + +- Single runtime: `ai-sdk` +- No hidden fallback +- No provider-specific MCP conversion APIs exposed from presenters +- Vendor-specific request body customization is handled via AI SDK provider options mapping + +## Historical Anchors + +- AI SDK migration landed in commit `4c8345a7` +- Legacy provider implementation can be inspected at commit `3add4093b46f15072d5ec3a65c8097e23b4907c4` + +## Compatibility Commitments + +- `LLMCoreStreamEvent` names and fields remain unchanged +- Provider IDs and provider settings remain unchanged +- Existing message history and `function_call_record` remain reusable +- Routing providers (`new-api`, `zenmux`) stay as thin delegates over migrated providers diff --git a/docs/specs/legacy-llm-provider-runtime-retirement/tasks.md b/docs/specs/legacy-llm-provider-runtime-retirement/tasks.md new file mode 100644 index 000000000..067666dd6 --- /dev/null +++ b/docs/specs/legacy-llm-provider-runtime-retirement/tasks.md @@ -0,0 +1,17 @@ +# Legacy Provider Runtime Retirement Tasks + +- [x] Delete `src/main/presenter/llmProviderPresenter/aiSdk/runtimeMode.ts` +- [x] Remove `DEEPCHAT_LLM_RUNTIME` and `llmRuntimeMode` +- [x] Convert provider request paths to AI SDK-only implementations +- [x] Remove provider-specific MCP tool conversion interfaces from presenter ports +- [x] Replace legacy SDK type imports with local neutral types where still needed +- [x] Remove obsolete provider SDK dependencies from `package.json` +- [x] Rewrite provider tests around AI SDK runtime helpers and delegate routing +- [x] Archive legacy runtime history and document the last legacy-code commit +- [x] Run `pnpm install` +- [x] Run `pnpm run format` +- [x] Run `pnpm run i18n` +- [x] Run `pnpm run lint` +- [x] Run `pnpm run typecheck` +- [x] Run targeted provider tests for the migrated AI SDK-only paths +- [ ] Run `pnpm run test:main` diff --git a/docs/specs/provider-layer-simplification/plan.md b/docs/specs/provider-layer-simplification/plan.md new file mode 100644 index 000000000..c01e01cfe --- /dev/null +++ b/docs/specs/provider-layer-simplification/plan.md @@ -0,0 +1,15 @@ +# Provider Layer Simplification Plan + +1. Add a provider definition registry that captures runtime kind, model source, health check, and + hook strategy per provider. +2. Introduce a generic `AiSdkProvider` that owns AI SDK-backed text, stream, summary, image, + embeddings, and provider check flows. +3. Switch `ProviderInstanceManager` from vendor constructor maps to: + - special providers for `acp`, `github-copilot`, `voiceai`, `ollama` + - generic `AiSdkProvider` for all other AI SDK-backed providers +4. Move ModelScope MCP sync HTTP logic into shared helpers so `ModelscopeProvider` is no longer + required. +5. Adapt provider tests to assert behavior through the generic provider instead of vendor classes. +6. Delete obsolete vendor provider classes after import scans confirm they have no remaining + callers. +7. Run format, i18n, lint, typecheck, and targeted provider tests. diff --git a/docs/specs/provider-layer-simplification/spec.md b/docs/specs/provider-layer-simplification/spec.md new file mode 100644 index 000000000..4e8e19ecd --- /dev/null +++ b/docs/specs/provider-layer-simplification/spec.md @@ -0,0 +1,67 @@ +# Provider Layer Simplification + +## Status + +Completed on `2026-04-11`. + +## Goal + +Collapse the AI SDK-backed provider layer into one internal implementation while keeping all +user-visible provider contracts unchanged. + +The simplified structure is: + +- registry-driven provider resolution +- one generic `AiSdkProvider` for AI SDK-backed providers +- special-case providers kept only when they own non-generic responsibilities + +## In Scope + +- replace vendor class selection in `ProviderInstanceManager` with a registry lookup +- move runtime choice, routing, model source, health-check strategy, and provider-specific hooks + into provider definitions +- keep `AcpProvider`, `GithubCopilotProvider`, `VoiceAIProvider`, and `OllamaProvider` as + independent classes +- keep `AiSdkProvider` as the single generic implementation for AI SDK-backed providers +- move ModelScope MCP sync helpers out of `ModelscopeProvider` +- delete obsolete vendor provider classes once they have no remaining callers + +## Out of Scope + +- removing providers from the settings UI or changing the default provider list +- changing persisted provider IDs, model configs, or conversation history +- refactoring `OllamaProvider` local model management into a different subsystem +- redesigning prompts or harmonizing provider-specific output behavior + +## Compatibility Constraints + +The following must remain stable: + +- `providerId` +- `apiType` +- provider configuration schema +- model configuration schema +- history and `function_call_record` compatibility +- `LLMProviderPresenter.getProviderInstance()` behavior +- `LLMCoreStreamEvent` names and payload structure + +## Result + +After this simplification: + +- `ProviderInstanceManager` acts as a registry-backed factory +- AI SDK-backed providers no longer rely on vendor-specific provider classes +- `providers/` retains only: + - `acpProvider.ts` + - `aiSdkProvider.ts` + - `githubCopilotProvider.ts` + - `ollamaProvider.ts` + - `voiceAIProvider.ts` + +## Acceptance Criteria + +- registry resolution honors `providerId` first and `apiType` second +- routing providers (`new-api`, `zenmux`, `grok`) continue to route by model capability +- ModelScope MCP sync works without `ModelscopeProvider` instance methods +- no remaining runtime imports of deleted vendor provider classes +- targeted provider tests pass on the generic path diff --git a/docs/specs/provider-layer-simplification/tasks.md b/docs/specs/provider-layer-simplification/tasks.md new file mode 100644 index 000000000..26339fcab --- /dev/null +++ b/docs/specs/provider-layer-simplification/tasks.md @@ -0,0 +1,13 @@ +# Provider Layer Simplification Tasks + +- [x] Add `providerRegistry.ts` as the single source of AI SDK-backed provider definitions. +- [x] Add generic `AiSdkProvider` and move shared provider behavior into it. +- [x] Route `ProviderInstanceManager` through registry definitions plus special providers. +- [x] Decouple ModelScope MCP sync from `ModelscopeProvider` instance methods. +- [x] Update provider-layer tests to target the generic provider behavior. +- [x] Delete obsolete vendor provider classes from `src/main/presenter/llmProviderPresenter/providers/`. +- [x] Run `pnpm run format`. +- [x] Run `pnpm run i18n`. +- [x] Run `pnpm run lint`. +- [x] Run `pnpm run typecheck`. +- [x] Run targeted provider-layer tests. diff --git a/package.json b/package.json index 0e47385c5..68c35e8b0 100644 --- a/package.json +++ b/package.json @@ -63,17 +63,23 @@ }, "dependencies": { "@agentclientprotocol/sdk": "^0.16.1", - "@anthropic-ai/sdk": "^0.53.0", + "@ai-sdk/amazon-bedrock": "^4.0.92", + "@ai-sdk/anthropic": "^3.0.68", + "@ai-sdk/azure": "^3.0.53", + "@ai-sdk/google": "^3.0.61", + "@ai-sdk/google-vertex": "^4.0.106", + "@ai-sdk/openai": "^3.0.52", + "@ai-sdk/openai-compatible": "^2.0.41", + "@ai-sdk/provider": "^3.0.8", "@aws-sdk/client-bedrock": "^3.958.0", - "@aws-sdk/client-bedrock-runtime": "^3.958.0", "@duckdb/node-api": "1.3.2-alpha.25", "@e2b/code-interpreter": "^1.5.1", "@electron-toolkit/preload": "^3.0.2", "@electron-toolkit/utils": "^4.0.0", - "@google/genai": "^1.46.0", "@jxa/run": "^1.4.0", "@larksuiteoapi/node-sdk": "^1.60.0", "@modelcontextprotocol/sdk": "^1.28.0", + "ai": "^6.0.157", "axios": "^1.13.6", "better-sqlite3-multiple-ciphers": "12.8.0", "cheerio": "^1.2.0", @@ -96,12 +102,11 @@ "nanoid": "^5.1.7", "node-pty": "^1.1.0", "ollama": "^0.5.18", - "openai": "^6.33.0", + "ollama-ai-provider": "^1.2.0", "pdf-parse-new": "^1.4.1", "run-applescript": "^7.1.0", "safe-regex2": "^5.1.0", "sharp": "^0.33.5", - "together-ai": "^0.16.0", "tokenx": "^0.4.1", "turndown": "^7.2.2", "undici": "^7.16.0", @@ -110,6 +115,7 @@ "zod": "^3.25.76" }, "devDependencies": { + "@antv/infographic": "^0.2.7", "@electron-toolkit/tsconfig": "^1.0.1", "@electron/notarize": "^3.1.1", "@iconify-json/lucide": "^1.2.99", @@ -189,7 +195,6 @@ "vue-virtual-scroller": "^2.0.0-beta.10", "vuedraggable": "^4.1.0", "yaml": "^2.8.3", - "@antv/infographic": "^0.2.7", "zod-to-json-schema": "^3.25.1" }, "simple-git-hooks": { diff --git a/src/main/presenter/agentRuntimePresenter/accumulator.ts b/src/main/presenter/agentRuntimePresenter/accumulator.ts index b91905170..6487dd693 100644 --- a/src/main/presenter/agentRuntimePresenter/accumulator.ts +++ b/src/main/presenter/agentRuntimePresenter/accumulator.ts @@ -1,5 +1,6 @@ import type { AssistantMessageBlock } from '@shared/types/agent-interface' import type { LLMCoreStreamEvent } from '@shared/types/core/llm-events' +import type { ChatMessageProviderOptions } from '@shared/types/core/chat-message' import type { StreamState } from './types' export function finalizeTrailingPendingNarrativeBlocks(blocks: AssistantMessageBlock[]): void { @@ -17,32 +18,49 @@ export function finalizeTrailingPendingNarrativeBlocks(blocks: AssistantMessageB function getCurrentBlock( blocks: AssistantMessageBlock[], - type: 'content' | 'reasoning_content' + type: 'content' | 'reasoning_content', + providerOptions?: ChatMessageProviderOptions ): AssistantMessageBlock { + const providerOptionsJson = serializeProviderOptions(providerOptions) const last = blocks[blocks.length - 1] if ( last && last.status === 'pending' && - (last.type === 'content' || last.type === 'reasoning_content') && - last.type !== type + (last.type === 'content' || last.type === 'reasoning_content') ) { + const lastProviderOptionsJson = + typeof last.extra?.providerOptionsJson === 'string' + ? last.extra.providerOptionsJson + : undefined + + if (last.type === type && lastProviderOptionsJson === providerOptionsJson) { + return last + } + last.status = 'success' } - const current = blocks[blocks.length - 1] - if (current && current.type === type && current.status === 'pending') { - return current - } const block: AssistantMessageBlock = { type, content: '', status: 'pending', - timestamp: Date.now() + timestamp: Date.now(), + ...(providerOptionsJson ? { extra: { providerOptionsJson } } : {}) } blocks.push(block) return block } +function serializeProviderOptions( + providerOptions?: ChatMessageProviderOptions +): string | undefined { + if (!providerOptions) { + return undefined + } + + return JSON.stringify(providerOptions) +} + function updateReasoningMetadata(state: StreamState, start: number, end: number): void { const relativeStart = Math.max(0, start - state.startTime) const relativeEnd = Math.max(0, end - state.startTime) @@ -61,7 +79,7 @@ export function accumulate(state: StreamState, event: LLMCoreStreamEvent): void switch (event.type) { case 'text': { if (state.firstTokenTime === null) state.firstTokenTime = Date.now() - const block = getCurrentBlock(state.blocks, 'content') + const block = getCurrentBlock(state.blocks, 'content', event.provider_options) block.content += event.content state.dirty = true break @@ -69,7 +87,7 @@ export function accumulate(state: StreamState, event: LLMCoreStreamEvent): void case 'reasoning': { const currentTime = Date.now() if (state.firstTokenTime === null) state.firstTokenTime = currentTime - const block = getCurrentBlock(state.blocks, 'reasoning_content') + const block = getCurrentBlock(state.blocks, 'reasoning_content', event.provider_options) block.content += event.reasoning_content if ( typeof block.reasoning_time !== 'object' || @@ -91,6 +109,7 @@ export function accumulate(state: StreamState, event: LLMCoreStreamEvent): void } case 'tool_call_start': { finalizeTrailingPendingNarrativeBlocks(state.blocks) + const providerOptionsJson = serializeProviderOptions(event.provider_options) const toolBlock: AssistantMessageBlock = { type: 'tool_call', content: '', @@ -101,13 +120,15 @@ export function accumulate(state: StreamState, event: LLMCoreStreamEvent): void name: event.tool_call_name, params: '', response: '' - } + }, + ...(providerOptionsJson ? { extra: { providerOptionsJson } } : {}) } state.blocks.push(toolBlock) state.pendingToolCalls.set(event.tool_call_id, { name: event.tool_call_name, arguments: '', - blockIndex: state.blocks.length - 1 + blockIndex: state.blocks.length - 1, + providerOptions: event.provider_options }) state.dirty = true break @@ -116,9 +137,18 @@ export function accumulate(state: StreamState, event: LLMCoreStreamEvent): void const pending = state.pendingToolCalls.get(event.tool_call_id) if (pending) { pending.arguments += event.tool_call_arguments_chunk + if (!pending.providerOptions && event.provider_options) { + pending.providerOptions = event.provider_options + } const block = state.blocks[pending.blockIndex] if (block?.tool_call) { block.tool_call.params = pending.arguments + if (event.provider_options) { + block.extra = { + ...block.extra, + providerOptionsJson: serializeProviderOptions(event.provider_options) + } + } } state.dirty = true } @@ -128,19 +158,24 @@ export function accumulate(state: StreamState, event: LLMCoreStreamEvent): void const pending = state.pendingToolCalls.get(event.tool_call_id) if (pending) { const finalArgs = event.tool_call_arguments_complete ?? pending.arguments + const providerOptions = event.provider_options ?? pending.providerOptions pending.arguments = finalArgs const block = state.blocks[pending.blockIndex] if (block?.tool_call) { block.tool_call.params = finalArgs block.extra = { ...block.extra, - toolCallArgsComplete: true + toolCallArgsComplete: true, + ...(providerOptions + ? { providerOptionsJson: serializeProviderOptions(providerOptions) } + : {}) } } state.completedToolCalls.push({ id: event.tool_call_id, name: pending.name, - arguments: finalArgs + arguments: finalArgs, + ...(providerOptions ? { providerOptions } : {}) }) state.pendingToolCalls.delete(event.tool_call_id) state.dirty = true diff --git a/src/main/presenter/agentRuntimePresenter/contextBuilder.ts b/src/main/presenter/agentRuntimePresenter/contextBuilder.ts index 32c331a36..323f4b71d 100644 --- a/src/main/presenter/agentRuntimePresenter/contextBuilder.ts +++ b/src/main/presenter/agentRuntimePresenter/contextBuilder.ts @@ -1,5 +1,5 @@ import { approximateTokenSize } from 'tokenx' -import type { ChatMessage } from '@shared/types/core/chat-message' +import type { ChatMessage, ChatMessageProviderOptions } from '@shared/types/core/chat-message' import type { ChatMessageRecord, AssistantMessageBlock, @@ -29,6 +29,33 @@ export type HistoryTurn = { tokens: number } +function parseProviderOptionsJson( + value: string | undefined +): ChatMessageProviderOptions | undefined { + if (!value) { + return undefined + } + + try { + const parsed = JSON.parse(value) + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return parsed as ChatMessageProviderOptions + } + } catch {} + + return undefined +} + +function getBlockProviderOptions( + block: AssistantMessageBlock +): ChatMessageProviderOptions | undefined { + return parseProviderOptionsJson( + typeof block.extra?.providerOptionsJson === 'string' + ? block.extra.providerOptionsJson + : undefined + ) +} + function resolveFileMimeType(file: MessageFile): string { if (typeof file.mimeType === 'string' && file.mimeType.trim()) { return file.mimeType @@ -266,7 +293,10 @@ export function recordToChatMessages( toolCalls.push({ id: toolCall.id, type: 'function', - function: { name: toolCall.name, arguments: toolCall.params || '{}' } + function: { name: toolCall.name, arguments: toolCall.params || '{}' }, + ...(getBlockProviderOptions(block) + ? { provider_options: getBlockProviderOptions(block) } + : {}) }) } @@ -274,13 +304,34 @@ export function recordToChatMessages( return [{ role: 'assistant', content: combinedText }] } + const contentParts = blocks + .filter( + (block): block is AssistantMessageBlock & { content: string } => + block.type === 'content' && typeof block.content === 'string' && block.content.length > 0 + ) + .map((block) => { + const providerOptions = getBlockProviderOptions(block) + return { + type: 'text' as const, + text: block.content, + ...(providerOptions ? { provider_options: providerOptions } : {}) + } + }) + const assistantMessage: ChatMessage = { role: 'assistant', - content: text, + content: contentParts.some((part) => part.provider_options) ? contentParts : text, tool_calls: toolCalls } if (preserveInterleavedReasoning && reasoning) { assistantMessage.reasoning_content = reasoning + const reasoningProviderOptions = blocks + .filter((block) => block.type === 'reasoning_content') + .map((block) => getBlockProviderOptions(block)) + .find(Boolean) + if (reasoningProviderOptions) { + assistantMessage.reasoning_provider_options = reasoningProviderOptions + } } const result: ChatMessage[] = [assistantMessage] diff --git a/src/main/presenter/agentRuntimePresenter/dispatch.ts b/src/main/presenter/agentRuntimePresenter/dispatch.ts index 1c3dc95a9..722b2e98a 100644 --- a/src/main/presenter/agentRuntimePresenter/dispatch.ts +++ b/src/main/presenter/agentRuntimePresenter/dispatch.ts @@ -19,7 +19,7 @@ import type { ProcessHooks, StreamState } from './types' -import type { ChatMessage } from '@shared/types/core/chat-message' +import type { ChatMessage, ChatMessageProviderOptions } from '@shared/types/core/chat-message' import { nanoid } from 'nanoid' import type { ToolBatchOutputFitItem, ToolOutputGuard } from './toolOutputGuard' import { buildTerminalErrorBlocks } from './messageStore' @@ -83,6 +83,73 @@ function extractReasoningFromBlocks(blocks: AssistantMessageBlock[]): string { .join('') } +function parseProviderOptionsJson( + value: string | undefined +): ChatMessageProviderOptions | undefined { + if (!value) { + return undefined + } + + try { + const parsed = JSON.parse(value) + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return parsed as ChatMessageProviderOptions + } + } catch {} + + return undefined +} + +function getBlockProviderOptions( + block: AssistantMessageBlock +): ChatMessageProviderOptions | undefined { + return parseProviderOptionsJson( + typeof block.extra?.providerOptionsJson === 'string' + ? block.extra.providerOptionsJson + : undefined + ) +} + +function extractAssistantContent( + blocks: AssistantMessageBlock[] +): ChatMessage['content'] | undefined { + const textBlocks = blocks.filter( + (block): block is AssistantMessageBlock & { content: string } => + block.type === 'content' && typeof block.content === 'string' && block.content.length > 0 + ) + + if (textBlocks.length === 0) { + return undefined + } + + const contentParts = textBlocks.map((block) => { + const providerOptions = getBlockProviderOptions(block) + return { + type: 'text' as const, + text: block.content, + ...(providerOptions ? { provider_options: providerOptions } : {}) + } + }) + + return contentParts.some((part) => part.provider_options) + ? contentParts + : contentParts.map((part) => part.text).join('') +} + +function extractReasoningProviderOptions( + blocks: AssistantMessageBlock[] +): ChatMessageProviderOptions | undefined { + const reasoningBlocks = blocks.filter((block) => block.type === 'reasoning_content') + for (const block of reasoningBlocks) { + const providerOptions = getBlockProviderOptions(block) + if (providerOptions) { + return providerOptions + } + } + + return undefined +} + function toolResponseToText(content: string | MCPContentItem[]): string { if (typeof content === 'string') return content return content @@ -564,20 +631,26 @@ export async function executeTools( } const iterationBlocks = state.blocks.slice(prevBlockCount) - const assistantText = extractTextFromBlocks(iterationBlocks) + const assistantContent = + extractAssistantContent(iterationBlocks) ?? extractTextFromBlocks(iterationBlocks) const assistantMessage: ChatMessage = { role: 'assistant', - content: assistantText, + content: assistantContent, tool_calls: state.completedToolCalls.map((tc) => ({ id: tc.id, type: 'function' as const, - function: { name: tc.name, arguments: tc.arguments } + function: { name: tc.name, arguments: tc.arguments }, + ...(tc.providerOptions ? { provider_options: tc.providerOptions } : {}) })) } const reasoning = extractReasoningFromBlocks(iterationBlocks) if (interleavedReasoning.preserveReasoningContent && reasoning) { assistantMessage.reasoning_content = reasoning + const reasoningProviderOptions = extractReasoningProviderOptions(iterationBlocks) + if (reasoningProviderOptions) { + assistantMessage.reasoning_provider_options = reasoningProviderOptions + } } else if ( reasoning && interleavedReasoning.reasoningSupported && diff --git a/src/main/presenter/agentRuntimePresenter/types.ts b/src/main/presenter/agentRuntimePresenter/types.ts index 84263e9d3..cb62d8ad7 100644 --- a/src/main/presenter/agentRuntimePresenter/types.ts +++ b/src/main/presenter/agentRuntimePresenter/types.ts @@ -5,7 +5,7 @@ import type { QuestionOption } from '@shared/types/agent-interface' import type { LLMCoreStreamEvent } from '@shared/types/core/llm-events' -import type { ChatMessage } from '@shared/types/core/chat-message' +import type { ChatMessage, ChatMessageProviderOptions } from '@shared/types/core/chat-message' import type { MCPToolDefinition, MCPToolResponse } from '@shared/types/core/mcp' import type { ModelConfig } from '@shared/presenter' import type { IToolPresenter } from '@shared/types/presenters/tool.presenter' @@ -24,6 +24,7 @@ export interface ToolCallResult { id: string name: string arguments: string + providerOptions?: ChatMessageProviderOptions serverName?: string serverIcons?: string serverDescription?: string @@ -34,7 +35,15 @@ export interface StreamState { metadata: MessageMetadata startTime: number firstTokenTime: number | null - pendingToolCalls: Map + pendingToolCalls: Map< + string, + { + name: string + arguments: string + blockIndex: number + providerOptions?: ChatMessageProviderOptions + } + > completedToolCalls: ToolCallResult[] stopReason: 'complete' | 'tool_use' | 'error' | 'abort' | 'max_tokens' dirty: boolean diff --git a/src/main/presenter/index.ts b/src/main/presenter/index.ts index 3f1c1f667..0ec802f22 100644 --- a/src/main/presenter/index.ts +++ b/src/main/presenter/index.ts @@ -198,14 +198,6 @@ export class Presenter implements IPresenter { this.configPresenter, this.sqlitePresenter, { - mcpToolsToAnthropicTools: (mcpTools, serverName) => - this.mcpPresenter.mcpToolsToAnthropicTools(mcpTools, serverName), - mcpToolsToGeminiTools: (mcpTools, serverName) => - this.mcpPresenter.mcpToolsToGeminiTools(mcpTools, serverName), - mcpToolsToOpenAITools: (mcpTools, serverName) => - this.mcpPresenter.mcpToolsToOpenAITools(mcpTools, serverName), - mcpToolsToOpenAIResponsesTools: (mcpTools, serverName) => - this.mcpPresenter.mcpToolsToOpenAIResponsesTools(mcpTools, serverName), getNpmRegistry: () => this.mcpPresenter.getNpmRegistry?.() ?? null, getUvRegistry: () => this.mcpPresenter.getUvRegistry?.() ?? null } diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/index.ts b/src/main/presenter/llmProviderPresenter/aiSdk/index.ts new file mode 100644 index 000000000..8d9f240d6 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/index.ts @@ -0,0 +1 @@ +export * from './runtime' diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/messageMapper.ts b/src/main/presenter/llmProviderPresenter/aiSdk/messageMapper.ts new file mode 100644 index 000000000..c29677029 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/messageMapper.ts @@ -0,0 +1,348 @@ +import type { ChatMessage, MCPToolDefinition } from '@shared/presenter' +import { generateId, type ModelMessage } from 'ai' +import { applyLegacyFunctionCallPrompt } from './middlewares/legacyFunctionCallMiddleware' +import { + buildFunctionCallRecordContent, + serializeChatMessageContent, + splitMergedToolContent, + toToolResultOutput, + tryParseJson +} from './toolProtocol' + +type PendingToolCall = { + id: string + name: string + args?: string +} + +function resolveBinaryData(value: string): string | URL { + if (value.startsWith('data:')) { + return value + } + + try { + return new URL(value) + } catch { + return value + } +} + +function resolveImageMediaType(value: string): string | undefined { + const dataUrlMatch = value.match(/^data:([^;,]+)[;,]/i) + if (dataUrlMatch?.[1]) { + return dataUrlMatch[1] + } + + const normalized = value.toLowerCase() + if (normalized.endsWith('.png')) return 'image/png' + if (normalized.endsWith('.jpg') || normalized.endsWith('.jpeg')) return 'image/jpeg' + if (normalized.endsWith('.webp')) return 'image/webp' + if (normalized.endsWith('.gif')) return 'image/gif' + + return undefined +} + +function mapUserContent(content: ChatMessage['content']): any[] { + if (typeof content === 'string' || content == null) { + return [ + { + type: 'text', + text: content ?? '' + } + ] + } + + return content + .map((part) => { + if (part.type === 'text') { + return { + type: 'text', + text: part.text + } + } + + if ( + part.type === 'image_url' && + part.image_url && + typeof part.image_url.url === 'string' && + part.image_url.url + ) { + const imageUrl = part.image_url.url + const mediaType = resolveImageMediaType(imageUrl) + + return { + type: 'image', + image: resolveBinaryData(imageUrl), + ...(mediaType ? { mediaType } : {}) + } + } + + return null + }) + .filter( + ( + part + ): part is + | { type: 'text'; text: string } + | { type: 'image'; image: string | URL; mediaType?: string } => part !== null + ) +} + +function mapAssistantTextAndReasoning(message: ChatMessage): any[] { + const content: any[] = [] + + if (message.reasoning_content) { + content.push({ + type: 'reasoning', + text: message.reasoning_content, + ...(message.reasoning_provider_options + ? { providerOptions: message.reasoning_provider_options } + : {}) + }) + } + + if (Array.isArray(message.content)) { + for (const part of message.content) { + if (part.type !== 'text' || !part.text) { + continue + } + + content.push({ + type: 'text', + text: part.text, + ...(part.provider_options ? { providerOptions: part.provider_options } : {}) + }) + } + + return content + } + + const text = serializeChatMessageContent(message.content) + if (text) { + content.push({ + type: 'text', + text, + ...(message.provider_options ? { providerOptions: message.provider_options } : {}) + }) + } + + return content +} + +export interface MapMessagesToModelMessagesOptions { + tools: MCPToolDefinition[] + supportsNativeTools: boolean + buildLegacyFunctionCallPrompt?: (tools: MCPToolDefinition[]) => string +} + +export function mapMessagesToModelMessages( + messages: ChatMessage[], + options: MapMessagesToModelMessagesOptions +): ModelMessage[] { + const pendingNativeToolCalls: PendingToolCall[] = [] + const pendingMockToolCalls: PendingToolCall[] = [] + + const enqueueNativeToolCall = (id: string, name: string, args?: string) => { + pendingNativeToolCalls.push({ id, name, args }) + } + + const enqueueMockToolCall = (id: string, name: string, args?: string) => { + pendingMockToolCalls.push({ id, name, args }) + } + + const consumeToolCall = ( + source: PendingToolCall[], + preferredId?: string + ): PendingToolCall | undefined => { + if (preferredId) { + const index = source.findIndex((toolCall) => toolCall.id === preferredId) + if (index !== -1) { + return source.splice(index, 1)[0] + } + } + + return source.shift() + } + + const modelMessages = messages.reduce((acc, message) => { + if (message.role === 'system') { + acc.push({ + role: 'system', + content: serializeChatMessageContent(message.content) + }) + return acc + } + + if (message.role === 'user') { + acc.push({ + role: 'user', + content: mapUserContent(message.content) + } as ModelMessage) + return acc + } + + if (message.role === 'assistant') { + const assistantContent = mapAssistantTextAndReasoning(message) + + if (message.tool_calls?.length) { + if (options.supportsNativeTools) { + for (const toolCall of message.tool_calls) { + const toolCallId = toolCall.id || `tool-call-${generateId()}` + const rawArgs = toolCall.function.arguments + enqueueNativeToolCall(toolCallId, toolCall.function.name, rawArgs) + assistantContent.push({ + type: 'tool-call', + toolCallId, + toolName: toolCall.function.name, + ...(toolCall.provider_options ? { providerOptions: toolCall.provider_options } : {}), + input: + typeof rawArgs === 'string' ? (tryParseJson(rawArgs) ?? { raw: rawArgs }) : rawArgs + }) + } + + acc.push({ + role: 'assistant', + content: assistantContent + } as ModelMessage) + } else { + if (assistantContent.length > 0) { + acc.push({ + role: 'assistant', + content: assistantContent + } as ModelMessage) + } + + for (const toolCall of message.tool_calls) { + enqueueMockToolCall( + toolCall.id || `tool-call-${generateId()}`, + toolCall.function.name, + toolCall.function.arguments + ) + } + } + + return acc + } + + acc.push({ + role: 'assistant', + content: assistantContent + } as ModelMessage) + return acc + } + + if (message.role === 'tool') { + const serialized = + typeof message.content === 'string' + ? message.content + : serializeChatMessageContent(message.content) + + if (options.supportsNativeTools) { + const splitParts = + pendingNativeToolCalls.length > 1 && !message.tool_call_id + ? splitMergedToolContent(serialized, pendingNativeToolCalls.length) + : null + + if (splitParts) { + acc.push( + ...(splitParts + .map((part) => { + const pending = consumeToolCall(pendingNativeToolCalls) + if (!pending) { + return undefined + } + + return { + role: 'tool' as const, + content: [ + { + type: 'tool-result', + toolCallId: pending.id, + toolName: pending.name, + output: toToolResultOutput(part) + } + ] + } as ModelMessage + }) + .filter(Boolean) as ModelMessage[]) + ) + + return acc + } + + const pending = consumeToolCall(pendingNativeToolCalls, message.tool_call_id) + acc.push({ + role: 'tool', + content: [ + { + type: 'tool-result', + toolCallId: pending?.id || message.tool_call_id || `tool-result-${generateId()}`, + toolName: pending?.name || 'unknown', + output: toToolResultOutput(serialized), + ...(message.provider_options ? { providerOptions: message.provider_options } : {}) + } + ] + } as ModelMessage) + + return acc + } + + const splitParts = + pendingMockToolCalls.length > 1 && !message.tool_call_id + ? splitMergedToolContent(serialized, pendingMockToolCalls.length) + : null + + if (splitParts) { + for (const part of splitParts) { + const pending = consumeToolCall(pendingMockToolCalls) + if (!pending) { + continue + } + + acc.push({ + role: 'user', + content: [ + { + type: 'text', + text: buildFunctionCallRecordContent( + pending.name, + tryParseJson(pending.args || '{}') ?? {}, + part + ) + } + ] + } as ModelMessage) + } + + return acc + } + + const pending = consumeToolCall(pendingMockToolCalls, message.tool_call_id) + acc.push({ + role: 'user', + content: [ + { + type: 'text', + text: buildFunctionCallRecordContent( + pending?.name || 'unknown', + tryParseJson(pending?.args || '{}') ?? {}, + serialized + ) + } + ] + } as ModelMessage) + } + + return acc + }, []) + + if (!options.supportsNativeTools && options.tools.length > 0) { + return applyLegacyFunctionCallPrompt( + modelMessages, + options.tools, + options.buildLegacyFunctionCallPrompt + ) + } + + return modelMessages +} diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/middlewares/legacyFunctionCallMiddleware.ts b/src/main/presenter/llmProviderPresenter/aiSdk/middlewares/legacyFunctionCallMiddleware.ts new file mode 100644 index 000000000..917c3b917 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/middlewares/legacyFunctionCallMiddleware.ts @@ -0,0 +1,67 @@ +import type { MCPToolDefinition } from '@shared/presenter' +import type { ModelMessage } from 'ai' + +function appendTextToUserMessage(message: ModelMessage, extraText: string): ModelMessage { + if (message.role !== 'user') { + return message + } + + if (!Array.isArray(message.content)) { + return { + ...message, + role: message.role ?? 'user', + content: [ + { + type: 'text', + text: `${String(message.content ?? '')}${extraText}` + } + ] + } + } + + const content = [...message.content] + const lastPart = content.at(-1) + if (lastPart?.type === 'text') { + content[content.length - 1] = { + ...lastPart, + text: `${lastPart.text}${extraText}` + } + } else { + content.push({ + type: 'text', + text: extraText + }) + } + + return { + ...message, + content + } +} + +export function applyLegacyFunctionCallPrompt( + messages: ModelMessage[], + tools: MCPToolDefinition[], + buildPrompt: ((tools: MCPToolDefinition[]) => string) | undefined +): ModelMessage[] { + if (!tools.length || !buildPrompt) { + return messages + } + + const promptSuffix = `\n\n${buildPrompt(tools)}` + const lastUserIndex = [...messages].map((message) => message.role).lastIndexOf('user') + + if (lastUserIndex === -1) { + return [ + ...messages, + { + role: 'user', + content: [{ type: 'text', text: buildPrompt(tools) }] + } + ] + } + + return messages.map((message, index) => + index === lastUserIndex ? appendTextToUserMessage(message, promptSuffix) : message + ) +} diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/middlewares/reasoningMiddleware.ts b/src/main/presenter/llmProviderPresenter/aiSdk/middlewares/reasoningMiddleware.ts new file mode 100644 index 000000000..51e22668e --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/middlewares/reasoningMiddleware.ts @@ -0,0 +1,7 @@ +import { extractReasoningMiddleware } from 'ai' + +export function createReasoningMiddleware(tagName = 'think') { + return extractReasoningMiddleware({ + tagName + }) +} diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/providerFactory.ts b/src/main/presenter/llmProviderPresenter/aiSdk/providerFactory.ts new file mode 100644 index 000000000..98d3cb777 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/providerFactory.ts @@ -0,0 +1,603 @@ +import type { + AWS_BEDROCK_PROVIDER, + IConfigPresenter, + LLM_PROVIDER, + VERTEX_PROVIDER +} from '@shared/presenter' +import { wrapLanguageModel } from 'ai' +import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock' +import { createAnthropic } from '@ai-sdk/anthropic' +import { createAzure } from '@ai-sdk/azure' +import { createGoogleGenerativeAI } from '@ai-sdk/google' +import { createVertex } from '@ai-sdk/google-vertex' +import { createOpenAI } from '@ai-sdk/openai' +import { createOpenAICompatible } from '@ai-sdk/openai-compatible' +import { createOllama } from 'ollama-ai-provider' +import { ProxyAgent } from 'undici' +import { proxyConfig } from '../../proxyConfig' +import { createReasoningMiddleware } from './middlewares/reasoningMiddleware' + +export type AiSdkProviderKind = + | 'openai-compatible' + | 'openai-responses' + | 'azure' + | 'anthropic' + | 'gemini' + | 'vertex' + | 'aws-bedrock' + | 'ollama' + +export interface CreateAiSdkProviderContextParams { + providerKind: AiSdkProviderKind + provider: LLM_PROVIDER + configPresenter: IConfigPresenter + defaultHeaders: Record + modelId: string + cleanHeaders?: boolean + wrapThinkReasoning?: boolean +} + +export interface AiSdkProviderContext { + providerOptionsKey: string + apiType: + | 'openai_chat' + | 'openai_responses' + | 'azure_responses' + | 'anthropic' + | 'google' + | 'vertex' + | 'bedrock' + | 'ollama' + model: any + embeddingModel?: any + imageModel?: any + endpoint: string + imageEndpoint?: string + embeddingEndpoint?: string + resolvedModelId?: string +} + +function isObjectRecord(value: unknown): value is Record { + return Boolean(value) && typeof value === 'object' && !Array.isArray(value) +} + +const VERTEX_SCHEMA_TYPE_MAP: Record = { + string: 'STRING', + number: 'NUMBER', + integer: 'INTEGER', + boolean: 'BOOLEAN', + object: 'OBJECT', + array: 'ARRAY', + null: 'NULL' +} + +function normalizeVertexSchemaNode(node: unknown): unknown { + if (Array.isArray(node)) { + return node.map((item) => normalizeVertexSchemaNode(item)) + } + + if (!isObjectRecord(node)) { + return node + } + + const normalized: Record = {} + + for (const [key, value] of Object.entries(node)) { + if (key === 'type' && typeof value === 'string') { + normalized[key] = VERTEX_SCHEMA_TYPE_MAP[value.toLowerCase()] ?? value + continue + } + + normalized[key] = normalizeVertexSchemaNode(value) + } + + return normalized +} + +export function normalizeVertexRequestBody(body: unknown): unknown { + if (!isObjectRecord(body)) { + return body + } + + const nextBody: Record = { ...body } + const systemInstruction = nextBody.systemInstruction + + if (isObjectRecord(systemInstruction) && !('role' in systemInstruction)) { + nextBody.systemInstruction = { + role: 'user', + ...systemInstruction + } + } + + if (Array.isArray(nextBody.tools)) { + nextBody.tools = nextBody.tools.map((tool) => { + if (!isObjectRecord(tool) || !Array.isArray(tool.functionDeclarations)) { + return tool + } + + return { + ...tool, + functionDeclarations: tool.functionDeclarations.map((declaration) => { + if (!isObjectRecord(declaration)) { + return declaration + } + + return { + ...declaration, + ...(declaration.parameters + ? { parameters: normalizeVertexSchemaNode(declaration.parameters) } + : {}) + } + }) + } + }) + } + + const toolConfig = nextBody.toolConfig + + if (!isObjectRecord(toolConfig)) { + return nextBody + } + + const functionCallingConfig = toolConfig.functionCallingConfig + if (!isObjectRecord(functionCallingConfig)) { + return nextBody + } + + const hasOnlyAutoMode = + functionCallingConfig.mode === 'AUTO' && + Object.keys(functionCallingConfig).length === 1 && + Object.keys(toolConfig).length === 1 + + if (hasOnlyAutoMode) { + delete nextBody.toolConfig + } + + return nextBody +} + +function normalizeRequestBody( + provider: LLM_PROVIDER, + requestUrl: string, + body: RequestInit['body'] | null | undefined +): RequestInit['body'] | null | undefined { + if (body == null || typeof body !== 'string') { + return body + } + + const isVertexRequest = + provider.apiType === 'vertex' || + requestUrl.includes(':generateContent') || + requestUrl.includes(':streamGenerateContent') + + if (!isVertexRequest) { + return body + } + + try { + const parsed = JSON.parse(body) + const normalized = normalizeVertexRequestBody(parsed) + return JSON.stringify(normalized) + } catch { + return body + } +} + +function createFetchMiddleware( + provider: LLM_PROVIDER, + defaultHeaders: Record, + cleanHeaders = false +) { + const proxyUrl = proxyConfig.getProxyUrl() + const dispatcher = proxyUrl ? new ProxyAgent(proxyUrl) : undefined + + return async (url: string | URL | Request, init?: RequestInit): Promise => { + const requestUrl = typeof url === 'string' ? url : url instanceof URL ? url.toString() : url.url + const nextInit: RequestInit & { dispatcher?: ProxyAgent } = { + ...init + } + + if (dispatcher) { + nextInit.dispatcher = dispatcher + } + + const headers = new Headers(init?.headers ?? {}) + Object.entries(defaultHeaders).forEach(([key, value]) => headers.set(key, value)) + + if (cleanHeaders) { + const allowedHeaders = new Set([ + 'authorization', + 'content-type', + 'accept', + 'http-referer', + 'x-title' + ]) + + const sanitized = new Headers() + headers.forEach((value, key) => { + const normalized = key.toLowerCase() + if ( + allowedHeaders.has(normalized) || + (!normalized.startsWith('x-stainless-') && + !normalized.includes('user-agent') && + !normalized.includes('openai-')) + ) { + sanitized.set(key, value) + } + }) + + if (!sanitized.has('Authorization') && provider.apiKey) { + sanitized.set('Authorization', `Bearer ${provider.apiKey}`) + } + + nextInit.headers = sanitized + } else { + nextInit.headers = headers + } + + nextInit.body = normalizeRequestBody(provider, requestUrl, nextInit.body) + return fetch(url, nextInit) + } +} + +function buildOpenAIEndpoint(baseUrl: string, path: string): string { + return `${baseUrl.replace(/\/+$/, '')}${path}` +} + +const DEFAULT_AZURE_V1_API_VERSION = 'v1' +const DEFAULT_AZURE_DEPLOYMENT_API_VERSION = '2024-02-01' + +export interface NormalizedAzureBaseUrl { + baseURL?: string + apiVersion: string + useDeploymentBasedUrls: boolean + deploymentName?: string +} + +export function normalizeAnthropicBaseUrl(baseUrl: string | undefined): string { + const normalized = (baseUrl || 'https://api.anthropic.com').replace(/\/+$/, '') + + if (normalized.endsWith('/v1/messages')) { + return normalized.slice(0, -'/messages'.length) + } + + if (normalized.endsWith('/messages')) { + return normalized.slice(0, -'/messages'.length) + } + + if (normalized.endsWith('/v1')) { + return normalized + } + + return `${normalized}/v1` +} + +export function normalizeVertexBaseUrl( + baseUrl: string | undefined, + apiKey: string | undefined, + apiVersion: 'v1' | 'v1beta1' = 'v1' +): string | undefined { + const normalized = baseUrl?.trim().replace(/\/+$/, '') + if (!normalized) { + return undefined + } + + if (!apiKey) { + return normalized + } + + if (/\/publishers\/google$/i.test(normalized)) { + return normalized + } + + if (new RegExp(`/${apiVersion}$`, 'i').test(normalized)) { + return `${normalized}/publishers/google` + } + + return `${normalized}/${apiVersion}/publishers/google` +} + +export function normalizeAzureBaseUrl( + baseUrl: string | undefined, + apiVersion: string | undefined +): NormalizedAzureBaseUrl { + const normalizedBaseUrl = baseUrl?.trim() + const normalizedApiVersion = apiVersion?.trim() + + if (!normalizedBaseUrl) { + return { + apiVersion: normalizedApiVersion || DEFAULT_AZURE_V1_API_VERSION, + useDeploymentBasedUrls: false + } + } + + try { + const url = new URL(normalizedBaseUrl) + url.search = '' + url.hash = '' + + let pathname = url.pathname.replace(/\/+$/, '') + let deploymentName: string | undefined + + const deploymentMatch = pathname.match(/\/openai\/deployments\/([^/]+)(?:\/.*)?$/i) + if (deploymentMatch?.[1]) { + deploymentName = decodeURIComponent(deploymentMatch[1]) + pathname = pathname.slice(0, deploymentMatch.index ?? pathname.length) || '/openai' + } else if (/\/openai\/v1$/i.test(pathname)) { + pathname = pathname.replace(/\/openai\/v1$/i, '/openai') + } else if (!/\/openai$/i.test(pathname)) { + pathname = pathname ? `${pathname}/openai` : '/openai' + } + + url.pathname = pathname || '/openai' + + return { + baseURL: url.toString().replace(/\/+$/, ''), + apiVersion: + normalizedApiVersion || + (deploymentName ? DEFAULT_AZURE_DEPLOYMENT_API_VERSION : DEFAULT_AZURE_V1_API_VERSION), + useDeploymentBasedUrls: Boolean(deploymentName), + deploymentName + } + } catch { + const fallbackBaseUrl = normalizedBaseUrl.replace(/\/+$/, '') + + return { + baseURL: fallbackBaseUrl.endsWith('/openai') + ? fallbackBaseUrl + : fallbackBaseUrl.endsWith('/openai/v1') + ? fallbackBaseUrl.slice(0, -'/v1'.length) + : `${fallbackBaseUrl}/openai`, + apiVersion: normalizedApiVersion || DEFAULT_AZURE_V1_API_VERSION, + useDeploymentBasedUrls: false + } + } +} + +function buildAzureEndpoint( + baseURL: string | undefined, + path: string, + apiVersion: string, + deploymentName: string, + useDeploymentBasedUrls: boolean +): string { + const basePath = (baseURL || '').replace(/\/+$/, '') + const endpoint = useDeploymentBasedUrls + ? `${basePath}/deployments/${deploymentName}${path}` + : `${basePath}/v1${path}` + + return `${endpoint}?api-version=${encodeURIComponent(apiVersion)}` +} + +export function createAiSdkProviderContext( + params: CreateAiSdkProviderContextParams +): AiSdkProviderContext { + const baseUrl = params.provider.baseUrl || '' + const fetch = createFetchMiddleware( + params.provider, + params.defaultHeaders, + params.cleanHeaders === true + ) + const maybeWrapModel = (model: any): any => + params.wrapThinkReasoning === false + ? model + : wrapLanguageModel({ + model, + middleware: createReasoningMiddleware() + }) + + switch (params.providerKind) { + case 'openai-responses': { + const provider = createOpenAI({ + baseURL: baseUrl, + apiKey: params.provider.apiKey, + headers: params.defaultHeaders, + fetch + }) + + return { + providerOptionsKey: 'openai', + apiType: 'openai_responses', + model: maybeWrapModel(provider.responses(params.modelId) as any), + embeddingModel: provider.embedding(params.modelId), + imageModel: provider.image(params.modelId), + endpoint: buildOpenAIEndpoint(baseUrl || 'https://api.openai.com/v1', '/responses') + } + } + + case 'azure': { + const azureApiVersion = params.configPresenter.getSetting('azureApiVersion') + const azureConfig = normalizeAzureBaseUrl(baseUrl || undefined, azureApiVersion) + const deploymentName = azureConfig.deploymentName || params.modelId + const provider = createAzure({ + baseURL: azureConfig.baseURL, + apiKey: params.provider.apiKey || undefined, + headers: params.defaultHeaders, + fetch, + apiVersion: azureConfig.apiVersion, + useDeploymentBasedUrls: azureConfig.useDeploymentBasedUrls + }) + + return { + providerOptionsKey: 'azure', + apiType: 'azure_responses', + model: maybeWrapModel(provider.responses(deploymentName) as any), + embeddingModel: provider.embedding(deploymentName), + imageModel: provider.image(deploymentName), + endpoint: buildAzureEndpoint( + azureConfig.baseURL, + '/responses', + azureConfig.apiVersion, + deploymentName, + azureConfig.useDeploymentBasedUrls + ), + imageEndpoint: buildAzureEndpoint( + azureConfig.baseURL, + '/images/generations', + azureConfig.apiVersion, + deploymentName, + azureConfig.useDeploymentBasedUrls + ), + embeddingEndpoint: buildAzureEndpoint( + azureConfig.baseURL, + '/embeddings', + azureConfig.apiVersion, + deploymentName, + azureConfig.useDeploymentBasedUrls + ), + resolvedModelId: deploymentName + } + } + + case 'openai-compatible': { + if (params.provider.id === 'openai') { + const provider = createOpenAI({ + baseURL: baseUrl, + apiKey: params.provider.apiKey, + headers: params.defaultHeaders, + fetch + }) + + return { + providerOptionsKey: 'openai', + apiType: 'openai_chat', + model: maybeWrapModel(provider.chat(params.modelId) as any), + embeddingModel: provider.embedding(params.modelId), + imageModel: provider.image(params.modelId), + endpoint: buildOpenAIEndpoint(baseUrl || 'https://api.openai.com/v1', '/chat/completions') + } + } + + const provider = createOpenAICompatible({ + name: params.provider.id, + baseURL: baseUrl, + apiKey: params.provider.apiKey, + headers: params.defaultHeaders, + fetch, + includeUsage: true + }) + + return { + providerOptionsKey: params.provider.id, + apiType: 'openai_chat', + model: maybeWrapModel(provider.chatModel(params.modelId) as any), + embeddingModel: provider.embeddingModel(params.modelId), + imageModel: provider.imageModel(params.modelId), + endpoint: buildOpenAIEndpoint(baseUrl, '/chat/completions') + } + } + + case 'anthropic': { + const anthropicBaseUrl = normalizeAnthropicBaseUrl(baseUrl) + const provider = createAnthropic({ + baseURL: anthropicBaseUrl, + apiKey: params.provider.apiKey || process.env.ANTHROPIC_API_KEY, + headers: params.defaultHeaders, + fetch, + name: 'anthropic' + }) + + return { + providerOptionsKey: 'anthropic', + apiType: 'anthropic', + model: maybeWrapModel(provider.messages(params.modelId) as any), + endpoint: `${anthropicBaseUrl}/messages` + } + } + + case 'gemini': { + const provider = createGoogleGenerativeAI({ + baseURL: baseUrl || undefined, + apiKey: params.provider.apiKey || process.env.GEMINI_API_KEY, + headers: params.defaultHeaders, + fetch + }) + + return { + providerOptionsKey: 'google', + apiType: 'google', + model: maybeWrapModel(provider.languageModel(params.modelId) as any), + embeddingModel: provider.embeddingModel(params.modelId), + imageModel: provider.imageModel(params.modelId), + endpoint: baseUrl || 'https://generativelanguage.googleapis.com' + } + } + + case 'vertex': { + const vertexProvider = params.provider as VERTEX_PROVIDER + const vertexApiVersion = (vertexProvider.apiVersion as 'v1' | 'v1beta1') || 'v1' + const vertexBaseUrl = normalizeVertexBaseUrl( + vertexProvider.baseUrl, + vertexProvider.apiKey || undefined, + vertexApiVersion + ) + const provider = createVertex({ + apiKey: vertexProvider.apiKey || undefined, + baseURL: vertexBaseUrl, + project: vertexProvider.projectId || process.env.GOOGLE_VERTEX_PROJECT, + location: vertexProvider.location || process.env.GOOGLE_VERTEX_LOCATION, + headers: params.defaultHeaders, + fetch, + googleAuthOptions: + vertexProvider.accountClientEmail && vertexProvider.accountPrivateKey + ? { + credentials: { + client_email: vertexProvider.accountClientEmail, + private_key: vertexProvider.accountPrivateKey + } + } + : undefined + }) + + return { + providerOptionsKey: 'vertex', + apiType: 'vertex', + model: maybeWrapModel(provider.languageModel(params.modelId) as any), + embeddingModel: provider.embeddingModel(params.modelId), + imageModel: provider.imageModel(params.modelId), + endpoint: vertexBaseUrl || 'https://aiplatform.googleapis.com/v1/publishers/google' + } + } + + case 'aws-bedrock': { + const bedrockProvider = params.provider as AWS_BEDROCK_PROVIDER + const provider = createAmazonBedrock({ + apiKey: bedrockProvider.apiKey || undefined, + baseURL: bedrockProvider.baseUrl || undefined, + region: bedrockProvider.credential?.region || process.env.AWS_REGION || 'us-east-1', + accessKeyId: bedrockProvider.credential?.accessKeyId || process.env.AWS_ACCESS_KEY_ID, + secretAccessKey: + bedrockProvider.credential?.secretAccessKey || process.env.AWS_SECRET_ACCESS_KEY, + headers: params.defaultHeaders, + fetch + }) + + return { + providerOptionsKey: 'bedrock', + apiType: 'bedrock', + model: maybeWrapModel(provider.languageModel(params.modelId) as any), + embeddingModel: (provider as any).embeddingModel?.(params.modelId), + imageModel: (provider as any).imageModel?.(params.modelId), + endpoint: bedrockProvider.baseUrl || 'https://bedrock-runtime.amazonaws.com' + } + } + + case 'ollama': { + const provider = createOllama({ + baseURL: baseUrl || undefined, + headers: params.defaultHeaders, + fetch + }) + + return { + providerOptionsKey: 'ollama', + apiType: 'ollama', + model: maybeWrapModel(provider(params.modelId) as any), + embeddingModel: + (provider as any).embeddingModel?.(params.modelId) ?? + (provider as any).textEmbeddingModel?.(params.modelId), + endpoint: baseUrl || 'http://127.0.0.1:11434' + } + } + } +} diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/providerOptionsMapper.ts b/src/main/presenter/llmProviderPresenter/aiSdk/providerOptionsMapper.ts new file mode 100644 index 000000000..779c66777 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/providerOptionsMapper.ts @@ -0,0 +1,308 @@ +import type { MCPToolDefinition, ModelConfig } from '@shared/presenter' +import type { ModelMessage } from 'ai' +import { resolvePromptCachePlan } from '../promptCacheStrategy' +import { modelCapabilities } from '../../configPresenter/modelCapabilities' +import { providerDbLoader } from '../../configPresenter/providerDbLoader' + +type ProviderOptionsRecord = Record> + +function cloneMessage(message: ModelMessage): ModelMessage { + return { + ...(message as any), + ...(Array.isArray((message as any).content) + ? { + content: (message as any).content.map((part: any) => ({ ...part })) + } + : {}) + } as ModelMessage +} + +function applyExplicitAnthropicCacheBreakpoint(messages: ModelMessage[]): ModelMessage[] { + const cloned = messages.map(cloneMessage) + + for (let messageIndex = cloned.length - 1; messageIndex >= 0; messageIndex -= 1) { + const message = cloned[messageIndex] + + if (message.role === 'system') { + continue + } + + if (!Array.isArray(message.content)) { + continue + } + + for (let partIndex = message.content.length - 1; partIndex >= 0; partIndex -= 1) { + const part = message.content[partIndex] + if (part?.type !== 'text' || typeof part.text !== 'string' || !part.text.trim()) { + continue + } + + message.content[partIndex] = { + ...part, + providerOptions: { + ...(part.providerOptions as Record | undefined), + anthropic: { + cacheControl: { + type: 'ephemeral' + } + } + } + } + + return cloned + } + } + + return cloned +} + +export interface BuildProviderOptionsParams { + providerId: string + providerOptionsKey: string + apiType: + | 'openai_chat' + | 'openai_responses' + | 'azure_responses' + | 'anthropic' + | 'google' + | 'vertex' + | 'bedrock' + | 'ollama' + modelId: string + modelConfig: ModelConfig + tools: MCPToolDefinition[] + messages: ModelMessage[] +} + +export interface ProviderOptionsMappingResult { + messages: ModelMessage[] + providerOptions?: ProviderOptionsRecord +} + +function isOfficialAnthropicProvider(providerId: string): boolean { + return providerId.trim().toLowerCase() === 'anthropic' +} + +function supportsDoubaoThinking(providerId: string, modelId: string): boolean { + if (providerId !== 'doubao') { + return false + } + + const model = providerDbLoader.getModel(providerId, modelId) + const notes = model?.extra_capabilities?.reasoning?.notes + return Array.isArray(notes) && notes.includes('doubao-thinking-parameter') +} + +function supportsSiliconcloudThinking(modelId: string): boolean { + const normalizedModelId = modelId.toLowerCase() + return [ + 'qwen/qwen3-8b', + 'qwen/qwen3-14b', + 'qwen/qwen3-32b', + 'qwen/qwen3-30b-a3b', + 'qwen/qwen3-235b-a22b', + 'tencent/hunyuan-a13b-instruct', + 'zai-org/glm-4.5v', + 'deepseek-ai/deepseek-v3.1', + 'pro/deepseek-ai/deepseek-v3.1' + ].some((supportedModel) => normalizedModelId.includes(supportedModel)) +} + +function supportsGrokReasoningEffort(modelId: string): boolean { + return ['grok-3-mini', 'grok-3-mini-fast'].some((model) => + modelId.toLowerCase().includes(model.toLowerCase()) + ) +} + +export function buildProviderOptions( + params: BuildProviderOptionsParams +): ProviderOptionsMappingResult { + const providerOptions: ProviderOptionsRecord = {} + let messages = params.messages + + const promptCachePlan = resolvePromptCachePlan({ + providerId: params.providerId, + apiType: + params.apiType === 'openai_responses' + ? 'openai_responses' + : params.apiType === 'anthropic' || params.apiType === 'bedrock' + ? 'anthropic' + : 'openai_chat', + modelId: params.modelId, + messages: params.messages as unknown[], + tools: params.tools, + conversationId: params.modelConfig.conversationId + }) + + switch (params.apiType) { + case 'openai_chat': + case 'openai_responses': { + const config: Record = {} + if (params.modelConfig.reasoningEffort && params.providerId !== 'grok') { + config.reasoningEffort = params.modelConfig.reasoningEffort + } + if (params.modelConfig.verbosity) { + config.textVerbosity = params.modelConfig.verbosity + } + if (params.modelConfig.maxCompletionTokens) { + config.maxCompletionTokens = params.modelConfig.maxCompletionTokens + } + if (promptCachePlan.cacheKey) { + config.promptCacheKey = promptCachePlan.cacheKey + } + if ( + supportsDoubaoThinking(params.providerId, params.modelId) && + params.modelConfig.reasoning + ) { + config.thinking = { + type: 'enabled' + } + } + if ( + params.providerId === 'siliconcloud' && + supportsSiliconcloudThinking(params.modelId) && + params.modelConfig.reasoning + ) { + config.enable_thinking = true + } + if ( + params.providerId === 'dashscope' && + modelCapabilities.supportsReasoning(params.providerId, params.modelId) && + params.modelConfig.reasoning + ) { + config.enable_thinking = true + const dbBudget = modelCapabilities.getThinkingBudgetRange( + params.providerId, + params.modelId + ).default + const budget = params.modelConfig.thinkingBudget ?? dbBudget + if (typeof budget === 'number') { + config.thinking_budget = budget + } + } + if ( + params.providerId === 'grok' && + params.modelConfig.reasoningEffort && + supportsGrokReasoningEffort(params.modelId) + ) { + config.reasoning_effort = params.modelConfig.reasoningEffort + } + if (Object.keys(config).length > 0) { + providerOptions[params.providerOptionsKey] = config + } + break + } + + case 'azure_responses': { + const config: Record = {} + if (params.modelConfig.reasoningEffort) { + config.reasoningEffort = params.modelConfig.reasoningEffort + } + if (params.modelConfig.verbosity) { + config.textVerbosity = params.modelConfig.verbosity + } + if (params.modelConfig.maxCompletionTokens) { + config.maxCompletionTokens = params.modelConfig.maxCompletionTokens + } + if (Object.keys(config).length > 0) { + providerOptions[params.providerOptionsKey] = config + } + break + } + + case 'anthropic': + case 'bedrock': { + const officialAnthropicProvider = + params.apiType === 'anthropic' && isOfficialAnthropicProvider(params.providerId) + const config: Record = { + toolStreaming: officialAnthropicProvider + } + if (officialAnthropicProvider && params.modelConfig.reasoning) { + config.sendReasoning = true + } + if (officialAnthropicProvider && params.modelConfig.reasoningEffort) { + config.effort = + params.modelConfig.reasoningEffort === 'low' + ? 'low' + : params.modelConfig.reasoningEffort === 'high' + ? 'high' + : 'medium' + } + if (params.modelConfig.thinkingBudget !== undefined) { + config.thinking = { + type: 'enabled', + budgetTokens: params.modelConfig.thinkingBudget + } + } + if (promptCachePlan.mode === 'anthropic_auto') { + config.cacheControl = { + type: 'ephemeral' + } + } + if (Object.keys(config).length > 0) { + providerOptions.anthropic = config + } + if (promptCachePlan.mode === 'anthropic_explicit') { + messages = applyExplicitAnthropicCacheBreakpoint(messages) + } + break + } + + case 'google': { + const config: Record = {} + if (params.tools.length > 0) { + config.streamFunctionCallArguments = true + } + if (params.modelConfig.thinkingBudget !== undefined || params.modelConfig.reasoningEffort) { + config.thinkingConfig = { + ...(params.modelConfig.thinkingBudget !== undefined + ? { thinkingBudget: params.modelConfig.thinkingBudget } + : {}), + ...(params.modelConfig.reasoningEffort + ? { thinkingLevel: params.modelConfig.reasoningEffort } + : {}), + includeThoughts: true + } + } + if (Object.keys(config).length > 0) { + providerOptions[params.providerOptionsKey] = config + } + break + } + + case 'vertex': { + const config: Record = { + streamFunctionCallArguments: params.tools.length > 0 + } + if (params.modelConfig.thinkingBudget !== undefined || params.modelConfig.reasoningEffort) { + config.thinkingConfig = { + ...(params.modelConfig.thinkingBudget !== undefined + ? { thinkingBudget: params.modelConfig.thinkingBudget } + : {}), + ...(params.modelConfig.reasoningEffort + ? { thinkingLevel: params.modelConfig.reasoningEffort } + : {}), + includeThoughts: true + } + } + providerOptions[params.providerOptionsKey] = config + break + } + + case 'ollama': { + const config: Record = {} + if (params.modelConfig.reasoningEffort) { + config.reasoning_effort = params.modelConfig.reasoningEffort + } + if (Object.keys(config).length > 0) { + providerOptions[params.providerOptionsKey] = config + } + break + } + } + + return { + messages, + providerOptions: Object.keys(providerOptions).length > 0 ? providerOptions : undefined + } +} diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/runtime.ts b/src/main/presenter/llmProviderPresenter/aiSdk/runtime.ts new file mode 100644 index 000000000..db54e57a1 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/runtime.ts @@ -0,0 +1,330 @@ +import { embedMany, generateId, generateImage, generateText, streamText } from 'ai' +import type { + ChatMessage, + IConfigPresenter, + LLM_EMBEDDING_ATTRS, + LLM_PROVIDER, + LLMResponse, + MCPToolDefinition, + ModelConfig +} from '@shared/presenter' +import { ApiEndpointType } from '@shared/model' +import { presenter } from '@/presenter' +import { EMBEDDING_TEST_KEY, isNormalized } from '@/utils/vector' +import type { LLMCoreStreamEvent } from '@shared/types/core/llm-events' +import { mcpToolsToAISDKTools } from './toolMapper' +import { mapMessagesToModelMessages } from './messageMapper' +import { buildProviderOptions } from './providerOptionsMapper' +import { type AiSdkProviderKind, createAiSdkProviderContext } from './providerFactory' +import { adaptAiSdkStream } from './streamAdapter' + +export interface AiSdkRuntimeContext { + providerKind: AiSdkProviderKind + provider: LLM_PROVIDER + configPresenter: IConfigPresenter + defaultHeaders: Record + buildLegacyFunctionCallPrompt?: (tools: MCPToolDefinition[]) => string + emitRequestTrace?: ( + modelConfig: ModelConfig, + payload: { + endpoint: string + headers?: Record + body?: unknown + } + ) => Promise + buildTraceHeaders?: () => Record + cleanHeaders?: boolean + supportsNativeTools?: (modelId: string, modelConfig: ModelConfig) => boolean + shouldUseImageGeneration?: (modelId: string, modelConfig: ModelConfig) => boolean +} + +function normalizePromptValue(value: unknown): string { + if (typeof value === 'string') { + return value + } + + if (typeof value === 'number' || typeof value === 'boolean' || typeof value === 'bigint') { + return String(value) + } + + if (Array.isArray(value)) { + return value + .map((item) => { + if (typeof item === 'string') { + return item + } + + if (item && typeof item === 'object' && 'text' in item && typeof item.text === 'string') { + return item.text + } + + return '' + }) + .filter((item) => item.trim().length > 0) + .join('\n') + } + + if (value && typeof value === 'object') { + if ('text' in value && typeof value.text === 'string') { + return value.text + } + + const stringified = String(value) + return stringified === '[object Object]' ? '' : stringified + } + + return '' +} + +function extractImagePrompt(messages: ChatMessage[]): string { + return messages + .map((message) => (message.role === 'user' ? normalizePromptValue(message.content) : '')) + .filter((content) => content.trim().length > 0) + .join('\n\n') +} + +function resolveSupportsNativeTools( + context: AiSdkRuntimeContext, + modelId: string, + modelConfig: ModelConfig +): boolean { + if (context.supportsNativeTools) { + return context.supportsNativeTools(modelId, modelConfig) + } + + return modelConfig.functionCall === true +} + +function shouldUseImageGenerationRuntime( + context: AiSdkRuntimeContext, + modelId: string, + modelConfig: ModelConfig +): boolean { + if (context.shouldUseImageGeneration) { + return context.shouldUseImageGeneration(modelId, modelConfig) + } + + return modelConfig.apiEndpoint === ApiEndpointType.Image +} + +async function buildPromptRuntime( + context: AiSdkRuntimeContext, + messages: ChatMessage[], + modelId: string, + modelConfig: ModelConfig, + tools: MCPToolDefinition[] +) { + const supportsNativeTools = resolveSupportsNativeTools(context, modelId, modelConfig) + const providerContext = createAiSdkProviderContext({ + providerKind: context.providerKind, + provider: context.provider, + configPresenter: context.configPresenter, + defaultHeaders: context.defaultHeaders, + modelId, + cleanHeaders: context.cleanHeaders + }) + const mappedMessages = mapMessagesToModelMessages(messages, { + tools, + supportsNativeTools, + buildLegacyFunctionCallPrompt: context.buildLegacyFunctionCallPrompt + }) + const toolsMap = supportsNativeTools ? mcpToolsToAISDKTools(tools) : {} + const providerOptionResult = buildProviderOptions({ + providerId: context.provider.id, + providerOptionsKey: providerContext.providerOptionsKey, + apiType: providerContext.apiType, + modelId, + modelConfig, + tools, + messages: mappedMessages + }) + + return { + providerContext, + messages: providerOptionResult.messages, + providerOptions: providerOptionResult.providerOptions, + tools: toolsMap, + supportsNativeTools + } +} + +function usageToLlmResponse( + usage: + | { + inputTokens?: number + outputTokens?: number + totalTokens?: number + } + | undefined +): LLMResponse['totalUsage'] | undefined { + if (!usage) { + return undefined + } + + return { + prompt_tokens: usage.inputTokens ?? 0, + completion_tokens: usage.outputTokens ?? 0, + total_tokens: usage.totalTokens ?? (usage.inputTokens ?? 0) + (usage.outputTokens ?? 0) + } +} + +export async function runAiSdkGenerateText( + context: AiSdkRuntimeContext, + messages: ChatMessage[], + modelId: string, + modelConfig: ModelConfig, + temperature?: number, + maxTokens?: number +): Promise { + const runtime = await buildPromptRuntime(context, messages, modelId, modelConfig, []) + + await context.emitRequestTrace?.(modelConfig, { + endpoint: runtime.providerContext.endpoint, + headers: context.buildTraceHeaders?.() ?? context.defaultHeaders, + body: { + model: runtime.providerContext.resolvedModelId ?? modelId, + maxOutputTokens: maxTokens, + temperature + } + }) + + const result = await generateText({ + model: runtime.providerContext.model, + messages: runtime.messages, + providerOptions: runtime.providerOptions as any, + temperature, + maxOutputTokens: maxTokens + }) + + return { + content: result.text, + reasoning_content: result.reasoningText, + totalUsage: usageToLlmResponse(result.totalUsage) + } +} + +export async function* runAiSdkCoreStream( + context: AiSdkRuntimeContext, + messages: ChatMessage[], + modelId: string, + modelConfig: ModelConfig, + temperature: number, + maxTokens: number, + tools: MCPToolDefinition[] +): AsyncGenerator { + if (shouldUseImageGenerationRuntime(context, modelId, modelConfig)) { + const prompt = extractImagePrompt(messages) + + const providerContext = createAiSdkProviderContext({ + providerKind: context.providerKind, + provider: context.provider, + configPresenter: context.configPresenter, + defaultHeaders: context.defaultHeaders, + modelId, + cleanHeaders: context.cleanHeaders + }) + + if (!providerContext.imageModel) { + throw new Error(`Image generation is not supported by provider ${context.provider.id}`) + } + + await context.emitRequestTrace?.(modelConfig, { + endpoint: providerContext.imageEndpoint ?? providerContext.endpoint, + headers: context.buildTraceHeaders?.() ?? context.defaultHeaders, + body: { + model: providerContext.resolvedModelId ?? modelId, + prompt + } + }) + + const result = await generateImage({ + model: providerContext.imageModel, + prompt + }) + + for (const image of result.images) { + const dataUrl = `data:${image.mediaType};base64,${image.base64}` + const cachedImage = await presenter.devicePresenter.cacheImage(dataUrl) + yield { + type: 'image_data', + image_data: { + data: cachedImage, + mimeType: image.mediaType + } + } + } + + yield { + type: 'stop', + stop_reason: 'complete' + } + return + } + + const runtime = await buildPromptRuntime(context, messages, modelId, modelConfig, tools) + + await context.emitRequestTrace?.(modelConfig, { + endpoint: runtime.providerContext.endpoint, + headers: context.buildTraceHeaders?.() ?? context.defaultHeaders, + body: { + model: runtime.providerContext.resolvedModelId ?? modelId, + maxOutputTokens: maxTokens, + temperature, + tools: tools.map((tool) => tool.function.name) + } + }) + + const result = streamText({ + model: runtime.providerContext.model, + messages: runtime.messages, + tools: runtime.tools, + providerOptions: runtime.providerOptions as any, + temperature, + maxOutputTokens: maxTokens + }) + + yield* adaptAiSdkStream(result.fullStream, { + supportsNativeTools: runtime.supportsNativeTools, + cacheImage: (data) => presenter.devicePresenter.cacheImage(data) + }) +} + +export async function runAiSdkEmbeddings( + context: AiSdkRuntimeContext, + modelId: string, + texts: string[] +): Promise { + const providerContext = createAiSdkProviderContext({ + providerKind: context.providerKind, + provider: context.provider, + configPresenter: context.configPresenter, + defaultHeaders: context.defaultHeaders, + modelId, + cleanHeaders: context.cleanHeaders, + wrapThinkReasoning: false + }) + + if (!providerContext.embeddingModel) { + throw new Error(`embedding is not supported by provider ${context.provider.id}`) + } + + const result = await embedMany({ + model: providerContext.embeddingModel, + values: texts + }) + + return result.embeddings +} + +export async function runAiSdkDimensions( + context: AiSdkRuntimeContext, + modelId: string +): Promise { + const embeddings = await runAiSdkEmbeddings(context, modelId, [ + EMBEDDING_TEST_KEY || generateId() + ]) + return { + dimensions: embeddings[0].length, + normalized: isNormalized(embeddings[0]) + } +} diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/streamAdapter.ts b/src/main/presenter/llmProviderPresenter/aiSdk/streamAdapter.ts new file mode 100644 index 000000000..566108842 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/streamAdapter.ts @@ -0,0 +1,230 @@ +import { createStreamEvent, type LLMCoreStreamEvent } from '@shared/types/core/llm-events' +import type { ChatMessageProviderOptions } from '@shared/types/core/chat-message' +import type { ToolSet, TextStreamPart } from 'ai' +import { parseLegacyFunctionCalls } from './toolProtocol' + +const FUNCTION_CALL_TAG = '' +const FUNCTION_CALL_CLOSE_TAG = '' + +function resolveSafeTextLength(buffer: string): number { + const maxCheck = Math.min(buffer.length, FUNCTION_CALL_TAG.length - 1) + + for (let suffixLength = maxCheck; suffixLength > 0; suffixLength -= 1) { + if (FUNCTION_CALL_TAG.startsWith(buffer.slice(-suffixLength))) { + return buffer.length - suffixLength + } + } + + return buffer.length +} + +function mapFinishReason( + reason: string | undefined +): 'tool_use' | 'max_tokens' | 'stop_sequence' | 'error' | 'complete' { + switch (reason) { + case 'tool-calls': + return 'tool_use' + case 'length': + return 'max_tokens' + case 'error': + return 'error' + case 'stop': + return 'stop_sequence' + default: + return 'complete' + } +} + +function toUsageEvent(usage: { + inputTokens?: number + outputTokens?: number + totalTokens?: number + inputTokenDetails?: { + cacheReadTokens?: number + cacheWriteTokens?: number + } +}): LLMCoreStreamEvent { + return createStreamEvent.usage({ + prompt_tokens: usage.inputTokens ?? 0, + completion_tokens: usage.outputTokens ?? 0, + total_tokens: usage.totalTokens ?? (usage.inputTokens ?? 0) + (usage.outputTokens ?? 0), + ...(usage.inputTokenDetails?.cacheReadTokens + ? { cached_tokens: usage.inputTokenDetails.cacheReadTokens } + : {}), + ...(usage.inputTokenDetails?.cacheWriteTokens + ? { cache_write_tokens: usage.inputTokenDetails.cacheWriteTokens } + : {}) + }) +} + +function toProviderOptions(value: unknown): ChatMessageProviderOptions | undefined { + if (!value || typeof value !== 'object' || Array.isArray(value)) { + return undefined + } + + return value as ChatMessageProviderOptions +} + +export interface AdaptAiSdkStreamOptions { + supportsNativeTools: boolean + cacheImage?: (data: string) => Promise +} + +export async function* adaptAiSdkStream( + fullStream: AsyncIterable>, + options: AdaptAiSdkStreamOptions +): AsyncGenerator { + const toolArgumentBuffers = new Map() + const endedToolCalls = new Set() + let bufferedLegacyText = '' + let legacyToolUseDetected = false + + const emitLegacyTextBuffer = async function* ( + flushAll = false + ): AsyncGenerator { + while (true) { + const startIndex = bufferedLegacyText.indexOf(FUNCTION_CALL_TAG) + + if (startIndex === -1) { + const safeLength = flushAll + ? bufferedLegacyText.length + : resolveSafeTextLength(bufferedLegacyText) + if (safeLength > 0) { + yield createStreamEvent.text(bufferedLegacyText.slice(0, safeLength)) + bufferedLegacyText = bufferedLegacyText.slice(safeLength) + } + return + } + + if (startIndex > 0) { + yield createStreamEvent.text(bufferedLegacyText.slice(0, startIndex)) + bufferedLegacyText = bufferedLegacyText.slice(startIndex) + } + + const endIndex = bufferedLegacyText.indexOf(FUNCTION_CALL_CLOSE_TAG) + if (endIndex === -1) { + return + } + + const blockEnd = endIndex + FUNCTION_CALL_CLOSE_TAG.length + const block = bufferedLegacyText.slice(0, blockEnd) + bufferedLegacyText = bufferedLegacyText.slice(blockEnd) + + const toolCalls = parseLegacyFunctionCalls(block) + if (!toolCalls.length) { + yield createStreamEvent.text(block) + continue + } + + legacyToolUseDetected = true + for (const toolCall of toolCalls) { + yield createStreamEvent.toolCallStart(toolCall.id, toolCall.function.name) + yield createStreamEvent.toolCallChunk(toolCall.id, toolCall.function.arguments) + yield createStreamEvent.toolCallEnd(toolCall.id, toolCall.function.arguments) + } + } + } + + for await (const part of fullStream) { + switch (part.type) { + case 'text-delta': { + if (options.supportsNativeTools) { + yield createStreamEvent.text(part.text, toProviderOptions((part as any).providerMetadata)) + break + } + + bufferedLegacyText += part.text + yield* emitLegacyTextBuffer(false) + break + } + + case 'reasoning-delta': + yield createStreamEvent.reasoning( + part.text, + toProviderOptions((part as any).providerMetadata) + ) + break + + case 'tool-input-start': + toolArgumentBuffers.set(part.id, '') + yield createStreamEvent.toolCallStart( + part.id, + part.toolName, + toProviderOptions((part as any).providerMetadata) + ) + break + + case 'tool-input-delta': + toolArgumentBuffers.set(part.id, `${toolArgumentBuffers.get(part.id) ?? ''}${part.delta}`) + yield createStreamEvent.toolCallChunk( + part.id, + part.delta, + toProviderOptions((part as any).providerMetadata) + ) + break + + case 'tool-input-end': + endedToolCalls.add(part.id) + yield createStreamEvent.toolCallEnd( + part.id, + toolArgumentBuffers.get(part.id), + toProviderOptions((part as any).providerMetadata) + ) + break + + case 'tool-call': + if (!endedToolCalls.has(part.toolCallId)) { + const serializedInput = JSON.stringify(part.input ?? {}) + const providerOptions = toProviderOptions((part as any).providerMetadata) + yield createStreamEvent.toolCallStart(part.toolCallId, part.toolName, providerOptions) + yield createStreamEvent.toolCallChunk(part.toolCallId, serializedInput, providerOptions) + yield createStreamEvent.toolCallEnd(part.toolCallId, serializedInput, providerOptions) + endedToolCalls.add(part.toolCallId) + } + break + + case 'file': { + const mediaType = part.file.mediaType + if (typeof mediaType !== 'string' || !mediaType.startsWith('image/')) { + break + } + + const dataUrl = `data:${mediaType};base64,${part.file.base64}` + let cachedImage = dataUrl + + if (options.cacheImage) { + try { + cachedImage = await options.cacheImage(dataUrl) + } catch (error) { + console.warn('[AI SDK Stream Adapter] Failed to cache image part:', error) + } + } + + yield createStreamEvent.imageData({ + data: cachedImage, + mimeType: mediaType + }) + break + } + + case 'finish': + if (!options.supportsNativeTools) { + yield* emitLegacyTextBuffer(true) + } + yield toUsageEvent(part.totalUsage) + yield createStreamEvent.stop( + !options.supportsNativeTools && legacyToolUseDetected + ? 'tool_use' + : mapFinishReason(part.finishReason) + ) + break + + case 'error': + yield createStreamEvent.error( + part.error instanceof Error ? part.error.message : String(part.error) + ) + yield createStreamEvent.stop('error') + break + } + } +} diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/toolMapper.ts b/src/main/presenter/llmProviderPresenter/aiSdk/toolMapper.ts new file mode 100644 index 000000000..e5cae14fc --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/toolMapper.ts @@ -0,0 +1,184 @@ +import type { MCPToolDefinition } from '@shared/presenter' +import { jsonSchema, tool, type ToolSet } from 'ai' + +type JsonSchema = Record +const UNSAFE_TOOL_NAMES = new Set(['__proto__', 'constructor', 'prototype']) + +function isObjectSchema(value: unknown): value is JsonSchema { + return Boolean(value) && typeof value === 'object' && !Array.isArray(value) +} + +function intersectRequiredKeys(variants: JsonSchema[]): string[] | undefined { + if (!variants.length) { + return undefined + } + + const requiredLists = variants.map((variant) => + Array.isArray(variant.required) + ? variant.required.filter((key): key is string => typeof key === 'string') + : [] + ) + + const [first, ...rest] = requiredLists + const intersection = first.filter((key) => rest.every((required) => required.includes(key))) + + return intersection.length > 0 ? intersection : undefined +} + +function unionRequiredKeys(variants: JsonSchema[]): string[] | undefined { + const union = Array.from( + new Set( + variants.flatMap((variant) => + Array.isArray(variant.required) + ? variant.required.filter((key): key is string => typeof key === 'string') + : [] + ) + ) + ) + + return union.length > 0 ? union : undefined +} + +function mergePropertySchemas(existing: unknown, incoming: unknown): unknown { + if (!isObjectSchema(existing) || !isObjectSchema(incoming)) { + return incoming + } + + if (JSON.stringify(existing) === JSON.stringify(incoming)) { + return existing + } + + if ( + existing.type === incoming.type && + typeof existing.const === 'string' && + typeof incoming.const === 'string' + ) { + return { + type: existing.type, + enum: Array.from(new Set([existing.const, incoming.const])) + } + } + + return { + anyOf: [existing, incoming] + } +} + +function mergeVariantProperties(variants: JsonSchema[]): Record | undefined { + const propertyMaps = variants + .map((variant) => (isObjectSchema(variant.properties) ? variant.properties : undefined)) + .filter((value): value is Record => Boolean(value)) + + if (!propertyMaps.length) { + return undefined + } + + const merged: Record = Object.create(null) + + for (const propertyMap of propertyMaps) { + for (const [key, value] of Object.entries(propertyMap)) { + if (UNSAFE_TOOL_NAMES.has(key)) { + continue + } + + merged[key] = key in merged ? mergePropertySchemas(merged[key], value) : value + } + } + + return merged +} + +function normalizeSchemaNode(node: unknown): unknown { + if (Array.isArray(node)) { + return node.map((item) => normalizeSchemaNode(item)) + } + + if (!isObjectSchema(node)) { + return node + } + + const normalized: JsonSchema = Object.fromEntries( + Object.entries(node).map(([key, value]) => [key, normalizeSchemaNode(value)]) + ) + + if (typeof normalized.type === 'string' && normalized.type.toLowerCase() === 'none') { + delete normalized.type + } + + return normalized +} + +export function normalizeToolInputSchema(schema: Record): Record { + const normalized = normalizeSchemaNode(schema) + if (!isObjectSchema(normalized)) { + return { + type: 'object', + properties: {} + } + } + + if (normalized.type === 'object') { + return normalized + } + + const branchKey = ['anyOf', 'oneOf', 'allOf'].find((key) => Array.isArray(normalized[key])) + const variants = branchKey + ? (normalized[branchKey] as unknown[]) + .filter(isObjectSchema) + .filter((item) => item.type === 'object') + : [] + + if (!variants.length) { + const required = Array.isArray(normalized.required) + ? normalized.required.filter((key): key is string => typeof key === 'string') + : undefined + const additionalProperties = + typeof normalized.additionalProperties === 'boolean' || + isObjectSchema(normalized.additionalProperties) + ? normalized.additionalProperties + : undefined + + return { + type: 'object', + properties: isObjectSchema(normalized.properties) ? normalized.properties : {}, + ...(required?.length ? { required } : {}), + ...(additionalProperties !== undefined ? { additionalProperties } : {}) + } + } + + const { type: _type, properties: _properties, required: _required, ...rest } = normalized + const sanitizedRest = Object.fromEntries( + Object.entries(rest).filter(([key]) => !['anyOf', 'oneOf', 'allOf'].includes(key)) + ) + const required = + branchKey === 'allOf' ? unionRequiredKeys(variants) : intersectRequiredKeys(variants) + + return { + ...sanitizedRest, + type: 'object', + properties: mergeVariantProperties(variants) ?? {}, + ...(required ? { required } : {}), + ...(variants.every((variant) => variant.additionalProperties === false) + ? { additionalProperties: false } + : {}) + } +} + +export function mcpToolsToAISDKTools(tools: MCPToolDefinition[]): ToolSet { + return tools.reduce( + (acc, toolDef) => { + const name = toolDef.function.name + if (!name || UNSAFE_TOOL_NAMES.has(name)) { + return acc + } + + acc[name] = tool({ + description: toolDef.function.description, + inputSchema: jsonSchema(normalizeToolInputSchema(toolDef.function.parameters as JsonSchema)) + }) + + return acc + }, + Object.create(null) as ToolSet + ) +} diff --git a/src/main/presenter/llmProviderPresenter/aiSdk/toolProtocol.ts b/src/main/presenter/llmProviderPresenter/aiSdk/toolProtocol.ts new file mode 100644 index 000000000..2002ece98 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/aiSdk/toolProtocol.ts @@ -0,0 +1,224 @@ +import type { ChatMessage } from '@shared/presenter' +import { generateId } from 'ai' +import { jsonrepair } from 'jsonrepair' + +export interface ParsedLegacyToolCall { + id: string + type: 'function' + function: { + name: string + arguments: string + } +} + +export function serializeChatMessageContent(content: ChatMessage['content']): string { + if (content == null) { + return '' + } + + if (typeof content === 'string') { + return content + } + + return content + .map((part) => { + if (part.type === 'text') { + return part.text + } + + if (part.type === 'image_url') { + return part.image_url.url + } + + return JSON.stringify(part) + }) + .join('\n') +} + +export function tryParseJson(value: string): unknown { + try { + return JSON.parse(value) + } catch { + try { + return JSON.parse(jsonrepair(value)) + } catch { + return undefined + } + } +} + +export function parseLegacyFunctionCalls( + response: string, + fallbackIdPrefix = 'tool-call' +): ParsedLegacyToolCall[] { + const functionCallMatches = response.match(/([\s\S]*?)<\/function_call>/g) + if (!functionCallMatches) { + return [] + } + + return functionCallMatches + .map((match, index) => { + const content = match.replace(/<\/?function_call>/g, '').trim() + if (!content) { + return null + } + + const parsed = tryParseJson(content) + if (!parsed || typeof parsed !== 'object') { + return null + } + + const record = parsed as Record + let functionName: string | undefined + let functionArgs: unknown + + if (record.function_call && typeof record.function_call === 'object') { + const call = record.function_call as Record + functionName = typeof call.name === 'string' ? call.name : undefined + functionArgs = call.arguments + } else if (typeof record.name === 'string' && record.arguments !== undefined) { + functionName = record.name + functionArgs = record.arguments + } else if (record.function && typeof record.function === 'object') { + const call = record.function as Record + functionName = typeof call.name === 'string' ? call.name : undefined + functionArgs = call.arguments + } else { + const keys = Object.keys(record) + if (keys.length === 1) { + const inner = record[keys[0]] + if (inner && typeof inner === 'object') { + const nested = inner as Record + if (typeof nested.name === 'string' && nested.arguments !== undefined) { + functionName = nested.name + functionArgs = nested.arguments + } + } + } + } + + if (!functionName || functionArgs === undefined) { + return null + } + + return { + id: `${fallbackIdPrefix}-${index}-${generateId()}`, + type: 'function' as const, + function: { + name: functionName, + arguments: + typeof functionArgs === 'string' ? functionArgs : JSON.stringify(functionArgs ?? {}) + } + } + }) + .filter((call): call is ParsedLegacyToolCall => call !== null) +} + +export function buildFunctionCallRecordContent( + name: string, + args: unknown, + response: unknown +): string { + return `${JSON.stringify({ + function_call_record: { + name, + arguments: args, + response + } + })}` +} + +export function splitMergedToolContent(content: string, expectedParts: number): string[] | null { + if (!content || expectedParts <= 1) { + return null + } + + const trimmed = content.trim() + if (!trimmed) { + return null + } + + const splitByDelimiter = (delimiter: RegExp): string[] | null => { + const parts = trimmed + .split(delimiter) + .map((part) => part.trim()) + .filter((part) => part.length > 0) + + return parts.length === expectedParts ? parts : null + } + + const tryJsonArray = (): string[] | null => { + if (!trimmed.startsWith('[')) { + return null + } + + const parsed = tryParseJson(trimmed) + if (!Array.isArray(parsed) || parsed.length !== expectedParts) { + return null + } + + return parsed.map((part) => (typeof part === 'string' ? part : JSON.stringify(part))) + } + + const strategies: Array<() => string[] | null> = [ + tryJsonArray, + () => splitByDelimiter(/\n-{3,}\n+/g), + () => splitByDelimiter(/\n={3,}\n+/g), + () => splitByDelimiter(/\n\*{3,}\n+/g), + () => splitByDelimiter(/\n\s*\n+/g) + ] + + for (const strategy of strategies) { + const parts = strategy() + if (parts) { + return parts + } + } + + return null +} + +export function toToolResultOutput(value: unknown): any { + if (Array.isArray(value)) { + return { + type: 'content', + value: value.map((entry) => { + if (typeof entry === 'string') { + return { + type: 'text', + text: entry + } + } + + if (entry && typeof entry === 'object') { + return entry as Record + } + + return { + type: 'text', + text: JSON.stringify(entry) + } + }) + } + } + + if (typeof value === 'string') { + const parsed = tryParseJson(value) + if (parsed !== undefined) { + return { + type: 'json', + value: parsed + } + } + + return { + type: 'text', + value + } + } + + return { + type: 'json', + value: value ?? null + } +} diff --git a/src/main/presenter/llmProviderPresenter/baseProvider.ts b/src/main/presenter/llmProviderPresenter/baseProvider.ts index 137cc453e..2ef2818e5 100644 --- a/src/main/presenter/llmProviderPresenter/baseProvider.ts +++ b/src/main/presenter/llmProviderPresenter/baseProvider.ts @@ -17,6 +17,7 @@ import { CONFIG_EVENTS } from '@/events' import logger from '@shared/logger' import { resolveRequestTraceContext, type ProviderRequestTracePayload } from './requestTrace' import type { ProviderMcpRuntimePort } from './runtimePorts' +import { normalizeToolInputSchema } from './aiSdk/toolMapper' /** * Base LLM Provider Abstract Class @@ -81,6 +82,14 @@ export abstract class BaseLLMProvider { return this.provider.capabilityProviderId || this.provider.id } + private escapeXmlAttribute(value: string): string { + return value + .replace(/&/g, '&') + .replace(/"/g, '"') + .replace(//g, '>') + } + /** * Load cached model data from configuration * Called in constructor to avoid needing to re-fetch model lists every time @@ -717,24 +726,82 @@ ${this.convertToolsToXml(tools)} * @returns XML 格式的工具定义字符串 */ protected convertToolsToXml(tools: MCPToolDefinition[]): string { + const resolveParameterType = (parameter: unknown): string | undefined => { + if (!parameter || typeof parameter !== 'object' || Array.isArray(parameter)) { + return undefined + } + + if (typeof (parameter as { type?: unknown }).type === 'string') { + return (parameter as { type: string }).type + } + + for (const branchKey of ['anyOf', 'oneOf', 'allOf'] as const) { + const branches = (parameter as Record)[branchKey] + if (!Array.isArray(branches)) { + continue + } + + const types = Array.from( + new Set( + branches + .filter( + (branch): branch is Record => + Boolean(branch) && typeof branch === 'object' && !Array.isArray(branch) + ) + .map((branch) => branch.type) + .filter((type): type is string => typeof type === 'string') + ) + ) + + if (types.length === 1) { + return types[0] + } + } + + return undefined + } + const xmlTools = tools .map((tool) => { const { name, description, parameters } = tool.function - const { properties, required = [] } = parameters + const normalizedParameters = normalizeToolInputSchema( + (parameters as Record | undefined) ?? {} + ) + const properties = + normalizedParameters.properties && + typeof normalizedParameters.properties === 'object' && + !Array.isArray(normalizedParameters.properties) + ? (normalizedParameters.properties as Record) + : {} + const required = Array.isArray(normalizedParameters.required) + ? normalizedParameters.required.filter( + (value): value is string => typeof value === 'string' + ) + : [] // 构建参数 XML const paramsXml = Object.entries(properties) .map(([paramName, paramDef]) => { const requiredAttr = required.includes(paramName) ? ' required="true"' : '' - const descriptionAttr = paramDef.description - ? ` description="${paramDef.description}"` - : '' - const typeAttr = paramDef.type ? ` type="${paramDef.type}"` : '' + const paramMeta = + paramDef && typeof paramDef === 'object' && !Array.isArray(paramDef) + ? (paramDef as Record) + : {} + const descriptionAttr = + typeof paramMeta.description === 'string' + ? ` description="${this.escapeXmlAttribute(paramMeta.description)}"` + : '' + const paramType = resolveParameterType(paramMeta) + const typeAttr = paramType ? ` type="${paramType}"` : '' return `` }) .join('\n ') + if (!paramsXml) { + return `` + } + // 构建工具 XML return ` ${paramsXml} diff --git a/src/main/presenter/llmProviderPresenter/index.ts b/src/main/presenter/llmProviderPresenter/index.ts index 30c035f69..d4dc6b31b 100644 --- a/src/main/presenter/llmProviderPresenter/index.ts +++ b/src/main/presenter/llmProviderPresenter/index.ts @@ -91,8 +91,7 @@ export class LLMProviderPresenter implements ILlmProviderPresenter { getProviderInstance: this.getProviderInstance.bind(this) }) this.modelScopeSyncManager = new ModelScopeSyncManager({ - configPresenter, - getProviderInstance: this.getProviderInstance.bind(this) + configPresenter }) this.rateLimitManager.initializeProviderRateLimitConfigs() diff --git a/src/main/presenter/llmProviderPresenter/managers/modelScopeSyncManager.ts b/src/main/presenter/llmProviderPresenter/managers/modelScopeSyncManager.ts index 4e2cc6fea..ac9d960dd 100644 --- a/src/main/presenter/llmProviderPresenter/managers/modelScopeSyncManager.ts +++ b/src/main/presenter/llmProviderPresenter/managers/modelScopeSyncManager.ts @@ -4,24 +4,19 @@ import { ModelScopeMcpSyncOptions, ModelScopeMcpSyncResult } from '@shared/presenter' -import { BaseLLMProvider } from '../baseProvider' -import { ModelscopeProvider, ModelScopeMcpServer } from '../providers/modelscopeProvider' +import { + convertModelScopeMcpServerToConfig, + fetchModelScopeMcpServers, + ModelScopeMcpServer +} from '../modelScopeMcp' interface ModelScopeSyncManagerOptions { configPresenter: IConfigPresenter - getProviderInstance: (providerId: string) => BaseLLMProvider } export class ModelScopeSyncManager { constructor(private readonly options: ModelScopeSyncManagerOptions) {} - private isModelscopeProvider(provider: BaseLLMProvider): provider is ModelscopeProvider { - return ( - typeof (provider as Partial).syncMcpServers === 'function' && - typeof (provider as Partial).convertMcpServerToConfig === 'function' - ) - } - async syncModelScopeMcpServers( providerId: string, syncOptions?: ModelScopeMcpSyncOptions @@ -35,10 +30,10 @@ export class ModelScopeSyncManager { throw new Error(error) } - const provider = this.options.getProviderInstance(providerId) + const provider = this.options.configPresenter.getProviderById(providerId) - if (!this.isModelscopeProvider(provider)) { - const error = 'Provider is not a ModelScope provider instance' + if (!provider) { + const error = 'Provider is not configured' console.error(`[ModelScope MCP Sync] Error: ${error}`) throw new Error(error) } @@ -56,7 +51,7 @@ export class ModelScopeSyncManager { const syncTask = async () => { console.log(`[ModelScope MCP Sync] Fetching MCP servers from ModelScope API...`) - const mcpResponse = await provider.syncMcpServers(syncOptions) + const mcpResponse = await fetchModelScopeMcpServers(provider, syncOptions) if (!mcpResponse || !mcpResponse.success || !mcpResponse.data?.mcp_server_list) { const errorMsg = 'Invalid response from ModelScope MCP API' @@ -84,7 +79,7 @@ export class ModelScopeSyncManager { return null } - const config = provider.convertMcpServerToConfig(server) + const config = convertModelScopeMcpServerToConfig(server) const name = server.name || server.id const displayName = server.chinese_name || server.name || server.id diff --git a/src/main/presenter/llmProviderPresenter/managers/providerInstanceManager.ts b/src/main/presenter/llmProviderPresenter/managers/providerInstanceManager.ts index 84380e734..33c78137a 100644 --- a/src/main/presenter/llmProviderPresenter/managers/providerInstanceManager.ts +++ b/src/main/presenter/llmProviderPresenter/managers/providerInstanceManager.ts @@ -1,53 +1,19 @@ import { ProviderBatchUpdate, ProviderChange } from '@shared/provider-operations' -import { IConfigPresenter, LLM_PROVIDER } from '@shared/presenter' +import { LLM_PROVIDER } from '@shared/presenter' import { BaseLLMProvider } from '../baseProvider' -import { DeepseekProvider } from '../providers/deepseekProvider' -import { SiliconcloudProvider } from '../providers/siliconcloudProvider' -import { DashscopeProvider } from '../providers/dashscopeProvider' -import { OpenAICompatibleProvider } from '../providers/openAICompatibleProvider' -import { PPIOProvider } from '../providers/ppioProvider' -import { TokenFluxProvider } from '../providers/tokenfluxProvider' -import { GeminiProvider } from '../providers/geminiProvider' -import { GithubProvider } from '../providers/githubProvider' import { GithubCopilotProvider } from '../providers/githubCopilotProvider' import { OllamaProvider } from '../providers/ollamaProvider' -import { AnthropicProvider } from '../providers/anthropicProvider' -import { AwsBedrockProvider } from '../providers/awsBedrockProvider' -import { DoubaoProvider } from '../providers/doubaoProvider' -import { TogetherProvider } from '../providers/togetherProvider' -import { GrokProvider } from '../providers/grokProvider' -import { GroqProvider } from '../providers/groqProvider' -import { ZhipuProvider } from '../providers/zhipuProvider' -import { LMStudioProvider } from '../providers/lmstudioProvider' -import { OpenAIResponsesProvider } from '../providers/openAIResponsesProvider' -import { CherryInProvider } from '../providers/cherryInProvider' -import { VertexProvider } from '../providers/vertexProvider' -import { OpenRouterProvider } from '../providers/openRouterProvider' -import { MinimaxProvider } from '../providers/minimaxProvider' -import { AihubmixProvider } from '../providers/aihubmixProvider' -import { _302AIProvider } from '../providers/_302AIProvider' -import { ModelscopeProvider } from '../providers/modelscopeProvider' import { AcpProvider } from '../providers/acpProvider' -import { VercelAIGatewayProvider } from '../providers/vercelAIGatewayProvider' -import { PoeProvider } from '../providers/poeProvider' -import { JiekouProvider } from '../providers/jiekouProvider' -import { ZenmuxProvider } from '../providers/zenmuxProvider' -import { O3fanProvider } from '../providers/o3fanProvider' import { VoiceAIProvider } from '../providers/voiceAIProvider' -import { NewApiProvider } from '../providers/newApiProvider' +import { AiSdkProvider } from '../providers/aiSdkProvider' import { RateLimitManager } from './rateLimitManager' import { StreamState } from '../types' import { AcpSessionPersistence } from '../acp' import type { ProviderMcpRuntimePort } from '../runtimePorts' - -type ProviderConstructor = new ( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - ...rest: any[] -) => BaseLLMProvider +import { resolveAiSdkProviderDefinition } from '../providerRegistry' interface ProviderInstanceManagerOptions { - configPresenter: IConfigPresenter + configPresenter: import('@shared/presenter').IConfigPresenter activeStreams: Map rateLimitManager: RateLimitManager getCurrentProviderId: () => string | null @@ -57,90 +23,11 @@ interface ProviderInstanceManagerOptions { } export class ProviderInstanceManager { - private static readonly PROVIDER_ID_MAP = ProviderInstanceManager.buildProviderIdMap() - private static readonly PROVIDER_TYPE_MAP = ProviderInstanceManager.buildProviderTypeMap() - private readonly providers: Map = new Map() private readonly providerInstances: Map = new Map() constructor(private readonly options: ProviderInstanceManagerOptions) {} - private static buildProviderIdMap(): Map { - return new Map([ - ['302ai', _302AIProvider], - ['minimax', MinimaxProvider], - ['grok', GrokProvider], - ['openrouter', OpenRouterProvider], - ['ppio', PPIOProvider], - ['tokenflux', TokenFluxProvider], - ['deepseek', DeepseekProvider], - ['aihubmix', AihubmixProvider], - ['modelscope', ModelscopeProvider], - ['silicon', SiliconcloudProvider], - ['siliconcloud', SiliconcloudProvider], - ['dashscope', DashscopeProvider], - ['gemini', GeminiProvider], - ['zhipu', ZhipuProvider], - ['vertex', VertexProvider], - ['github', GithubProvider], - ['github-copilot', GithubCopilotProvider], - ['ollama', OllamaProvider], - ['anthropic', AnthropicProvider], - ['doubao', DoubaoProvider], - ['openai', OpenAIResponsesProvider], - ['voiceai', VoiceAIProvider], - ['openai-responses', OpenAIResponsesProvider], - ['cherryin', CherryInProvider], - ['new-api', NewApiProvider], - ['lmstudio', LMStudioProvider], - ['together', TogetherProvider], - ['groq', GroqProvider], - ['vercel-ai-gateway', VercelAIGatewayProvider], - ['poe', PoeProvider], - ['aws-bedrock', AwsBedrockProvider], - ['jiekou', JiekouProvider], - ['zenmux', ZenmuxProvider], - ['o3fan', O3fanProvider], - ['acp', AcpProvider] - ]) - } - - private static buildProviderTypeMap(): Map { - return new Map([ - ['minimax', AnthropicProvider], - ['deepseek', DeepseekProvider], - ['silicon', SiliconcloudProvider], - ['siliconcloud', SiliconcloudProvider], - ['dashscope', DashscopeProvider], - ['ppio', PPIOProvider], - ['gemini', GeminiProvider], - ['vertex', VertexProvider], - ['zhipu', ZhipuProvider], - ['github', GithubProvider], - ['github-copilot', GithubCopilotProvider], - ['ollama', OllamaProvider], - ['anthropic', AnthropicProvider], - ['doubao', DoubaoProvider], - ['openai', OpenAIResponsesProvider], - ['openai-completions', OpenAICompatibleProvider], - ['voiceai', VoiceAIProvider], - ['openai-compatible', OpenAICompatibleProvider], - ['openai-responses', OpenAIResponsesProvider], - ['new-api', NewApiProvider], - ['lmstudio', LMStudioProvider], - ['together', TogetherProvider], - ['groq', GroqProvider], - ['grok', GrokProvider], - ['vercel-ai-gateway', VercelAIGatewayProvider], - ['poe', PoeProvider], - ['aws-bedrock', AwsBedrockProvider], - ['jiekou', JiekouProvider], - ['zenmux', ZenmuxProvider], - ['acp', AcpProvider], - ['o3fan', O3fanProvider] - ]) - } - init(): void { const providers = this.options.configPresenter.getProviders() for (const provider of providers) { @@ -390,22 +277,6 @@ export class ProviderInstanceManager { */ private createProviderInstance(provider: LLM_PROVIDER): BaseLLMProvider | undefined { try { - let ProviderClass = ProviderInstanceManager.PROVIDER_ID_MAP.get(provider.id) - - if (!ProviderClass) { - ProviderClass = ProviderInstanceManager.PROVIDER_TYPE_MAP.get(provider.apiType) - if (ProviderClass) { - console.log( - `No specific provider found for id: ${provider.id}, falling back to apiType: ${provider.apiType}` - ) - } - } - - if (!ProviderClass) { - console.warn(`Unknown provider type: ${provider.apiType} for provider id: ${provider.id}`) - return undefined - } - if (provider.id === 'acp') { if (!this.options.acpSessionPersistence) { throw new Error('ACP session persistence is not configured') @@ -418,7 +289,25 @@ export class ProviderInstanceManager { ) } - return new ProviderClass(provider, this.options.configPresenter, this.options.mcpRuntime) + if (provider.id === 'github-copilot') { + return new GithubCopilotProvider(provider, this.options.configPresenter) + } + + if (provider.id === 'voiceai') { + return new VoiceAIProvider(provider, this.options.configPresenter) + } + + if (provider.id === 'ollama' || provider.apiType === 'ollama') { + return new OllamaProvider(provider, this.options.configPresenter, this.options.mcpRuntime) + } + + const definition = resolveAiSdkProviderDefinition(provider) + if (!definition) { + console.warn(`Unknown provider type: ${provider.apiType} for provider id: ${provider.id}`) + return undefined + } + + return new AiSdkProvider(provider, this.options.configPresenter, this.options.mcpRuntime) } catch (error) { console.error(`Failed to create provider instance for ${provider.id}:`, error) return undefined diff --git a/src/main/presenter/llmProviderPresenter/modelScopeMcp.ts b/src/main/presenter/llmProviderPresenter/modelScopeMcp.ts new file mode 100644 index 000000000..0698298d0 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/modelScopeMcp.ts @@ -0,0 +1,139 @@ +import type { LLM_PROVIDER, MCPServerConfig, ModelScopeMcpSyncOptions } from '@shared/presenter' + +export interface ModelScopeMcpServerResponse { + code: number + data: { + mcp_server_list: ModelScopeMcpServer[] + total_count: number + } + message: string + request_id: string + success: boolean +} + +export interface ModelScopeMcpServer { + name: string + description: string + id: string + chinese_name?: string + logo_url: string + operational_urls: Array<{ + id: string + url: string + }> + tags: string[] + locales: { + zh: { + name: string + description: string + } + en: { + name: string + description: string + } + } +} + +export async function fetchModelScopeMcpServers( + provider: Pick, + _syncOptions?: ModelScopeMcpSyncOptions +): Promise { + if (!provider.apiKey) { + throw new Error('API key is required for MCP sync') + } + + const response = await fetch('https://www.modelscope.cn/openapi/v1/mcp/servers/operational', { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${provider.apiKey}` + }, + signal: AbortSignal.timeout(30000) + }) + + if (response.status === 401 || response.status === 403) { + throw new Error('ModelScope MCP sync unauthorized: Invalid or expired API key') + } + + if (response.status === 500 || !response.ok) { + const errorText = await response.text() + throw new Error( + `ModelScope MCP sync failed: ${response.status} ${response.statusText} - ${errorText}` + ) + } + + const data: ModelScopeMcpServerResponse = await response.json() + + if (!data.success) { + throw new Error(`ModelScope MCP sync failed: ${data.message}`) + } + + return data +} + +export function convertModelScopeMcpServerToConfig( + mcpServer: ModelScopeMcpServer +): MCPServerConfig { + if (!mcpServer.operational_urls || mcpServer.operational_urls.length === 0) { + throw new Error(`No operational URLs found for server ${mcpServer.id}`) + } + + const baseUrl = mcpServer.operational_urls[0].url + const emojis = [ + '🔧', + '⚡', + '🚀', + '🔨', + '⚙️', + '🛠️', + '🔥', + '💡', + '⭐', + '🎯', + '🎨', + '🔮', + '💎', + '🎪', + '🎭', + '🔬', + '📱', + '💻', + '🖥️', + '⌨️', + '🖱️', + '📡', + '📣', + '🔔', + '📻', + '📷', + '🔍', + '💰', + '🎮', + '📝', + '📊', + '📦', + '✉️', + '🗞️', + '🔖' + ] + const randomEmoji = emojis[Math.floor(Math.random() * emojis.length)] + const displayName = mcpServer.chinese_name || mcpServer.name || mcpServer.id + + return { + command: '', + args: [], + env: {}, + descriptions: + mcpServer.locales?.zh?.description || + mcpServer.description || + `ModelScope MCP Server: ${displayName}`, + icons: randomEmoji, + autoApprove: ['all'], + enabled: false, + disable: false, + type: 'sse', + baseUrl, + source: 'modelscope', + sourceId: mcpServer.id + } +} diff --git a/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts b/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts index 8de06538c..ccab3f652 100644 --- a/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts +++ b/src/main/presenter/llmProviderPresenter/promptCacheStrategy.ts @@ -1,6 +1,4 @@ import { createHash } from 'crypto' -import Anthropic from '@anthropic-ai/sdk' -import type { ChatCompletionContentPart, ChatCompletionMessageParam } from 'openai/resources' import type { MCPToolDefinition } from '@shared/presenter' import { resolvePromptCacheMode, type PromptCacheMode } from './promptCacheCapabilities' @@ -32,7 +30,23 @@ type EphemeralCacheControl = { type: 'ephemeral' } const EPHEMERAL_CACHE_CONTROL: EphemeralCacheControl = { type: 'ephemeral' } -type AnthropicTextBlockWithCache = Anthropic.TextBlockParam & { +type PromptCacheTextPart = { + type: 'text' + text: string + cache_control?: EphemeralCacheControl +} + +type PromptCacheOpenAIMessage = { + role: string + content?: string | Array> +} + +type PromptCacheAnthropicMessage = { + role: 'user' | 'assistant' + content: string | Array> +} + +type AnthropicTextBlockWithCache = PromptCacheTextPart & { cache_control?: EphemeralCacheControl } @@ -59,7 +73,7 @@ function buildPromptCacheKey( } function findOpenAIChatBreakpoint( - messages: ChatCompletionMessageParam[] + messages: PromptCacheOpenAIMessage[] ): PromptCacheBreakpointPlan | undefined { let prefixEnd = messages.length @@ -102,7 +116,7 @@ function findOpenAIChatBreakpoint( } function findAnthropicBreakpoint( - messages: Anthropic.MessageParam[] + messages: PromptCacheAnthropicMessage[] ): PromptCacheBreakpointPlan | undefined { let prefixEnd = messages.length @@ -168,8 +182,8 @@ export function resolvePromptCachePlan(params: ResolvePromptCachePlanParams): Pr const breakpointPlan = params.apiType === 'anthropic' - ? findAnthropicBreakpoint(params.messages as Anthropic.MessageParam[]) - : findOpenAIChatBreakpoint(params.messages as ChatCompletionMessageParam[]) + ? findAnthropicBreakpoint(params.messages as PromptCacheAnthropicMessage[]) + : findOpenAIChatBreakpoint(params.messages as PromptCacheOpenAIMessage[]) return { mode, @@ -207,9 +221,9 @@ export function applyAnthropicTopLevelCacheControl { @@ -249,10 +263,11 @@ export function applyOpenAIChatExplicitCacheBreakpoint( } return { - ...part, + type: 'text', + text: part.text, cache_control: EPHEMERAL_CACHE_CONTROL - } as unknown as ChatCompletionContentPart - }) as ChatCompletionMessageParam['content'] + } satisfies PromptCacheTextPart + }) as PromptCacheOpenAIMessage['content'] } else { return messages } @@ -262,15 +277,15 @@ export function applyOpenAIChatExplicitCacheBreakpoint( ? ({ ...message, content: nextContent - } as ChatCompletionMessageParam) + } as PromptCacheOpenAIMessage) : message ) } export function applyAnthropicExplicitCacheBreakpoint( - messages: Anthropic.MessageParam[], + messages: PromptCacheAnthropicMessage[], plan: PromptCachePlan -): Anthropic.MessageParam[] { +): PromptCacheAnthropicMessage[] { if (plan.mode !== 'anthropic_explicit' || !plan.breakpointPlan) { return messages } @@ -283,7 +298,7 @@ export function applyAnthropicExplicitCacheBreakpoint( } const content = target.content - let nextContent: Anthropic.MessageParam['content'] = content + let nextContent: PromptCacheAnthropicMessage['content'] = content if (typeof content === 'string') { if (!content.trim() || contentIndex !== 0) { @@ -309,7 +324,8 @@ export function applyAnthropicExplicitCacheBreakpoint( } return { - ...block, + type: 'text', + text: block.text, cache_control: EPHEMERAL_CACHE_CONTROL } satisfies AnthropicTextBlockWithCache }) @@ -322,7 +338,7 @@ export function applyAnthropicExplicitCacheBreakpoint( ? ({ ...message, content: nextContent - } as Anthropic.MessageParam) + } as PromptCacheAnthropicMessage) : message ) } diff --git a/src/main/presenter/llmProviderPresenter/providerRegistry.ts b/src/main/presenter/llmProviderPresenter/providerRegistry.ts new file mode 100644 index 000000000..950802ff2 --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/providerRegistry.ts @@ -0,0 +1,411 @@ +import type { LLM_PROVIDER } from '@shared/presenter' +import type { AiSdkProviderKind } from './aiSdk/providerFactory' + +export type AiSdkBehaviorPreset = + | 'openai' + | 'title-summary' + | 'english-summary' + | 'chinese-summary' + | 'anthropic' + | 'google' + +export type AiSdkModelSourceStrategy = + | 'openai' + | 'github' + | 'together' + | 'provider-db' + | 'config-db' + | 'bedrock' + | 'new-api' + | 'openrouter' + | 'ppio' + | 'groq' + | 'tokenflux' + | '302ai' + +export type AiSdkKeyStatusStrategy = + | 'none' + | 'openrouter' + | 'deepseek' + | 'ppio' + | 'tokenflux' + | '302ai' + | 'cherryin' + | 'modelscope' + | 'siliconcloud' + +export type AiSdkCheckStrategy = 'fetch-models' | 'key-status' | 'generate-text' + +export type AiSdkCredentialStrategy = 'none' | 'api-key' | 'anthropic' | 'vertex' | 'bedrock' + +export type AiSdkRouteStrategy = 'none' | 'grok' | 'new-api' | 'zenmux' + +export type AiSdkEmbeddingStrategy = 'none' | 'openai' | 'google' | 'new-api' | 'zenmux' + +export interface AiSdkProviderDefinition { + runtimeKind: AiSdkProviderKind + behaviorPreset: AiSdkBehaviorPreset + modelSource: AiSdkModelSourceStrategy + checkStrategy: AiSdkCheckStrategy + credentialStrategy?: AiSdkCredentialStrategy + keyStatusStrategy?: AiSdkKeyStatusStrategy + routeStrategy?: AiSdkRouteStrategy + embeddingStrategy?: AiSdkEmbeddingStrategy + providerDbGroup?: string + providerDbSourceId?: string + checkModelId?: string + checkPrompt?: string + checkTemperature?: number + checkMaxTokens?: number + defaultHeadersPatch?: Record +} + +const createDefinition = (definition: AiSdkProviderDefinition): AiSdkProviderDefinition => + definition + +const OPENAI_BASE = createDefinition({ + runtimeKind: 'openai-compatible', + behaviorPreset: 'openai', + modelSource: 'openai', + checkStrategy: 'fetch-models', + keyStatusStrategy: 'none', + routeStrategy: 'none', + embeddingStrategy: 'openai' +}) + +const TITLE_SUMMARY_OPENAI = createDefinition({ + ...OPENAI_BASE, + behaviorPreset: 'title-summary' +}) + +const ENGLISH_SUMMARY_OPENAI = createDefinition({ + ...OPENAI_BASE, + behaviorPreset: 'english-summary' +}) + +const CHINESE_SUMMARY_OPENAI = createDefinition({ + ...OPENAI_BASE, + behaviorPreset: 'chinese-summary' +}) + +const PROVIDER_ID_REGISTRY = new Map([ + [ + '302ai', + createDefinition({ + ...OPENAI_BASE, + modelSource: '302ai', + checkStrategy: 'key-status', + keyStatusStrategy: '302ai' + }) + ], + [ + 'aihubmix', + createDefinition({ + ...OPENAI_BASE, + defaultHeadersPatch: { + 'APP-Code': 'SMUE7630' + } + }) + ], + [ + 'anthropic', + createDefinition({ + runtimeKind: 'anthropic', + behaviorPreset: 'anthropic', + modelSource: 'config-db', + checkStrategy: 'generate-text', + credentialStrategy: 'anthropic', + keyStatusStrategy: 'none', + routeStrategy: 'none', + embeddingStrategy: 'none', + checkModelId: 'claude-sonnet-4-5-20250929', + checkPrompt: 'Hello', + checkTemperature: 0.2, + checkMaxTokens: 16 + }) + ], + [ + 'aws-bedrock', + createDefinition({ + runtimeKind: 'aws-bedrock', + behaviorPreset: 'anthropic', + modelSource: 'bedrock', + checkStrategy: 'generate-text', + credentialStrategy: 'bedrock', + keyStatusStrategy: 'none', + routeStrategy: 'none', + embeddingStrategy: 'none', + checkModelId: 'anthropic.claude-3-5-sonnet-20240620-v1:0', + checkPrompt: 'Hi', + checkTemperature: 0.2, + checkMaxTokens: 16, + providerDbSourceId: 'amazon-bedrock' + }) + ], + [ + 'azure-openai', + createDefinition({ + ...OPENAI_BASE, + runtimeKind: 'azure' + }) + ], + [ + 'cherryin', + createDefinition({ + ...OPENAI_BASE, + checkStrategy: 'key-status', + keyStatusStrategy: 'cherryin' + }) + ], + [ + 'dashscope', + createDefinition({ + ...ENGLISH_SUMMARY_OPENAI + }) + ], + [ + 'deepseek', + createDefinition({ + ...OPENAI_BASE, + checkStrategy: 'key-status', + keyStatusStrategy: 'deepseek' + }) + ], + [ + 'doubao', + createDefinition({ + ...OPENAI_BASE, + modelSource: 'provider-db', + providerDbGroup: 'default' + }) + ], + [ + 'gemini', + createDefinition({ + runtimeKind: 'gemini', + behaviorPreset: 'google', + modelSource: 'config-db', + checkStrategy: 'generate-text', + credentialStrategy: 'api-key', + keyStatusStrategy: 'none', + routeStrategy: 'none', + embeddingStrategy: 'google', + checkModelId: 'gemini-2.0-flash', + checkPrompt: 'Hello', + checkTemperature: 0.2, + checkMaxTokens: 16 + }) + ], + [ + 'github', + createDefinition({ + ...OPENAI_BASE, + modelSource: 'github' + }) + ], + [ + 'grok', + createDefinition({ + ...OPENAI_BASE, + routeStrategy: 'grok' + }) + ], + [ + 'groq', + createDefinition({ + ...ENGLISH_SUMMARY_OPENAI, + modelSource: 'groq' + }) + ], + [ + 'jiekou', + createDefinition({ + ...OPENAI_BASE + }) + ], + [ + 'lmstudio', + createDefinition({ + ...OPENAI_BASE + }) + ], + [ + 'minimax', + createDefinition({ + runtimeKind: 'anthropic', + behaviorPreset: 'anthropic', + modelSource: 'provider-db', + providerDbGroup: 'default', + checkStrategy: 'generate-text', + credentialStrategy: 'anthropic', + keyStatusStrategy: 'none', + routeStrategy: 'none', + embeddingStrategy: 'none', + checkModelId: 'claude-sonnet-4-5-20250929', + checkPrompt: 'Hello', + checkTemperature: 0.2, + checkMaxTokens: 16 + }) + ], + [ + 'modelscope', + createDefinition({ + ...TITLE_SUMMARY_OPENAI, + checkStrategy: 'key-status', + keyStatusStrategy: 'modelscope' + }) + ], + [ + 'new-api', + createDefinition({ + ...OPENAI_BASE, + modelSource: 'new-api', + routeStrategy: 'new-api', + embeddingStrategy: 'new-api' + }) + ], + [ + 'o3fan', + createDefinition({ + ...TITLE_SUMMARY_OPENAI, + modelSource: 'provider-db', + providerDbGroup: 'o3fan' + }) + ], + [ + 'openai', + createDefinition({ + ...OPENAI_BASE, + runtimeKind: 'openai-responses' + }) + ], + [ + 'openai-responses', + createDefinition({ + ...OPENAI_BASE, + runtimeKind: 'openai-responses' + }) + ], + [ + 'openrouter', + createDefinition({ + ...OPENAI_BASE, + modelSource: 'openrouter', + checkStrategy: 'key-status', + keyStatusStrategy: 'openrouter' + }) + ], + [ + 'poe', + createDefinition({ + ...OPENAI_BASE + }) + ], + [ + 'ppio', + createDefinition({ + ...TITLE_SUMMARY_OPENAI, + modelSource: 'ppio', + checkStrategy: 'key-status', + keyStatusStrategy: 'ppio' + }) + ], + [ + 'silicon', + createDefinition({ + ...CHINESE_SUMMARY_OPENAI, + checkStrategy: 'key-status', + keyStatusStrategy: 'siliconcloud' + }) + ], + [ + 'siliconcloud', + createDefinition({ + ...CHINESE_SUMMARY_OPENAI, + checkStrategy: 'key-status', + keyStatusStrategy: 'siliconcloud' + }) + ], + [ + 'together', + createDefinition({ + ...CHINESE_SUMMARY_OPENAI, + modelSource: 'together' + }) + ], + [ + 'tokenflux', + createDefinition({ + ...TITLE_SUMMARY_OPENAI, + modelSource: 'tokenflux', + checkStrategy: 'key-status', + keyStatusStrategy: 'tokenflux' + }) + ], + [ + 'vercel-ai-gateway', + createDefinition({ + ...OPENAI_BASE + }) + ], + [ + 'vertex', + createDefinition({ + runtimeKind: 'vertex', + behaviorPreset: 'google', + modelSource: 'config-db', + checkStrategy: 'generate-text', + credentialStrategy: 'vertex', + keyStatusStrategy: 'none', + routeStrategy: 'none', + embeddingStrategy: 'google', + checkModelId: 'gemini-1.5-flash-001', + checkPrompt: 'Hello from Vertex AI', + checkTemperature: 0.2, + checkMaxTokens: 16 + }) + ], + [ + 'zenmux', + createDefinition({ + ...OPENAI_BASE, + routeStrategy: 'zenmux', + embeddingStrategy: 'zenmux' + }) + ], + [ + 'zhipu', + createDefinition({ + ...TITLE_SUMMARY_OPENAI, + modelSource: 'provider-db', + providerDbGroup: 'zhipu' + }) + ] +]) + +const PROVIDER_API_TYPE_REGISTRY = new Map([ + ['anthropic', PROVIDER_ID_REGISTRY.get('anthropic')!], + ['aws-bedrock', PROVIDER_ID_REGISTRY.get('aws-bedrock')!], + ['doubao', PROVIDER_ID_REGISTRY.get('doubao')!], + ['gemini', PROVIDER_ID_REGISTRY.get('gemini')!], + ['grok', PROVIDER_ID_REGISTRY.get('grok')!], + ['groq', PROVIDER_ID_REGISTRY.get('groq')!], + ['new-api', PROVIDER_ID_REGISTRY.get('new-api')!], + ['o3fan', PROVIDER_ID_REGISTRY.get('o3fan')!], + ['openai', PROVIDER_ID_REGISTRY.get('openai')!], + ['openai-compatible', OPENAI_BASE], + ['openai-completions', OPENAI_BASE], + ['openai-responses', PROVIDER_ID_REGISTRY.get('openai-responses')!], + ['together', PROVIDER_ID_REGISTRY.get('together')!], + ['vertex', PROVIDER_ID_REGISTRY.get('vertex')!], + ['zenmux', PROVIDER_ID_REGISTRY.get('zenmux')!] +]) + +export function resolveAiSdkProviderDefinition( + provider: LLM_PROVIDER +): AiSdkProviderDefinition | null { + const providerId = provider.id.trim().toLowerCase() + const apiType = provider.apiType.trim().toLowerCase() + + return PROVIDER_ID_REGISTRY.get(providerId) || PROVIDER_API_TYPE_REGISTRY.get(apiType) || null +} diff --git a/src/main/presenter/llmProviderPresenter/providers/_302AIProvider.ts b/src/main/presenter/llmProviderPresenter/providers/_302AIProvider.ts deleted file mode 100644 index ea4ea6207..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/_302AIProvider.ts +++ /dev/null @@ -1,253 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - ChatMessage, - KeyStatus, - MODEL_META, - IConfigPresenter -} from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Define interface for 302AI API balance response -interface _302AIBalanceResponse { - data: { - balance: string - } -} - -// Define interface for 302AI model response based on actual API format -interface _302AIModelResponse { - id: string - object: string - category?: string - category_en?: string - content_length: number // This is the context length - created_on?: string - description?: string - description_en?: string - description_jp?: string - first_byte_req_time?: string - is_moderated: boolean - max_completion_tokens: number // This is the max output tokens - price?: { - input_token?: string - output_token?: string - per_request?: string - } - supported_tools: boolean // This indicates function calling support -} - -export class _302AIProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - - /** - * Get current API key status from 302AI - * @returns Promise API key status information - */ - public async getKeyStatus(): Promise { - if (!this.provider.apiKey) { - throw new Error('API key is required') - } - - const response = await fetch('https://api.302.ai/dashboard/balance', { - method: 'GET', - headers: { - Authorization: `Bearer ${this.provider.apiKey}`, - 'Content-Type': 'application/json' - } - }) - - if (!response.ok) { - const errorText = await response.text() - throw new Error( - `302AI API key check failed: ${response.status} ${response.statusText} - ${errorText}` - ) - } - - const balanceResponse: _302AIBalanceResponse = await response.json() - const balance = parseFloat(balanceResponse.data.balance) - const remaining = '$' + balanceResponse.data.balance - - return { - limit_remaining: remaining, - remainNum: balance - } - } - - /** - * Override check method to use 302AI's API key status endpoint - * @returns Promise<{ isOk: boolean; errorMsg: string | null }> - */ - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - const keyStatus = await this.getKeyStatus() - // Check if there's remaining quota - if (keyStatus.remainNum !== undefined && keyStatus.remainNum <= 0) { - return { - isOk: false, - errorMsg: `API key quota exhausted. Remaining: ${keyStatus.limit_remaining}` - } - } - - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - let errorMessage = 'An unknown error occurred during 302AI API key check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - - console.error('302AI API key check failed:', error) - return { isOk: false, errorMsg: errorMessage } - } - } - - /** - * Override fetchOpenAIModels to parse 302AI specific model data and update model configurations - * @param options - Request options - * @returns Promise - Array of model metadata - */ - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - try { - const response = await this.openai.models.list(options) - // console.log('302AI models response:', JSON.stringify(response, null, 2)) - - const models: MODEL_META[] = [] - - for (const model of response.data) { - // Type the model as 302AI specific response - const _302aiModel = model as unknown as _302AIModelResponse - - // Extract model information - const modelId = _302aiModel.id - - // Check for function calling support using supported_tools field - const hasFunctionCalling = _302aiModel.supported_tools === true - - // Check for vision support based on model ID and description patterns - const hasVision = - modelId.includes('vision') || - modelId.includes('gpt-4o') || - (_302aiModel.description && _302aiModel.description.includes('vision')) || - (_302aiModel.description_en && - _302aiModel.description_en.toLowerCase().includes('vision')) || - modelId.includes('claude') || // Some Claude models support vision - modelId.includes('gemini') || // Gemini models often support vision - (modelId.includes('qwen') && modelId.includes('vl')) // Qwen VL models - - // Get existing model configuration first - const existingConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) - - // Extract configuration values with proper fallback priority: API -> existing config -> default - const contextLength = _302aiModel.content_length || existingConfig.contextLength || 4096 - - // Use max_completion_tokens if available, otherwise fall back to existing config or default - const maxTokens = - _302aiModel.max_completion_tokens > 0 - ? _302aiModel.max_completion_tokens - : existingConfig.maxTokens || 2048 - - // Build new configuration based on API response - const newConfig = { - contextLength: contextLength, - maxTokens: maxTokens, - functionCall: hasFunctionCalling, - vision: hasVision, - reasoning: existingConfig.reasoning || false, // Keep existing reasoning setting - temperature: existingConfig.temperature, // Keep existing temperature - type: existingConfig.type // Keep existing type - } - - // Check if configuration has changed - const configChanged = - existingConfig.contextLength !== newConfig.contextLength || - existingConfig.maxTokens !== newConfig.maxTokens || - existingConfig.functionCall !== newConfig.functionCall || - existingConfig.vision !== newConfig.vision - - // Update configuration if changed - if (configChanged) { - console.log(`Updating configuration for 302AI model ${modelId}:`, { - old: { - contextLength: existingConfig.contextLength, - maxTokens: existingConfig.maxTokens, - functionCall: existingConfig.functionCall, - vision: existingConfig.vision - }, - new: newConfig, - apiData: { - content_length: _302aiModel.content_length, - max_completion_tokens: _302aiModel.max_completion_tokens, - supported_tools: _302aiModel.supported_tools, - category: _302aiModel.category, - description: _302aiModel.description - } - }) - - this.configPresenter.setModelConfig(modelId, this.provider.id, newConfig, { - source: 'provider' - }) - } - - // Create MODEL_META object - const modelMeta: MODEL_META = { - id: modelId, - name: modelId, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: contextLength, - maxTokens: maxTokens, - vision: hasVision, - functionCall: hasFunctionCalling, - reasoning: existingConfig.reasoning || false - } - - models.push(modelMeta) - } - - console.log(`Processed ${models.length} 302AI models with dynamic configuration updates`) - return models - } catch (error) { - console.error('Error fetching 302AI models:', error) - // Fallback to parent implementation - return super.fetchOpenAIModels(options) - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/aiSdkProvider.ts b/src/main/presenter/llmProviderPresenter/providers/aiSdkProvider.ts new file mode 100644 index 000000000..d7aee850a --- /dev/null +++ b/src/main/presenter/llmProviderPresenter/providers/aiSdkProvider.ts @@ -0,0 +1,1925 @@ +import { EMBEDDING_TEST_KEY, isNormalized } from '@/utils/vector' +import { + ApiEndpointType, + ModelType, + isNewApiEndpointType, + resolveNewApiCapabilityProviderId, + type NewApiEndpointType +} from '@shared/model' +import { + DEFAULT_MODEL_CONTEXT_LENGTH, + DEFAULT_MODEL_MAX_TOKENS, + resolveModelContextLength, + resolveModelFunctionCall, + resolveModelMaxTokens +} from '@shared/modelConfigDefaults' +import { + AWS_BEDROCK_PROVIDER, + ChatMessage, + IConfigPresenter, + KeyStatus, + LLM_EMBEDDING_ATTRS, + LLM_PROVIDER, + LLMCoreStreamEvent, + LLMResponse, + MCPToolDefinition, + MODEL_META, + ModelConfig, + VERTEX_PROVIDER +} from '@shared/presenter' +import { BedrockClient, ListFoundationModelsCommand } from '@aws-sdk/client-bedrock' +import { ProxyAgent } from 'undici' +import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider' +import { + runAiSdkCoreStream, + runAiSdkDimensions, + runAiSdkEmbeddings, + runAiSdkGenerateText, + type AiSdkRuntimeContext +} from '../aiSdk' +import type { AiSdkProviderKind } from '../aiSdk/providerFactory' +import { normalizeAzureBaseUrl } from '../aiSdk/providerFactory' +import { proxyConfig } from '../../proxyConfig' +import type { ProviderMcpRuntimePort } from '../runtimePorts' +import { + type AiSdkBehaviorPreset, + type AiSdkCredentialStrategy, + type AiSdkEmbeddingStrategy, + type AiSdkKeyStatusStrategy, + type AiSdkModelSourceStrategy, + type AiSdkProviderDefinition, + type AiSdkRouteStrategy, + resolveAiSdkProviderDefinition +} from '../providerRegistry' +import { providerDbLoader } from '../../configPresenter/providerDbLoader' +import { modelCapabilities } from '../../configPresenter/modelCapabilities' + +const OPENAI_IMAGE_GENERATION_MODELS = ['gpt-4o-all', 'gpt-4o-image'] +const OPENAI_IMAGE_GENERATION_MODEL_PREFIXES = ['dall-e-', 'gpt-image-'] +const DEFAULT_NEW_API_BASE_URL = 'https://www.newapi.ai' +const ZENMUX_ANTHROPIC_BASE_URL = 'https://zenmux.ai/api/anthropic' + +type RouteDecision = { + providerKind: AiSdkProviderKind + providerPatch?: Partial + modelConfigPatch?: Partial + endpointType?: NewApiEndpointType | 'grok-image' +} + +const isOpenAIImageGenerationModel = (modelId: string): boolean => + OPENAI_IMAGE_GENERATION_MODELS.includes(modelId) || + OPENAI_IMAGE_GENERATION_MODEL_PREFIXES.some((prefix) => modelId.startsWith(prefix)) + +export function normalizeExtractedImageText(content: string): string { + const normalized = content + .replace(/\r\n/g, '\n') + .replace(/\n\s*\n/g, '\n') + .trim() + if (!normalized) { + return '' + } + + const semanticText = normalized.replace(/[`*_~!()[\]]/g, '').trim() + return semanticText.length > 0 ? normalized : '' +} + +function toModelRecordArray(payload: unknown): Array> { + if (Array.isArray(payload)) { + return payload.filter( + (item): item is Record => + Boolean(item) && typeof item === 'object' && !Array.isArray(item) + ) + } + + if (!payload || typeof payload !== 'object') { + return [] + } + + const record = payload as Record + for (const key of ['data', 'body', 'models']) { + const value = record[key] + if (Array.isArray(value)) { + return value.filter( + (item): item is Record => + Boolean(item) && typeof item === 'object' && !Array.isArray(item) + ) + } + } + + return [] +} + +function toErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message) { + return error.message + } + if (typeof error === 'string' && error.trim()) { + return error + } + return fallback +} + +export class AiSdkProvider extends BaseLLMProvider { + private readonly definition: AiSdkProviderDefinition + + constructor( + provider: LLM_PROVIDER, + configPresenter: IConfigPresenter, + mcpRuntime?: ProviderMcpRuntimePort + ) { + super(provider, configPresenter, mcpRuntime) + const definition = resolveAiSdkProviderDefinition(provider) + if (!definition) { + throw new Error( + `No AI SDK definition found for provider ${provider.id} (${provider.apiType})` + ) + } + this.definition = definition + this.init() + } + + private getRouteStrategy(): AiSdkRouteStrategy { + return this.definition.routeStrategy ?? 'none' + } + + private getBehaviorPreset(decision: RouteDecision): AiSdkBehaviorPreset { + switch (this.getRouteStrategy()) { + case 'new-api': + case 'zenmux': + if (decision.providerKind === 'anthropic' || decision.providerKind === 'aws-bedrock') { + return 'anthropic' + } + if (decision.providerKind === 'gemini' || decision.providerKind === 'vertex') { + return 'google' + } + return this.definition.behaviorPreset + default: + return this.definition.behaviorPreset + } + } + + private getNormalizedNewApiHost(): string { + const rawBaseUrl = (this.provider.baseUrl || DEFAULT_NEW_API_BASE_URL).trim() + const normalizedBaseUrl = rawBaseUrl.replace(/\/+$/, '') + return normalizedBaseUrl.replace(/\/(v1|v1beta(?:\d+)?)$/i, '') || DEFAULT_NEW_API_BASE_URL + } + + private getStoredModel(modelId: string): MODEL_META | undefined { + return [...this.models, ...this.customModels].find((model) => model.id === modelId) + } + + private getDefaultNewApiEndpointType(model: Pick) { + const supportedEndpointTypes = model.supportedEndpointTypes ?? [] + if (supportedEndpointTypes.length === 0) { + return model.type === ModelType.ImageGeneration ? 'image-generation' : undefined + } + + if ( + model.type === ModelType.ImageGeneration && + supportedEndpointTypes.includes('image-generation') + ) { + return 'image-generation' + } + + return supportedEndpointTypes[0] + } + + private resolveNewApiEndpointType(modelId: string): NewApiEndpointType { + const modelConfig = this.getProviderModelConfig(modelId) + if (isNewApiEndpointType(modelConfig.endpointType)) { + return modelConfig.endpointType + } + + const storedModel = this.getStoredModel(modelId) + if (storedModel && isNewApiEndpointType(storedModel.endpointType)) { + return storedModel.endpointType + } + + const defaultEndpointType = storedModel + ? this.getDefaultNewApiEndpointType(storedModel) + : undefined + return defaultEndpointType ?? 'openai' + } + + private resolveRouteDecision(modelId: string, _modelConfig?: ModelConfig): RouteDecision { + const strategy = this.getRouteStrategy() + + if (strategy === 'grok' && modelId.startsWith('grok-2-image')) { + return { + providerKind: this.definition.runtimeKind, + endpointType: 'grok-image', + modelConfigPatch: { + apiEndpoint: ApiEndpointType.Image + } + } + } + + if (strategy === 'zenmux' && modelId.trim().toLowerCase().startsWith('anthropic/')) { + return { + providerKind: 'anthropic', + providerPatch: { + apiType: 'anthropic', + baseUrl: ZENMUX_ANTHROPIC_BASE_URL + } + } + } + + if (strategy === 'new-api') { + const endpointType = this.resolveNewApiEndpointType(modelId) + const host = this.getNormalizedNewApiHost() + + switch (endpointType) { + case 'anthropic': + return { + providerKind: 'anthropic', + endpointType, + providerPatch: { + apiType: 'anthropic', + baseUrl: host, + capabilityProviderId: resolveNewApiCapabilityProviderId('anthropic') + } + } + case 'gemini': + return { + providerKind: 'gemini', + endpointType, + providerPatch: { + apiType: 'gemini', + baseUrl: host, + capabilityProviderId: resolveNewApiCapabilityProviderId('gemini') + } + } + case 'openai-response': + return { + providerKind: 'openai-responses', + endpointType, + providerPatch: { + apiType: 'openai-responses', + baseUrl: `${host}/v1`, + capabilityProviderId: resolveNewApiCapabilityProviderId('openai-response') + } + } + case 'image-generation': + return { + providerKind: 'openai-compatible', + endpointType, + providerPatch: { + apiType: 'openai-completions', + baseUrl: `${host}/v1`, + capabilityProviderId: resolveNewApiCapabilityProviderId('image-generation') + }, + modelConfigPatch: { + apiEndpoint: ApiEndpointType.Image, + type: ModelType.ImageGeneration, + endpointType: 'image-generation' + } + } + case 'openai': + default: + return { + providerKind: 'openai-compatible', + endpointType, + providerPatch: { + apiType: 'openai-completions', + baseUrl: `${host}/v1`, + capabilityProviderId: resolveNewApiCapabilityProviderId('openai') + } + } + } + } + + return { + providerKind: this.definition.runtimeKind + } + } + + private getRuntimeProvider(decision: RouteDecision): LLM_PROVIDER { + return { + ...this.provider, + ...decision.providerPatch + } + } + + private getResolvedModelConfig(modelId: string, modelConfig?: ModelConfig): ModelConfig { + return { + ...this.configPresenter.getModelConfig(modelId, this.provider.id), + ...modelConfig + } + } + + private getModelConfigForDecision(modelId: string, modelConfig?: ModelConfig): ModelConfig { + const decision = this.resolveRouteDecision(modelId, modelConfig) + return { + ...this.getResolvedModelConfig(modelId, modelConfig), + ...decision.modelConfigPatch + } + } + + public getProviderModelConfig(modelId: string): ModelConfig { + return this.configPresenter.getModelConfig(modelId, this.provider.id) ?? ({} as ModelConfig) + } + + public stringifyMessageContent(content: ChatMessage['content']): string { + if (typeof content === 'string') { + return content + } + + if (!Array.isArray(content)) { + return '' + } + + return content + .map((part) => { + if (part.type === 'text' && typeof part.text === 'string') { + return part.text + } + return '' + }) + .filter(Boolean) + .join('\n') + } + + public buildFallbackSummaryTitle(messages: ChatMessage[]): string { + const latestUserMessage = [...messages].reverse().find((message) => message.role === 'user') + const textContent = this.stringifyMessageContent(latestUserMessage?.content ?? '') + const normalizedTitle = textContent.replace(/\s+/g, ' ').trim() + if (!normalizedTitle) { + return 'New Conversation' + } + + return normalizedTitle.slice(0, 60) + } + + public getDbProviderModels(providerId = this.provider.id): MODEL_META[] { + return this.configPresenter.getDbProviderModels(providerId) + } + + public updateProviderManagedModelConfig(modelId: string, config: Partial): void { + this.configPresenter.setModelConfig( + modelId, + this.provider.id, + { + ...this.getProviderModelConfig(modelId), + ...config + }, + { + source: 'provider' + } + ) + } + + public getModelFetchTimeoutMs(): number { + return this.getModelFetchTimeout() + } + + private getFetchDispatcher(): ProxyAgent | undefined { + const proxyUrl = proxyConfig.getProxyUrl() + return proxyUrl ? new ProxyAgent(proxyUrl) : undefined + } + + private isAzureOpenAI(decision: RouteDecision, runtimeProvider: LLM_PROVIDER): boolean { + return decision.providerKind === 'azure' || runtimeProvider.id === 'azure-openai' + } + + private isOfficialOpenAIService(decision: RouteDecision, runtimeProvider: LLM_PROVIDER): boolean { + return runtimeProvider.id === 'openai' && !this.isAzureOpenAI(decision, runtimeProvider) + } + + private resolveTraceAuthToken(runtimeProvider: LLM_PROVIDER): string { + return runtimeProvider.oauthToken || runtimeProvider.apiKey || 'MISSING_API_KEY' + } + + private buildTraceHeaders( + decision: RouteDecision, + runtimeProvider: LLM_PROVIDER, + defaultHeaders: Record + ): Record { + const headers: Record = { + 'Content-Type': 'application/json', + ...defaultHeaders + } + + if (this.isAzureOpenAI(decision, runtimeProvider)) { + headers['api-key'] = this.resolveTraceAuthToken(runtimeProvider) + } else { + headers.Authorization = `Bearer ${this.resolveTraceAuthToken(runtimeProvider)}` + } + + return headers + } + + private getRequestHeaders( + decision: RouteDecision, + runtimeProvider: LLM_PROVIDER, + defaultHeaders: Record, + contentType?: string + ): Record { + const headers: Record = { + ...defaultHeaders + } + + if (contentType) { + headers['Content-Type'] = contentType + } + + if (this.isAzureOpenAI(decision, runtimeProvider)) { + headers['api-key'] = runtimeProvider.apiKey + } else { + headers.Authorization = `Bearer ${runtimeProvider.oauthToken || runtimeProvider.apiKey}` + } + + return headers + } + + private buildModelsUrl(decision: RouteDecision, runtimeProvider: LLM_PROVIDER): string { + if (this.isAzureOpenAI(decision, runtimeProvider)) { + const azureApiVersion = this.configPresenter.getSetting('azureApiVersion') + const azureConfig = normalizeAzureBaseUrl( + runtimeProvider.baseUrl || undefined, + azureApiVersion + ) + const baseURL = azureConfig.baseURL?.replace(/\/+$/, '') || '' + return `${baseURL}/models?api-version=${encodeURIComponent(azureConfig.apiVersion)}` + } + + const baseUrl = (runtimeProvider.baseUrl || 'https://api.openai.com/v1').replace(/\/+$/, '') + return `${baseUrl}/models` + } + + private buildRuntimeContext( + modelId: string, + modelConfig?: ModelConfig + ): { context: AiSdkRuntimeContext; decision: RouteDecision; resolvedModelConfig: ModelConfig } { + const decision = this.resolveRouteDecision(modelId, modelConfig) + const runtimeProvider = this.getRuntimeProvider(decision) + const defaultHeaders = { + ...this.defaultHeaders, + ...this.definition.defaultHeadersPatch + } + const resolvedModelConfig = this.getModelConfigForDecision(modelId, modelConfig) + + const cleanHeaders = this.isAzureOpenAI(decision, runtimeProvider) + ? false + : !this.isOfficialOpenAIService(decision, runtimeProvider) + + const shouldUseImageGeneration = + decision.endpointType === 'grok-image' || decision.endpointType === 'image-generation' + ? () => true + : this.isAzureOpenAI(decision, runtimeProvider) + ? (_runtimeModelId: string, runtimeModelConfig: ModelConfig) => + runtimeModelConfig.apiEndpoint === ApiEndpointType.Image + : decision.providerKind === 'gemini' || decision.providerKind === 'vertex' + ? (_runtimeModelId: string, runtimeModelConfig: ModelConfig) => + runtimeModelConfig.apiEndpoint === ApiEndpointType.Image + : decision.providerKind === 'openai-responses' + ? (runtimeModelId: string, runtimeModelConfig: ModelConfig) => + isOpenAIImageGenerationModel(runtimeModelId) || + runtimeModelConfig.apiEndpoint === ApiEndpointType.Image + : (runtimeModelId: string, runtimeModelConfig: ModelConfig) => + isOpenAIImageGenerationModel(runtimeModelId) || + runtimeModelConfig.apiEndpoint === ApiEndpointType.Image + + return { + decision, + resolvedModelConfig, + context: { + providerKind: decision.providerKind, + provider: runtimeProvider, + configPresenter: this.configPresenter, + defaultHeaders, + buildLegacyFunctionCallPrompt: (tools) => this.getFunctionCallWrapPrompt(tools), + emitRequestTrace: (runtimeModelConfig, payload) => + this.emitRequestTrace(runtimeModelConfig, payload), + buildTraceHeaders: () => this.buildTraceHeaders(decision, runtimeProvider, defaultHeaders), + cleanHeaders, + supportsNativeTools: (_runtimeModelId, runtimeModelConfig) => + runtimeModelConfig.functionCall === true, + shouldUseImageGeneration + } + } + } + + public async requestProviderJson( + url: string, + init: RequestInit = {}, + timeout?: number, + decision?: RouteDecision + ): Promise { + const resolvedDecision = decision ?? { providerKind: this.definition.runtimeKind } + const runtimeProvider = this.getRuntimeProvider(resolvedDecision) + const defaultHeaders = { + ...this.defaultHeaders, + ...this.definition.defaultHeadersPatch + } + const controller = new AbortController() + const timeoutId = + typeof timeout === 'number' && timeout > 0 + ? setTimeout(() => controller.abort(), timeout) + : undefined + + try { + const dispatcher = this.getFetchDispatcher() + const response = await fetch(url, { + ...init, + headers: { + ...this.getRequestHeaders( + resolvedDecision, + runtimeProvider, + defaultHeaders, + init.body && !(init.body instanceof FormData) ? 'application/json' : undefined + ), + ...(init.headers as Record | undefined) + }, + signal: controller.signal, + ...(dispatcher ? ({ dispatcher } as Record) : {}) + } as RequestInit) + + if (!response.ok) { + const errorText = await response.text() + throw new Error(errorText || `Request failed with status ${response.status}`) + } + + return (await response.json()) as T + } finally { + if (timeoutId) { + clearTimeout(timeoutId) + } + } + } + + public async fetchOpenAIModelRecords( + options?: { timeout: number }, + decision?: RouteDecision + ): Promise>> { + const resolvedDecision = decision ?? { providerKind: this.definition.runtimeKind } + const runtimeProvider = this.getRuntimeProvider(resolvedDecision) + const payload = await this.requestProviderJson( + this.buildModelsUrl(resolvedDecision, runtimeProvider), + { method: 'GET' }, + options?.timeout, + resolvedDecision + ) + return toModelRecordArray(payload) + } + + public async fetchDefaultOpenAIModels( + options?: { timeout: number }, + decision?: RouteDecision + ): Promise { + const response = await this.fetchOpenAIModelRecords(options, decision) + const models: MODEL_META[] = [] + + for (const model of response) { + if (typeof model.id !== 'string') { + continue + } + + models.push({ + id: model.id, + name: model.id, + group: 'default', + providerId: this.provider.id, + isCustom: false, + contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, + maxTokens: DEFAULT_MODEL_MAX_TOKENS + }) + } + + return models + } + + public async runText( + messages: ChatMessage[], + modelId: string, + temperature?: number, + maxTokens?: number, + modelConfig?: ModelConfig + ): Promise { + if (!this.isInitialized) { + throw new Error('Provider not initialized') + } + if (!modelId) { + throw new Error('Model ID is required') + } + + const { context, resolvedModelConfig } = this.buildRuntimeContext(modelId, modelConfig) + return runAiSdkGenerateText( + context, + messages, + modelId, + resolvedModelConfig, + temperature, + maxTokens + ) + } + + public async *streamText( + messages: ChatMessage[], + modelId: string, + modelConfig: ModelConfig, + temperature: number, + maxTokens: number, + tools: MCPToolDefinition[] + ): AsyncGenerator { + if (!this.isInitialized) { + throw new Error('Provider not initialized') + } + if (!modelId) { + throw new Error('Model ID is required') + } + + const { context, resolvedModelConfig } = this.buildRuntimeContext(modelId, modelConfig) + yield* runAiSdkCoreStream( + context, + messages, + modelId, + resolvedModelConfig, + temperature, + maxTokens, + tools + ) + } + + public async collectStreamResponse( + messages: ChatMessage[], + modelId: string, + temperature?: number, + maxTokens?: number, + tools: MCPToolDefinition[] = [], + modelConfig?: ModelConfig + ): Promise { + const response: LLMResponse = { + content: '' + } + const resolvedModelConfig = + modelConfig ?? + ({ + ...this.getProviderModelConfig(modelId), + apiEndpoint: ApiEndpointType.Image + } as ModelConfig) + + for await (const event of this.streamText( + messages, + modelId, + resolvedModelConfig, + temperature ?? resolvedModelConfig.temperature ?? 0.7, + maxTokens ?? resolvedModelConfig.maxTokens ?? 1024, + tools + )) { + switch (event.type) { + case 'text': + response.content += event.content + break + case 'reasoning': + response.reasoning_content = `${response.reasoning_content ?? ''}${event.reasoning_content}` + break + case 'image_data': + if (!response.content) { + response.content = event.image_data.data + } + break + case 'usage': + response.totalUsage = event.usage + break + case 'error': + throw new Error(event.error_message) + } + } + + return response + } + + public async runEmbeddings(modelId: string, texts: string[]): Promise { + const { context } = this.buildRuntimeContext(modelId) + return runAiSdkEmbeddings(context, modelId, texts) + } + + private async runEmbeddingsWithDecision( + modelId: string, + texts: string[], + decision: RouteDecision + ): Promise { + const runtimeProvider = this.getRuntimeProvider(decision) + const defaultHeaders = { + ...this.defaultHeaders, + ...this.definition.defaultHeadersPatch + } + const context: AiSdkRuntimeContext = { + providerKind: decision.providerKind, + provider: runtimeProvider, + configPresenter: this.configPresenter, + defaultHeaders, + buildLegacyFunctionCallPrompt: (tools) => this.getFunctionCallWrapPrompt(tools), + emitRequestTrace: (runtimeModelConfig, payload) => + this.emitRequestTrace(runtimeModelConfig, payload), + buildTraceHeaders: () => this.buildTraceHeaders(decision, runtimeProvider, defaultHeaders), + cleanHeaders: this.isAzureOpenAI(decision, runtimeProvider) + ? false + : !this.isOfficialOpenAIService(decision, runtimeProvider), + supportsNativeTools: (_runtimeModelId, runtimeModelConfig) => + runtimeModelConfig.functionCall === true, + shouldUseImageGeneration: (_runtimeModelId, runtimeModelConfig) => + runtimeModelConfig.apiEndpoint === ApiEndpointType.Image + } + + return runAiSdkEmbeddings(context, modelId, texts) + } + + public async runDimensions(modelId: string): Promise { + if (modelId === 'text-embedding-3-small' || modelId === 'text-embedding-ada-002') { + return { + dimensions: 1536, + normalized: true + } + } + + if (modelId === 'text-embedding-3-large') { + return { + dimensions: 3072, + normalized: true + } + } + + try { + const embeddings = await this.runEmbeddings(modelId, [EMBEDDING_TEST_KEY]) + return { + dimensions: embeddings[0].length, + normalized: isNormalized(embeddings[0]) + } + } catch (error) { + console.error(`[AiSdkProvider] Failed to get dimensions for model ${modelId}:`, error) + const { context } = this.buildRuntimeContext(modelId) + return runAiSdkDimensions(context, modelId) + } + } + + private async runDimensionsWithDecision( + modelId: string, + decision: RouteDecision + ): Promise { + try { + const embeddings = await this.runEmbeddingsWithDecision( + modelId, + [EMBEDDING_TEST_KEY], + decision + ) + return { + dimensions: embeddings[0].length, + normalized: isNormalized(embeddings[0]) + } + } catch (error) { + console.error(`[AiSdkProvider] Failed to get dimensions for model ${modelId}:`, error) + const runtimeProvider = this.getRuntimeProvider(decision) + const defaultHeaders = { + ...this.defaultHeaders, + ...this.definition.defaultHeadersPatch + } + const context: AiSdkRuntimeContext = { + providerKind: decision.providerKind, + provider: runtimeProvider, + configPresenter: this.configPresenter, + defaultHeaders, + buildLegacyFunctionCallPrompt: (tools) => this.getFunctionCallWrapPrompt(tools), + emitRequestTrace: (runtimeModelConfig, payload) => + this.emitRequestTrace(runtimeModelConfig, payload), + buildTraceHeaders: () => this.buildTraceHeaders(decision, runtimeProvider, defaultHeaders), + cleanHeaders: this.isAzureOpenAI(decision, runtimeProvider) + ? false + : !this.isOfficialOpenAIService(decision, runtimeProvider), + supportsNativeTools: (_runtimeModelId, runtimeModelConfig) => + runtimeModelConfig.functionCall === true, + shouldUseImageGeneration: (_runtimeModelId, runtimeModelConfig) => + runtimeModelConfig.apiEndpoint === ApiEndpointType.Image + } + return runAiSdkDimensions(context, modelId) + } + } + + private mapConfigDbModels(providerId = this.provider.id): MODEL_META[] { + return this.getDbProviderModels(providerId).map((model) => ({ + id: model.id, + name: model.name, + group: model.group || 'default', + providerId: this.provider.id, + isCustom: false, + contextLength: model.contextLength, + maxTokens: model.maxTokens, + vision: model.vision || false, + functionCall: model.functionCall || false, + reasoning: model.reasoning || false, + ...(model.type ? { type: model.type } : {}) + })) + } + + private mapProviderDbModels(group: string): MODEL_META[] { + const resolvedId = modelCapabilities.resolveProviderId(this.provider.id) || this.provider.id + const provider = providerDbLoader.getProvider(resolvedId) + if (!provider || !Array.isArray(provider.models)) { + return [] + } + + return provider.models.map((model) => { + const inputs = model.modalities?.input + const outputs = model.modalities?.output + const hasImageInput = Array.isArray(inputs) && inputs.includes('image') + const hasImageOutput = Array.isArray(outputs) && outputs.includes('image') + const modelType = hasImageOutput ? ModelType.ImageGeneration : ModelType.Chat + + return { + id: model.id, + name: model.display_name || model.name || model.id, + group, + providerId: this.provider.id, + isCustom: false, + contextLength: resolveModelContextLength(model.limit?.context), + maxTokens: resolveModelMaxTokens(model.limit?.output), + vision: hasImageInput, + functionCall: resolveModelFunctionCall(model.tool_call), + reasoning: Boolean(model.reasoning?.supported), + enableSearch: Boolean(model.search?.supported), + type: modelType + } + }) + } + + private syncProviderModelConfig(modelId: string, nextConfig: Partial): void { + const existingConfig = this.getProviderModelConfig(modelId) + const merged = { + ...existingConfig, + ...nextConfig + } + + const changed = Object.keys(nextConfig).some( + (key) => existingConfig[key as keyof ModelConfig] !== merged[key as keyof ModelConfig] + ) + + if (changed) { + this.updateProviderManagedModelConfig(modelId, merged) + } + } + + private async fetchProviderModelsByStrategy( + strategy: AiSdkModelSourceStrategy + ): Promise { + switch (strategy) { + case 'config-db': + return this.mapConfigDbModels(this.definition.providerDbSourceId) + case 'provider-db': + return this.mapProviderDbModels(this.definition.providerDbGroup || 'default') + case 'github': { + const response = await this.fetchOpenAIModelRecords({ + timeout: this.getModelFetchTimeout() + }) + return response + .filter((model) => typeof model.name === 'string') + .map((model) => ({ + id: model.name as string, + name: model.name as string, + group: 'default', + providerId: this.provider.id, + isCustom: false, + contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, + maxTokens: DEFAULT_MODEL_MAX_TOKENS, + description: typeof model.description === 'string' ? model.description : undefined + })) + } + case 'together': { + const response = await this.fetchOpenAIModelRecords({ + timeout: this.getModelFetchTimeout() + }) + return response + .filter((model) => model.type === 'chat' || model.type === 'language') + .map((model) => ({ + id: model.id as string, + name: model.id as string, + group: 'default', + providerId: this.provider.id, + isCustom: false, + contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, + maxTokens: DEFAULT_MODEL_MAX_TOKENS + })) + } + case 'openrouter': + case 'ppio': + case 'groq': + case 'tokenflux': + case '302ai': + return this.fetchOpenAiDerivedModels(strategy) + case 'bedrock': + return this.fetchBedrockModels() + case 'new-api': + return this.fetchNewApiModels() + case 'openai': + default: + return this.fetchDefaultOpenAIModels({ timeout: this.getModelFetchTimeout() }).then( + (models) => + this.getRouteStrategy() === 'zenmux' + ? models.map((model) => ({ + ...model, + group: 'ZenMux' + })) + : models + ) + } + } + + private async fetchOpenAiDerivedModels( + strategy: 'openrouter' | 'ppio' | 'groq' | 'tokenflux' | '302ai' + ): Promise { + try { + const response = await this.fetchOpenAIModelRecords({ timeout: this.getModelFetchTimeout() }) + const models: MODEL_META[] = [] + + for (const model of response) { + const modelId = typeof model.id === 'string' ? model.id : '' + if (!modelId) { + continue + } + + const existingConfig = this.getProviderModelConfig(modelId) + + if (strategy === 'groq') { + const status = + typeof model.status === 'number' + ? model.status + : typeof model.active === 'boolean' + ? model.active + ? 1 + : 0 + : 1 + if (status === 0 || model.active === false) { + continue + } + } + + const features = Array.isArray(model.features) + ? model.features.filter((item): item is string => typeof item === 'string') + : [] + const supportedParameters = Array.isArray(model.supported_parameters) + ? model.supported_parameters.filter((item): item is string => typeof item === 'string') + : [] + const inputModalities = Array.isArray( + (model.architecture as Record)?.input_modalities + ) + ? ((model.architecture as Record).input_modalities as unknown[]).filter( + (item): item is string => typeof item === 'string' + ) + : [] + + const contextLength = + strategy === 'openrouter' + ? (typeof model.context_length === 'number' ? model.context_length : undefined) || + (typeof (model.top_provider as Record)?.context_length === 'number' + ? ((model.top_provider as Record).context_length as number) + : undefined) || + existingConfig.contextLength || + 4096 + : strategy === 'ppio' + ? (typeof model.context_size === 'number' ? model.context_size : undefined) || + existingConfig.contextLength || + 4096 + : strategy === 'groq' + ? (typeof model.context_size === 'number' ? model.context_size : undefined) || + (typeof model.context_window === 'number' ? model.context_window : undefined) || + existingConfig.contextLength || + 4096 + : strategy === 'tokenflux' + ? (typeof model.context_length === 'number' ? model.context_length : undefined) || + existingConfig.contextLength || + 4096 + : (typeof model.content_length === 'number' ? model.content_length : undefined) || + existingConfig.contextLength || + 4096 + + const maxTokens = + strategy === 'openrouter' + ? (typeof (model.top_provider as Record)?.max_completion_tokens === + 'number' + ? ((model.top_provider as Record).max_completion_tokens as number) + : undefined) || + existingConfig.maxTokens || + 2048 + : strategy === 'ppio' + ? (typeof model.max_output_tokens === 'number' + ? model.max_output_tokens + : undefined) || + existingConfig.maxTokens || + 2048 + : strategy === 'groq' + ? (typeof model.max_output_tokens === 'number' + ? model.max_output_tokens + : undefined) || + (typeof model.max_tokens === 'number' ? model.max_tokens : undefined) || + existingConfig.maxTokens || + 2048 + : strategy === 'tokenflux' + ? existingConfig.maxTokens || Math.min(contextLength / 2, 4096) + : typeof model.max_completion_tokens === 'number' && + model.max_completion_tokens > 0 + ? (model.max_completion_tokens as number) + : existingConfig.maxTokens || 2048 + + const hasFunctionCalling = + strategy === 'openrouter' + ? supportedParameters.includes('tools') + : strategy === 'ppio' + ? features.includes('function-calling') + : strategy === 'groq' + ? features.includes('function-calling') || + (!modelId.toLowerCase().includes('distil') && + !modelId.toLowerCase().includes('gemma')) + : strategy === 'tokenflux' + ? true + : model.supported_tools === true + + const hasVision = + strategy === 'openrouter' + ? inputModalities.includes('image') + : strategy === 'ppio' + ? features.includes('vision') + : strategy === 'groq' + ? features.includes('vision') || + modelId.toLowerCase().includes('vision') || + modelId.toLowerCase().includes('llava') + : strategy === 'tokenflux' + ? Boolean(model.supports_vision) + : modelId.includes('vision') || + modelId.includes('gpt-4o') || + (typeof model.description === 'string' && + model.description.includes('vision')) || + (typeof model.description_en === 'string' && + model.description_en.toLowerCase().includes('vision')) || + modelId.includes('claude') || + modelId.includes('gemini') || + (modelId.includes('qwen') && modelId.includes('vl')) + + const reasoning = + strategy === 'openrouter' + ? supportedParameters.includes('reasoning') || + supportedParameters.includes('include_reasoning') || + existingConfig.reasoning || + false + : existingConfig.reasoning || false + + this.syncProviderModelConfig(modelId, { + contextLength, + maxTokens, + functionCall: hasFunctionCalling, + vision: hasVision, + reasoning, + temperature: existingConfig.temperature, + type: existingConfig.type + }) + + models.push({ + id: modelId, + name: + strategy === 'ppio' && typeof model.display_name === 'string' + ? model.display_name + : strategy === 'groq' && typeof model.display_name === 'string' + ? model.display_name + : strategy === 'tokenflux' && typeof model.name === 'string' + ? model.name + : strategy === 'openrouter' && typeof model.name === 'string' + ? model.name + : modelId, + group: 'default', + providerId: this.provider.id, + isCustom: false, + contextLength, + maxTokens, + description: + typeof model.description === 'string' + ? model.description + : strategy === 'groq' + ? `Groq model ${modelId}` + : undefined, + vision: hasVision, + functionCall: hasFunctionCalling, + reasoning + }) + } + + return models + } catch (error) { + console.error(`Error fetching ${strategy} models:`, error) + return this.fetchDefaultOpenAIModels({ timeout: this.getModelFetchTimeout() }) + } + } + + private async fetchBedrockModels(): Promise { + const provider = this.provider as AWS_BEDROCK_PROVIDER + const accessKeyId = provider.credential?.accessKeyId || process.env.BEDROCK_ACCESS_KEY_ID + const secretAccessKey = + provider.credential?.secretAccessKey || process.env.BEDROCK_SECRET_ACCESS_KEY + const region = provider.credential?.region || process.env.BEDROCK_REGION + + if (!accessKeyId || !secretAccessKey || !region) { + return this.mapConfigDbModels(this.definition.providerDbSourceId).filter((model) => + model.id.startsWith('anthropic.') + ) + } + + try { + const client = new BedrockClient({ + credentials: { + accessKeyId, + secretAccessKey + }, + region + }) + const response = await client.send(new ListFoundationModelsCommand({})) + return ( + response.modelSummaries + ?.filter( + (model) => model.modelId && /^anthropic\.claude-[a-z0-9-]+(:\d+)$/g.test(model.modelId) + ) + ?.filter((model) => model.modelLifecycle?.status === 'ACTIVE') + ?.filter( + (model) => model.inferenceTypesSupported && model.inferenceTypesSupported.length > 0 + ) + .map((model) => ({ + id: model.inferenceTypesSupported?.includes('ON_DEMAND') + ? model.modelId! + : `${region.split('-')[0]}.${model.modelId}`, + name: model.modelId?.replace('anthropic.', '') || '', + providerId: this.provider.id, + maxTokens: 64_000, + group: `AWS Bedrock Claude - ${ + model.modelId?.includes('opus') + ? 'opus' + : model.modelId?.includes('sonnet') + ? 'sonnet' + : model.modelId?.includes('haiku') + ? 'haiku' + : 'other' + }`, + isCustom: false, + contextLength: 200_000, + vision: false, + functionCall: false, + reasoning: false + })) || [] + ) + } catch (error) { + console.error('获取AWS Bedrock Anthropic模型列表出错:', error) + return this.mapConfigDbModels(this.definition.providerDbSourceId).filter((model) => + model.id.startsWith('anthropic.') + ) + } + } + + private async fetchNewApiModels(): Promise { + type NewApiModelRecord = { + id?: unknown + name?: unknown + owned_by?: unknown + description?: unknown + type?: unknown + supported_endpoint_types?: unknown + context_length?: unknown + contextLength?: unknown + input_token_limit?: unknown + max_input_tokens?: unknown + max_tokens?: unknown + max_output_tokens?: unknown + output_token_limit?: unknown + } + + type NewApiModelsResponse = { + data?: NewApiModelRecord[] + } + + const host = this.getNormalizedNewApiHost() + const payload = await this.requestProviderJson( + `${host}/v1/models`, + { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}`, + 'Content-Type': 'application/json', + ...this.defaultHeaders + } + }, + this.getModelFetchTimeout() + ) + const rawModels = Array.isArray(payload.data) ? payload.data : [] + + const models = rawModels + .filter((rawModel): rawModel is NewApiModelRecord & { id: string } => { + return typeof rawModel.id === 'string' && rawModel.id.trim().length > 0 + }) + .map((rawModel) => { + const supportedEndpointTypes = Array.isArray(rawModel.supported_endpoint_types) + ? rawModel.supported_endpoint_types.filter(isNewApiEndpointType) + : [] + + const normalizedRawType = + typeof rawModel.type === 'string' ? rawModel.type.trim().toLowerCase() : '' + const normalizedModelId = rawModel.id.toLowerCase() + const type = + normalizedRawType === 'imagegeneration' || + normalizedRawType === 'image-generation' || + normalizedRawType === 'image' || + supportedEndpointTypes.includes('image-generation') + ? ModelType.ImageGeneration + : normalizedRawType === 'embedding' || + normalizedRawType === 'embeddings' || + normalizedModelId.includes('embedding') + ? ModelType.Embedding + : normalizedRawType === 'rerank' || normalizedModelId.includes('rerank') + ? ModelType.Rerank + : undefined + + const contextLengthCandidate = [ + rawModel.context_length, + rawModel.contextLength, + rawModel.input_token_limit, + rawModel.max_input_tokens + ].find( + (candidate): candidate is number => + typeof candidate === 'number' && Number.isFinite(candidate) + ) + + const maxTokensCandidate = [ + rawModel.max_tokens, + rawModel.max_output_tokens, + rawModel.output_token_limit + ].find( + (candidate): candidate is number => + typeof candidate === 'number' && Number.isFinite(candidate) + ) + + const defaultEndpointType = + supportedEndpointTypes.length === 0 + ? type === ModelType.ImageGeneration + ? 'image-generation' + : undefined + : type === ModelType.ImageGeneration && + supportedEndpointTypes.includes('image-generation') + ? 'image-generation' + : supportedEndpointTypes[0] + + return { + id: rawModel.id, + name: typeof rawModel.name === 'string' ? rawModel.name : rawModel.id, + group: typeof rawModel.owned_by === 'string' ? rawModel.owned_by : 'default', + providerId: this.provider.id, + isCustom: false, + supportedEndpointTypes, + endpointType: defaultEndpointType, + ...(typeof rawModel.description === 'string' + ? { description: rawModel.description } + : {}), + ...(type ? { type } : {}), + ...(contextLengthCandidate !== undefined + ? { contextLength: contextLengthCandidate } + : {}), + ...(maxTokensCandidate !== undefined ? { maxTokens: maxTokensCandidate } : {}) + } satisfies MODEL_META + }) + + for (const model of models) { + if (this.configPresenter.hasUserModelConfig(model.id, this.provider.id)) { + continue + } + + const existingConfig = this.getProviderModelConfig(model.id) + this.updateProviderManagedModelConfig(model.id, { + ...existingConfig, + type: model.type ?? existingConfig.type, + apiEndpoint: + model.endpointType === 'image-generation' ? ApiEndpointType.Image : ApiEndpointType.Chat, + endpointType: model.endpointType ?? existingConfig.endpointType + }) + } + + return models + } + + protected async fetchProviderModels(): Promise { + return this.fetchProviderModelsByStrategy(this.definition.modelSource) + } + + public onProxyResolved(): void {} + + private resolveKeyStatusStrategy(): AiSdkKeyStatusStrategy { + return this.definition.keyStatusStrategy ?? 'none' + } + + public async getKeyStatus(): Promise { + switch (this.resolveKeyStatusStrategy()) { + case 'openrouter': { + const response = await fetch('https://openrouter.ai/api/v1/key', { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}`, + 'Content-Type': 'application/json' + } + }) + if (response.status !== 200) { + const errorText = await response.text() + throw new Error( + `OpenRouter API key check failed: ${response.status} ${response.statusText} - ${errorText}` + ) + } + const payload = (await response.json()) as { + data: { + usage: number + limit_remaining: number | null + } + } + const keyStatus: KeyStatus = { + usage: '$' + payload.data.usage + } + if (payload.data.limit_remaining !== null) { + keyStatus.limit_remaining = '$' + payload.data.limit_remaining + keyStatus.remainNum = payload.data.limit_remaining + } + return keyStatus + } + case 'deepseek': { + const response = await fetch('https://api.deepseek.com/user/balance', { + method: 'GET', + headers: { + Accept: 'application/json', + Authorization: `Bearer ${this.provider.apiKey}` + } + }) + if (!response.ok) { + const errorText = await response.text() + throw new Error( + `DeepSeek API key check failed: ${response.status} ${response.statusText} - ${errorText}` + ) + } + const payload = (await response.json()) as { + is_available: boolean + balance_infos: Array<{ currency: string; total_balance: string }> + } + if (!payload.is_available) { + throw new Error('DeepSeek API key is not available') + } + const balanceInfo = + payload.balance_infos.find((info) => info.currency === 'CNY') || + payload.balance_infos.find((info) => info.currency === 'USD') || + payload.balance_infos[0] + if (!balanceInfo) { + throw new Error('No balance information available') + } + const totalBalance = Number.parseFloat(balanceInfo.total_balance) + const currencySymbol = balanceInfo.currency === 'USD' ? '$' : '¥' + return { + limit_remaining: `${currencySymbol}${totalBalance}`, + remainNum: totalBalance + } + } + case 'ppio': { + const response = await fetch('https://api.ppinfra.com/v3/user', { + method: 'GET', + headers: { + Authorization: this.provider.apiKey, + 'Content-Type': 'application/json' + } + }) + if (!response.ok) { + const errorText = await response.text() + throw new Error( + `PPIO API key check failed: ${response.status} ${response.statusText} - ${errorText}` + ) + } + const payload = (await response.json()) as { credit_balance: number } + return { + limit_remaining: '¥' + payload.credit_balance / 10000, + remainNum: payload.credit_balance + } + } + case 'tokenflux': { + const response = await fetch(`${this.provider.baseUrl}/models`, { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}`, + 'Content-Type': 'application/json' + } + }) + if (!response.ok) { + const errorText = await response.text() + throw new Error( + `TokenFlux API key check failed: ${response.status} ${response.statusText} - ${errorText}` + ) + } + return { + limit_remaining: 'Available', + remainNum: undefined + } + } + case '302ai': { + const response = await fetch('https://api.302.ai/dashboard/balance', { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}`, + 'Content-Type': 'application/json' + } + }) + if (!response.ok) { + const errorText = await response.text() + throw new Error( + `302AI API key check failed: ${response.status} ${response.statusText} - ${errorText}` + ) + } + const payload = (await response.json()) as { data: { balance: string } } + return { + limit_remaining: `$${payload.data.balance}`, + remainNum: Number.parseFloat(payload.data.balance) + } + } + case 'cherryin': { + const baseUrl = (this.provider.baseUrl || 'https://open.cherryin.ai/v1').replace(/\/$/, '') + const usageResponse = await fetch(`${baseUrl}/dashboard/billing/usage`, { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}`, + 'Content-Type': 'application/json' + } + }) + if (!usageResponse.ok) { + const errorText = await usageResponse.text() + throw new Error( + `CherryIn usage check failed: ${usageResponse.status} ${usageResponse.statusText} - ${errorText}` + ) + } + const usageData = (await usageResponse.json()) as { total_usage: number } + const usageUsd = Number.isFinite(Number(usageData.total_usage)) + ? Number(usageData.total_usage) / 100 + : 0 + return { + usage: `$${usageUsd.toFixed(2)}` + } + } + case 'modelscope': { + const response = await this.fetchOpenAIModelRecords({ timeout: 10000 }) + return { + limit_remaining: 'Available', + remainNum: response.length + } + } + case 'siliconcloud': { + const response = await fetch('https://api.siliconflow.cn/v1/user/info', { + method: 'GET', + headers: { + Authorization: `Bearer ${this.provider.apiKey}`, + 'Content-Type': 'application/json' + } + }) + if (!response.ok) { + const errorText = await response.text() + throw new Error( + `SiliconCloud API key check failed: ${response.status} ${response.statusText} - ${errorText}` + ) + } + const payload = (await response.json()) as { + code: number + message: string + status: boolean + data: { totalBalance: string } + } + if (payload.code !== 20000 || !payload.status) { + throw new Error(`SiliconCloud API error: ${payload.message}`) + } + const totalBalance = Number.parseFloat(payload.data.totalBalance) + return { + limit_remaining: `¥${totalBalance}`, + remainNum: totalBalance + } + } + case 'none': + default: + return null + } + } + + private validateCredentials(strategy: AiSdkCredentialStrategy): string | null { + switch (strategy) { + case 'api-key': + return this.provider.apiKey ? null : 'Missing API key' + case 'anthropic': + return this.provider.apiKey || process.env.ANTHROPIC_API_KEY ? null : 'Missing API key' + case 'vertex': { + const provider = this.provider as VERTEX_PROVIDER + return provider.projectId && + provider.location && + (provider.apiKey || (provider.accountClientEmail && provider.accountPrivateKey)) + ? null + : 'projectId, location, and API credentials are required for Vertex AI' + } + case 'bedrock': { + const provider = this.provider as AWS_BEDROCK_PROVIDER + const accessKeyId = provider.credential?.accessKeyId || process.env.BEDROCK_ACCESS_KEY_ID + const secretAccessKey = + provider.credential?.secretAccessKey || process.env.BEDROCK_SECRET_ACCESS_KEY + const region = provider.credential?.region || process.env.BEDROCK_REGION + return accessKeyId && secretAccessKey && region ? null : 'Missing AWS Bedrock credentials' + } + case 'none': + default: + return null + } + } + + public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { + switch (this.definition.checkStrategy) { + case 'key-status': + try { + const keyStatus = await this.getKeyStatus() + if (keyStatus?.remainNum !== undefined && keyStatus.remainNum <= 0) { + return { + isOk: false, + errorMsg: `API key quota exhausted. Remaining: ${keyStatus.limit_remaining}` + } + } + return { isOk: true, errorMsg: null } + } catch (error) { + return { + isOk: false, + errorMsg: toErrorMessage(error, 'Provider check failed') + } + } + case 'generate-text': { + const credentialError = this.validateCredentials( + this.definition.credentialStrategy ?? 'none' + ) + if (credentialError) { + return { + isOk: false, + errorMsg: credentialError + } + } + + try { + await this.runText( + [{ role: 'user', content: this.definition.checkPrompt || 'Hello' }], + this.definition.checkModelId || '', + this.definition.checkTemperature ?? 0.2, + this.definition.checkMaxTokens ?? 16 + ) + return { isOk: true, errorMsg: null } + } catch (error) { + return { + isOk: false, + errorMsg: toErrorMessage(error, 'Provider check failed') + } + } + } + case 'fetch-models': + default: + try { + await this.fetchProviderModels() + return { isOk: true, errorMsg: null } + } catch (error) { + return { + isOk: false, + errorMsg: toErrorMessage(error, 'Provider check failed') + } + } + } + } + + private buildTranscript(messages: ChatMessage[]): string { + return messages + .map((message) => `${message.role}: ${this.stringifyMessageContent(message.content)}`) + .join('\n') + } + + private async runSummaryTitlePrompt( + messages: ChatMessage[], + modelId: string, + temperature: number, + maxTokens?: number + ): Promise { + const response = await this.runText( + [ + { + role: 'user', + content: `${SUMMARY_TITLES_PROMPT}\n\n${this.buildTranscript(messages)}` + } + ], + modelId, + temperature, + maxTokens + ) + return response.content.trim() + } + + private async runPromptCompletion( + prompt: string, + modelId: string, + temperature?: number, + maxTokens?: number, + systemPrompt?: string + ): Promise { + return this.runText( + [ + ...(systemPrompt ? [{ role: 'system' as const, content: systemPrompt }] : []), + { role: 'user', content: prompt } + ], + modelId, + temperature, + maxTokens + ) + } + + private async getSuggestionsByPreset( + preset: AiSdkBehaviorPreset, + context: string | ChatMessage[], + modelId: string, + temperature?: number, + maxTokens?: number, + systemPrompt?: string + ): Promise { + const promptContext = Array.isArray(context) ? this.buildTranscript(context) : context + + if (preset === 'anthropic') { + const response = await this.runPromptCompletion( + `根据下面的上下文,给出3个可能的回复建议,每个建议一行,不要有编号或者额外的解释:\n\n${promptContext}`, + modelId, + temperature ?? 0.7, + maxTokens ?? 128, + systemPrompt + ) + return response.content + .split('\n') + .map((line) => line.trim()) + .filter(Boolean) + .slice(0, 3) + } + + if (preset === 'google') { + const response = await this.runPromptCompletion( + `Based on the following context, please provide up to 5 reasonable suggestion options, each on a new line without numbering:\n\n${promptContext}`, + modelId, + temperature ?? 0.7, + maxTokens ?? 128, + systemPrompt + ) + return response.content + .split('\n') + .map((line) => line.trim()) + .filter(Boolean) + .slice(0, 5) + } + + const messages = Array.isArray(context) + ? context + : [{ role: 'user' as const, content: context }] + const lastUserMessage = messages.filter((message) => message.role === 'user').pop() + if (!lastUserMessage) { + return [] + } + + const response = await this.runText( + [ + { + role: 'system', + content: + 'Based on the last user message in the conversation history, provide 3 brief, relevant follow-up suggestions or questions. Output ONLY the suggestions, each on a new line.' + }, + ...messages.slice(-5) + ], + modelId, + temperature ?? 0.7, + maxTokens ?? 60 + ) + + return response.content + .split('\n') + .map((item) => item.trim()) + .filter((item) => item.length > 0 && !item.match(/^[0-9.\-*\s]*/)) + } + + public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { + const decision = this.resolveRouteDecision(modelId) + if (decision.endpointType === 'image-generation') { + return this.buildFallbackSummaryTitle(messages) + } + + const preset = this.getBehaviorPreset(decision) + + switch (preset) { + case 'anthropic': + return this.runSummaryTitlePrompt(messages, modelId, 0.3, 50) + case 'google': { + const title = await this.runSummaryTitlePrompt(messages, modelId, 0.4) + return title || 'New Conversation' + } + case 'openai': + case 'title-summary': + case 'english-summary': + case 'chinese-summary': + default: { + const title = await this.runSummaryTitlePrompt(messages, modelId, 0.5) + return title.replace(/["']/g, '').trim() + } + } + } + + public async completions( + messages: ChatMessage[], + modelId: string, + temperature?: number, + maxTokens?: number + ): Promise { + const decision = this.resolveRouteDecision(modelId) + if (decision.endpointType === 'grok-image' || decision.endpointType === 'image-generation') { + return this.collectStreamResponse(messages, modelId, temperature, maxTokens) + } + + return this.runText(messages, modelId, temperature, maxTokens) + } + + public async summaries( + text: string, + modelId: string, + temperature?: number, + maxTokens?: number, + systemPrompt?: string + ): Promise { + const decision = this.resolveRouteDecision(modelId) + if (decision.endpointType === 'grok-image' || decision.endpointType === 'image-generation') { + return this.collectStreamResponse( + [{ role: 'user', content: text }], + modelId, + temperature, + maxTokens + ) + } + + const preset = this.getBehaviorPreset(decision) + switch (preset) { + case 'anthropic': + return this.runPromptCompletion( + `请对以下内容进行摘要:\n\n${text}\n\n请提供一个简洁明了的摘要。`, + modelId, + temperature, + maxTokens, + systemPrompt + ) + case 'google': + return this.runPromptCompletion( + `Please generate a concise summary for the following content:\n\n${text}`, + modelId, + temperature, + maxTokens, + systemPrompt + ) + case 'title-summary': + return this.runPromptCompletion( + "You need to summarize the user's conversation into a title of no more than 10 words, with the title language matching the user's primary language, without using punctuation or other special symbols:\n" + + text, + modelId, + temperature, + maxTokens, + systemPrompt + ) + case 'english-summary': + return this.runPromptCompletion( + `Please summarize the following content using concise language and highlighting key points:\n${text}`, + modelId, + temperature, + maxTokens, + systemPrompt + ) + case 'chinese-summary': + return this.runPromptCompletion( + `请总结以下内容,使用简洁的语言,突出重点:\n${text}`, + modelId, + temperature, + maxTokens, + systemPrompt + ) + case 'openai': + default: + if (this.provider.id === 'deepseek') { + return this.runPromptCompletion( + `${SUMMARY_TITLES_PROMPT}\n\n${text}`, + modelId, + temperature, + maxTokens, + systemPrompt + ) + } + return this.runText( + [ + { role: 'system', content: 'Summarize the following text concisely:' }, + { role: 'user', content: text } + ], + modelId, + temperature, + maxTokens + ) + } + } + + public async generateText( + prompt: string, + modelId: string, + temperature?: number, + maxTokens?: number, + systemPrompt?: string + ): Promise { + const decision = this.resolveRouteDecision(modelId) + if (decision.endpointType === 'grok-image' || decision.endpointType === 'image-generation') { + return this.collectStreamResponse( + [{ role: 'user', content: prompt }], + modelId, + temperature, + maxTokens + ) + } + + return this.runPromptCompletion(prompt, modelId, temperature, maxTokens, systemPrompt) + } + + public async suggestions( + context: string | ChatMessage[], + modelId: string, + temperature?: number, + maxTokens?: number, + systemPrompt?: string + ): Promise { + const decision = this.resolveRouteDecision(modelId) + return this.getSuggestionsByPreset( + this.getBehaviorPreset(decision), + context, + modelId, + temperature, + maxTokens, + systemPrompt + ) + } + + public async *coreStream( + messages: ChatMessage[], + modelId: string, + modelConfig: ModelConfig, + temperature: number, + maxTokens: number, + tools: MCPToolDefinition[] + ): AsyncGenerator { + yield* this.streamText(messages, modelId, modelConfig, temperature, maxTokens, tools) + } + + private getEmbeddingStrategy(): AiSdkEmbeddingStrategy { + return this.definition.embeddingStrategy ?? 'none' + } + + public async getEmbeddings(modelId: string, texts: string[]): Promise { + switch (this.getEmbeddingStrategy()) { + case 'openai': + case 'google': + return this.runEmbeddings(modelId, texts) + case 'new-api': { + return this.runEmbeddingsWithDecision(modelId, texts, { + providerKind: 'openai-compatible', + providerPatch: { + apiType: 'openai-completions', + baseUrl: `${this.getNormalizedNewApiHost()}/v1`, + capabilityProviderId: resolveNewApiCapabilityProviderId('openai') + } + }) + } + case 'zenmux': + if (modelId.trim().toLowerCase().startsWith('anthropic/')) { + throw new Error(`Embeddings not supported for Anthropic models: ${modelId}`) + } + return this.runEmbeddings(modelId, texts) + case 'none': + default: + throw new Error('embedding is not supported by this provider') + } + } + + public async getDimensions(modelId: string): Promise { + switch (this.getEmbeddingStrategy()) { + case 'openai': + case 'google': + return this.runDimensions(modelId) + case 'new-api': { + return this.runDimensionsWithDecision(modelId, { + providerKind: 'openai-compatible', + providerPatch: { + apiType: 'openai-completions', + baseUrl: `${this.getNormalizedNewApiHost()}/v1`, + capabilityProviderId: resolveNewApiCapabilityProviderId('openai') + } + }) + } + case 'zenmux': + if (modelId.trim().toLowerCase().startsWith('anthropic/')) { + throw new Error(`Embeddings not supported for Anthropic models: ${modelId}`) + } + return this.runDimensions(modelId) + case 'none': + default: + throw new Error('embedding is not supported by this provider') + } + } +} diff --git a/src/main/presenter/llmProviderPresenter/providers/aihubmixProvider.ts b/src/main/presenter/llmProviderPresenter/providers/aihubmixProvider.ts deleted file mode 100644 index 107598e4a..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/aihubmixProvider.ts +++ /dev/null @@ -1,66 +0,0 @@ -import { LLM_PROVIDER, LLMResponse, ChatMessage, IConfigPresenter } from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import { proxyConfig } from '@/presenter/proxyConfig' -import { ProxyAgent } from 'undici' -import OpenAI from 'openai' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -export class AihubmixProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - protected createOpenAIClient(): void { - // Get proxy configuration - const proxyUrl = proxyConfig.getProxyUrl() - const fetchOptions: { dispatcher?: ProxyAgent } = {} - - if (proxyUrl) { - console.log(`[Aihubmix Provider] Using proxy: ${proxyUrl}`) - const proxyAgent = new ProxyAgent(proxyUrl) - fetchOptions.dispatcher = proxyAgent - } - - this.openai = new OpenAI({ - apiKey: this.provider.apiKey, - baseURL: this.provider.baseUrl, - defaultHeaders: { - ...this.defaultHeaders, - 'APP-Code': 'SMUE7630' - }, - fetchOptions - }) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts b/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts deleted file mode 100644 index 5b29a7708..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/anthropicProvider.ts +++ /dev/null @@ -1,948 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - MODEL_META, - LLMCoreStreamEvent, - ModelConfig, - MCPToolDefinition, - ChatMessage, - IConfigPresenter -} from '@shared/presenter' -import { createStreamEvent } from '@shared/types/core/llm-events' -import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider' -import Anthropic from '@anthropic-ai/sdk' -import { proxyConfig } from '../../proxyConfig' -import { ProxyAgent } from 'undici' -import type { Usage } from '@anthropic-ai/sdk/resources' -import type { ProviderMcpRuntimePort } from '../runtimePorts' -import { - applyAnthropicExplicitCacheBreakpoint, - applyAnthropicTopLevelCacheControl, - resolvePromptCachePlan -} from '../promptCacheStrategy' - -type CacheAwareAnthropicUsage = Usage & { - cache_read_input_tokens?: number - cache_creation_input_tokens?: number - cacheReadInputTokens?: number - cacheWriteInputTokens?: number -} - -function getAnthropicUsageNumber( - usage: CacheAwareAnthropicUsage | undefined, - snakeKey: 'cache_read_input_tokens' | 'cache_creation_input_tokens', - camelKey: 'cacheReadInputTokens' | 'cacheWriteInputTokens' -): number { - const value = usage?.[snakeKey] ?? usage?.[camelKey] - return typeof value === 'number' && Number.isFinite(value) ? value : 0 -} - -function buildAnthropicUsageSnapshot(usage: CacheAwareAnthropicUsage | undefined): { - prompt_tokens: number - completion_tokens: number - total_tokens: number - cached_tokens?: number - cache_write_tokens?: number -} | null { - if (!usage) { - return null - } - - const uncachedInputTokens = - typeof usage.input_tokens === 'number' && Number.isFinite(usage.input_tokens) - ? usage.input_tokens - : 0 - const completionTokens = - typeof usage.output_tokens === 'number' && Number.isFinite(usage.output_tokens) - ? usage.output_tokens - : 0 - const cachedTokens = getAnthropicUsageNumber( - usage, - 'cache_read_input_tokens', - 'cacheReadInputTokens' - ) - const cacheWriteTokens = getAnthropicUsageNumber( - usage, - 'cache_creation_input_tokens', - 'cacheWriteInputTokens' - ) - const promptTokens = uncachedInputTokens + cachedTokens + cacheWriteTokens - - return { - prompt_tokens: promptTokens, - completion_tokens: completionTokens, - total_tokens: promptTokens + completionTokens, - ...(cachedTokens > 0 ? { cached_tokens: cachedTokens } : {}), - ...(cacheWriteTokens > 0 ? { cache_write_tokens: cacheWriteTokens } : {}) - } -} - -export class AnthropicProvider extends BaseLLMProvider { - private anthropic!: Anthropic - private defaultModel = 'claude-sonnet-4-5-20250929' - - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - this.init() - } - - private buildAnthropicEndpoint(): string { - const baseUrl = (this.provider.baseUrl || 'https://api.anthropic.com').replace(/\/+$/, '') - return `${baseUrl}/v1/messages` - } - - private buildAnthropicApiKeyHeaders(): Record { - return { - 'Content-Type': 'application/json', - 'anthropic-version': '2023-06-01', - 'x-api-key': this.provider.apiKey || process.env.ANTHROPIC_API_KEY || 'MISSING_API_KEY' - } - } - - private applyPromptCache>( - requestParams: T, - modelId: string, - messages: Anthropic.MessageParam[], - conversationId?: string - ): T { - const plan = resolvePromptCachePlan({ - providerId: this.provider.id, - apiType: 'anthropic', - modelId, - messages: messages as unknown[], - conversationId - }) - const nextRequestParams = - plan.mode === 'anthropic_explicit' - ? { - ...requestParams, - messages: applyAnthropicExplicitCacheBreakpoint(messages, plan) - } - : requestParams - - return applyAnthropicTopLevelCacheControl(nextRequestParams, plan) - } - - public onProxyResolved(): void { - this.init() - } - - protected async init() { - if (this.provider.enable) { - try { - const apiKey = this.provider.apiKey || process.env.ANTHROPIC_API_KEY || null - if (!apiKey) { - console.warn('[Anthropic Provider] No API key available') - return - } - - const proxyUrl = proxyConfig.getProxyUrl() - const fetchOptions: { dispatcher?: ProxyAgent } = {} - - if (proxyUrl) { - console.log('[Anthropic Provider] Proxy enabled') - const proxyAgent = new ProxyAgent(proxyUrl) - fetchOptions.dispatcher = proxyAgent - } - - this.anthropic = new Anthropic({ - apiKey, - baseURL: this.provider.baseUrl || 'https://api.anthropic.com', - defaultHeaders: this.defaultHeaders, - fetchOptions - }) - - await super.init() - } catch (error) { - console.error('Failed to initialize Anthropic provider:', error) - } - } - } - - protected async fetchProviderModels(): Promise { - try { - const models = await this.anthropic.models.list() - - // 引入getModelConfig函数 - if (models && models.data && Array.isArray(models.data)) { - const processedModels: MODEL_META[] = [] - - for (const model of models.data) { - // 确保模型有必要的属性 - if (model.id) { - // 获取额外的配置信息 - const modelConfig = this.configPresenter.getModelConfig(model.id, this.provider.id) - - // 提取模型组名称,通常是Claude后面的版本号 - - processedModels.push({ - id: model.id, - name: model.display_name || model.id, - providerId: this.provider.id, - maxTokens: modelConfig?.maxTokens || 64_000, - group: 'Claude', - isCustom: false, - contextLength: modelConfig?.contextLength || 200000, - vision: modelConfig?.vision || false, - functionCall: modelConfig?.functionCall || false, - reasoning: modelConfig?.reasoning || false - }) - } - } - - // 如果成功解析出模型,则返回 - if (processedModels.length > 0) { - return processedModels - } - } - - // 如果API请求失败或返回数据解析失败,优先使用聚合 Provider DB 的模型列表 - console.log('从API获取模型列表失败,使用 Provider DB 作为兜底') - } catch (error) { - console.error('获取Anthropic模型列表出错:', error) - } - const dbModels = this.configPresenter.getDbProviderModels(this.provider.id).map((m) => ({ - id: m.id, - name: m.name, - providerId: this.provider.id, - maxTokens: m.maxTokens, - group: m.group || 'default', - isCustom: false, - contextLength: m.contextLength, - vision: m.vision || false, - functionCall: m.functionCall || false, - reasoning: m.reasoning || false, - ...(m.type ? { type: m.type } : {}) - })) - - return dbModels - } - - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - if (!this.anthropic) { - return { isOk: false, errorMsg: 'Anthropic SDK not initialized' } - } - - await this.anthropic.messages.create({ - model: this.defaultModel, - max_tokens: 10, - messages: [{ role: 'user', content: 'Hello' }] - }) - - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - console.error('Anthropic API check failed:', error) - const errorMessage = error instanceof Error ? error.message : String(error) - return { isOk: false, errorMsg: `API check failed: ${errorMessage}` } - } - } - - private formatMessages(messages: ChatMessage[]): { - system?: string - messages: Anthropic.MessageParam[] - } { - // console.log('开始格式化消息,总消息数:', messages.length) - - // 提取系统消息 - let systemContent = '' - for (const msg of messages) { - if (msg.role === 'system') { - systemContent += - (typeof msg.content === 'string' - ? msg.content - : msg.content && Array.isArray(msg.content) - ? msg.content - .filter((c) => c.type === 'text') - .map((c) => c.text || '') - .join('\n') - : '') + '\n' - } - } - - // 定义消息组和内容块的类型 - type ContentBlock = Anthropic.ContentBlockParam - type ToolCall = { id: string; function: { name: string; arguments?: string } } - type MessageGroup = { - role: string - contents: ContentBlock[] - toolCalls?: string[] - hasToolUse?: boolean - } - - // 预处理:对消息进行分组 - // 新的逻辑:每个assistant消息如果包含tool_calls,就单独成组 - const messageGroups: MessageGroup[] = [] - let currentGroup: MessageGroup | null = null - - // 用于跟踪tool_calls和tool响应的匹配 - const toolResponseMap = new Map() - - // 第一阶段:建立初始分组和收集工具响应 - for (let i = 0; i < messages.length; i++) { - const msg = messages[i] - if (msg.role === 'system') continue // 系统消息已处理 - - // console.log( - // `处理第${i + 1}条消息, 角色:${msg.role}`, - // msg.content - // ? typeof msg.content === 'string' - // ? '内容长度:' + msg.content.length - // : '数组内容长度:' + (Array.isArray(msg.content) ? msg.content.length : 0) - // : '无内容' - // ) - - // 处理工具响应,将其与对应的工具调用关联 - if (msg.role === 'tool' && 'tool_call_id' in msg) { - const toolCallId = msg.tool_call_id as string - const responseContent = - typeof msg.content === 'string' - ? msg.content - : Array.isArray(msg.content) - ? JSON.stringify(msg.content) - : '' - - toolResponseMap.set(toolCallId, { - type: 'tool_result', - tool_use_id: toolCallId, - content: responseContent - } as ContentBlock) - - // console.log('记录tool响应,tool_call_id:', toolCallId) - continue - } - - // 处理用户消息 - 开始新组 - if (msg.role === 'user') { - if (currentGroup) { - messageGroups.push(currentGroup) - } - - let formattedContent: ContentBlock[] = [] - if (typeof msg.content === 'string') { - formattedContent = [{ type: 'text', text: msg.content }] - } else if (msg.content && Array.isArray(msg.content)) { - formattedContent = msg.content.map((c) => { - if (c.type === 'image_url' && c.image_url) { - return { - type: 'image', - source: c.image_url.url.startsWith('data:image') - ? { - type: 'base64', - data: c.image_url.url.split(',')[1], - media_type: c.image_url.url.split(';')[0].split(':')[1] as - | 'image/jpeg' - | 'image/png' - | 'image/gif' - | 'image/webp' - } - : { type: 'url', url: c.image_url.url } - } as ContentBlock - } else { - return { type: 'text', text: c.text || '' } as ContentBlock - } - }) - } - - currentGroup = { - role: 'user', - contents: formattedContent - } - - // console.log('开始新的用户消息组') - continue - } - - // 处理assistant消息 - 添加到当前组或开始新组 - if (msg.role === 'assistant') { - // 检查是否需要新建一个组: - // 1. 当前还没有组 - // 2. 当前组不是assistant - // 3. 当前组是assistant但包含了工具调用 - const shouldCreateNewGroup = - !currentGroup || currentGroup.role !== 'assistant' || currentGroup.hasToolUse === true - - if (shouldCreateNewGroup) { - if (currentGroup) { - messageGroups.push(currentGroup) - } - - currentGroup = { - role: 'assistant', - contents: [], - toolCalls: [], - hasToolUse: false - } - } - - // 确保currentGroup已初始化 - if (!currentGroup) { - currentGroup = { - role: 'assistant', - contents: [], - toolCalls: [], - hasToolUse: false - } - } - - // 处理常规内容 - if (msg.content) { - let assistantContent: ContentBlock[] = [] - if (typeof msg.content === 'string') { - if (msg.content.trim()) { - assistantContent = [{ type: 'text', text: msg.content }] - } - } else if (Array.isArray(msg.content)) { - // 处理各种内容类型 - for (const content of msg.content) { - if (content.type === 'text') { - currentGroup.contents.push({ - type: 'text', - text: content.text || '' - } as ContentBlock) - } else if (content.type === 'image_url' && content.image_url) { - currentGroup.contents.push({ - type: 'image', - source: content.image_url.url.startsWith('data:image') - ? { - type: 'base64', - data: content.image_url.url.split(',')[1], - media_type: content.image_url.url.split(';')[0].split(':')[1] as - | 'image/jpeg' - | 'image/png' - | 'image/gif' - | 'image/webp' - } - : { type: 'url', url: content.image_url.url } - } as ContentBlock) - } - } - - continue - } - - currentGroup.contents.push(...assistantContent) - // console.log('添加文本内容到当前assistant组, 项数:', assistantContent.length) - } - - // 处理tool_calls - if ('tool_calls' in msg && Array.isArray(msg.tool_calls)) { - // console.log('处理assistant消息中的tool_calls', msg.tool_calls.length) - - // 标记当前组包含工具调用 - if (currentGroup) { - currentGroup.hasToolUse = true - } - - for (const toolCall of msg.tool_calls as ToolCall[]) { - try { - // @ts-ignore - 转换为Anthropic格式 - currentGroup.contents.push({ - type: 'tool_use', - id: toolCall.id, - name: toolCall.function.name, - input: JSON.parse(toolCall.function.arguments || '{}') - } as ContentBlock) - - // console.log('添加tool_call到当前assistant组:', toolCall.function.name) - - // 记录工具调用,稍后查找响应 - if (!currentGroup.toolCalls) currentGroup.toolCalls = [] - currentGroup.toolCalls.push(toolCall.id) - } catch (e) { - console.error('Error processing tool_call:', e) - } - } - } - } - } - - // 添加最后一个组 - if (currentGroup) { - messageGroups.push(currentGroup) - } - - // console.log('预处理完成,消息组数量:', messageGroups.length) - - // 第二阶段:生成最终的格式化消息 - const formattedMessages: Anthropic.MessageParam[] = [] - - for (const group of messageGroups) { - if (group.contents.length === 0) continue - - // 添加组的主要内容 - formattedMessages.push({ - role: group.role as 'user' | 'assistant', - content: group.contents as Anthropic.ContentBlockParam[] - }) - - // console.log(`添加${group.role}组,内容项数:${group.contents.length}`) - - // 如果是assistant组且有工具调用,添加对应的工具响应 - if (group.role === 'assistant' && group.toolCalls && group.toolCalls.length > 0) { - for (const toolCallId of group.toolCalls) { - const toolResponse = toolResponseMap.get(toolCallId) - if (toolResponse) { - formattedMessages.push({ - role: 'user', - content: [toolResponse] - }) - - // console.log('添加工具响应,tool_call_id:', toolCallId) - } - } - } - } - - // console.log('格式化完成, 最终消息数:', formattedMessages.length) - // 为调试目的,打印前3条消息的结构 - // formattedMessages.slice(0, Math.min(3, formattedMessages.length)).forEach((msg, i) => { - // console.log(`最终消息#${i + 1}:`, { - // role: msg.role, - // contentLength: Array.isArray(msg.content) ? msg.content.length : 0, - // contentTypes: Array.isArray(msg.content) - // ? msg.content.map((c) => c.type).join(',') - // : typeof msg.content - // }) - // }) - - return { - system: systemContent || undefined, - messages: formattedMessages - } - } - - public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { - const prompt = `${SUMMARY_TITLES_PROMPT}\n\n${messages.map((m) => `${m.role}: ${m.content}`).join('\n')}` - const response = await this.generateText(prompt, modelId, 0.3, 50) - - return response.content.trim() - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - try { - const formattedMessages = this.formatMessages(messages) - - const requestParams: any = { - model: modelId, - max_tokens: maxTokens || 1024, - temperature: temperature ?? 0.7, - messages: formattedMessages.messages - } - - if (formattedMessages.system) { - requestParams.system = formattedMessages.system - } - - const cachedRequestParams = this.applyPromptCache( - requestParams, - modelId, - formattedMessages.messages - ) - - if (!this.anthropic) { - throw new Error('Anthropic client is not initialized') - } - const response = await this.anthropic.messages.create(cachedRequestParams) - - const resultResp: LLMResponse = { - content: '' - } - - // 添加usage信息 - if (response.usage) { - const usageSnapshot = buildAnthropicUsageSnapshot( - response.usage as CacheAwareAnthropicUsage - ) - resultResp.totalUsage = { - prompt_tokens: usageSnapshot?.prompt_tokens ?? 0, - completion_tokens: usageSnapshot?.completion_tokens ?? 0, - total_tokens: usageSnapshot?.total_tokens ?? 0 - } - } - - // 获取文本内容 - const content = response.content - .filter((block) => block.type === 'text') - .map((block) => (block.type === 'text' ? block.text : '')) - .join('') - - // 处理标签 - if (content.includes('')) { - const thinkStart = content.indexOf('') - const thinkEnd = content.indexOf('') - - if (thinkEnd > thinkStart) { - // 提取reasoning_content - resultResp.reasoning_content = content.substring(thinkStart + 7, thinkEnd).trim() - - // 合并前后的普通内容 - const beforeThink = content.substring(0, thinkStart).trim() - const afterThink = content.substring(thinkEnd + 8).trim() - resultResp.content = [beforeThink, afterThink].filter(Boolean).join('\n') - } else { - // 如果没有找到配对的结束标签,将所有内容作为普通内容 - resultResp.content = content - } - } else { - // 没有think标签,所有内容作为普通内容 - resultResp.content = content - } - - return resultResp - } catch (error) { - console.error('Anthropic completions error:', error) - throw error - } - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number, - systemPrompt?: string - ): Promise { - const prompt = `请对以下内容进行摘要: - -${text} - -请提供一个简洁明了的摘要。` - - return this.generateText(prompt, modelId, temperature, maxTokens, systemPrompt) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number, - systemPrompt?: string - ): Promise { - try { - const requestParams: any = { - model: modelId, - max_tokens: maxTokens || 1024, - temperature: temperature ?? 0.7, - messages: [{ role: 'user' as const, content: [{ type: 'text' as const, text: prompt }] }] - } - - if (systemPrompt) { - requestParams.system = systemPrompt - } - - const cachedRequestParams = this.applyPromptCache( - requestParams, - modelId, - requestParams.messages - ) - - if (!this.anthropic) { - throw new Error('Anthropic client is not initialized') - } - const response = await this.anthropic.messages.create(cachedRequestParams) - - return { - content: response.content - .filter((block: any) => block.type === 'text') - .map((block: any) => (block.type === 'text' ? block.text : '')) - .join(''), - reasoning_content: undefined - } - } catch (error) { - console.error('Anthropic generate text error:', error) - throw error - } - } - - async suggestions( - context: string, - modelId: string, - temperature?: number, - maxTokens?: number, - systemPrompt?: string - ): Promise { - const prompt = ` -根据下面的上下文,给出3个可能的回复建议,每个建议一行,不要有编号或者额外的解释: - -${context} -` - try { - const requestParams: any = { - model: modelId, - max_tokens: maxTokens || 1024, - temperature: temperature ?? 0.7, - messages: [{ role: 'user' as const, content: [{ type: 'text' as const, text: prompt }] }] - } - - if (systemPrompt) { - requestParams.system = systemPrompt - } - - const cachedRequestParams = this.applyPromptCache( - requestParams, - modelId, - requestParams.messages - ) - - if (!this.anthropic) { - throw new Error('Anthropic client is not initialized') - } - const response = await this.anthropic.messages.create(cachedRequestParams) - - const suggestions = response.content - .filter((block: any) => block.type === 'text') - .map((block: any) => (block.type === 'text' ? block.text : '')) - .join('') - .split('\n') - .map((s: string) => s.trim()) - .filter((s: string) => s.length > 0) - .slice(0, 3) - - return suggestions - } catch (error) { - console.error('Anthropic suggestions error:', error) - return ['建议生成失败'] - } - } - - // 添加coreStream方法 - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!modelId) throw new Error('Model ID is required') - console.log('modelConfig', modelConfig, modelId) - - if (!this.anthropic) throw new Error('Anthropic client is not initialized') - try { - // 格式化消息 - const formattedMessagesObject = this.formatMessages(messages) - - // 将MCP工具转换为Anthropic工具格式 - const anthropicTools = - mcpTools.length > 0 - ? await this.mcpRuntime?.mcpToolsToAnthropicTools(mcpTools, this.provider.id) - : undefined - - // 创建基本请求参数 - const streamParams = { - model: modelId, - max_tokens: maxTokens || 1024, - temperature: temperature ?? 0.7, - messages: formattedMessagesObject.messages, - stream: true - } as Anthropic.Messages.MessageCreateParamsStreaming - - // 启用Claude 3.7模型的思考功能 - if (modelId.includes('claude-3-7')) { - streamParams.thinking = { budget_tokens: 1024, type: 'enabled' } - } - - // 如果有系统消息,添加到请求参数中 - if (formattedMessagesObject.system) { - // @ts-ignore - system属性在类型定义中可能不存在,但API已支持 - streamParams.system = formattedMessagesObject.system - } - - // 添加工具参数 - if (anthropicTools && anthropicTools.length > 0) { - // @ts-ignore - 类型不匹配,但格式是正确的 - streamParams.tools = anthropicTools - } - const cachedStreamParams = this.applyPromptCache( - streamParams as unknown as Record, - modelId, - formattedMessagesObject.messages, - modelConfig.conversationId - ) as unknown as Anthropic.Messages.MessageCreateParamsStreaming - await this.emitRequestTrace(modelConfig, { - endpoint: this.buildAnthropicEndpoint(), - headers: this.buildAnthropicApiKeyHeaders(), - body: cachedStreamParams - }) - // console.log('streamParams', JSON.stringify(streamParams.messages)) - // 创建Anthropic流 - const stream = await this.anthropic.messages.create(cachedStreamParams) - - // 状态变量 - let accumulatedJson = '' - let toolUseDetected = false - let currentToolId = '' - let currentToolName = '' - let currentToolInputs: Record = {} - let usageMetadata: Usage | undefined - // 处理流中的各种事件 - for await (const chunk of stream) { - // 处理使用统计 - if (chunk.type === 'message_start' && chunk.message.usage) { - usageMetadata = chunk.message.usage - } - - // 处理工具调用开始 - // @ts-ignore - Anthropic SDK类型定义不完整 - if (chunk.type === 'content_block_start' && chunk.content_block?.type === 'tool_use') { - toolUseDetected = true - // @ts-ignore - content_block不在类型定义中 - currentToolId = chunk.content_block.id || `anthropic-tool-${Date.now()}` - // @ts-ignore - content_block不在类型定义中 - currentToolName = chunk.content_block.name || '' - currentToolInputs = {} - accumulatedJson = '' - - // 发送工具调用开始事件 - if (currentToolName) { - yield createStreamEvent.toolCallStart(currentToolId, currentToolName) - } - continue - } - - // 处理工具调用参数更新 - input_json_delta - // @ts-ignore - 类型定义中没有工具相关字段 - if (chunk.type === 'content_block_delta' && chunk.delta?.type === 'input_json_delta') { - // @ts-ignore - partial_json不在类型定义中 - const partialJson = chunk.delta.partial_json - if (partialJson) { - accumulatedJson += partialJson - - // 发送工具调用参数块事件 - yield createStreamEvent.toolCallChunk(currentToolId, partialJson) - } - continue - } - - // 处理工具使用更新 - tool_use_delta - // @ts-ignore - 类型定义中没有工具相关字段 - if (chunk.type === 'content_block_delta' && chunk.delta?.type === 'tool_use_delta') { - // @ts-ignore - delta.name不在类型定义中 - if (chunk.delta.name && !currentToolName) { - // @ts-ignore - 访问delta.name - currentToolName = chunk.delta.name - yield createStreamEvent.toolCallStart(currentToolId, currentToolName) - } - - // @ts-ignore - delta.input不在类型定义中 - if (chunk.delta.input) { - currentToolInputs = { - ...currentToolInputs, - // @ts-ignore - 访问delta.input - ...chunk.delta.input - } - } - continue - } - - // 处理内容块结束 - if (chunk.type === 'content_block_stop') { - // 处理工具调用完成 - if (toolUseDetected && currentToolName && accumulatedJson) { - try { - // 尝试解析完整的JSON - const jsonStr = accumulatedJson.trim() - if (jsonStr && (jsonStr.startsWith('{') || jsonStr.startsWith('['))) { - try { - const jsonObject = JSON.parse(jsonStr) - if (jsonObject && typeof jsonObject === 'object') { - currentToolInputs = { ...currentToolInputs, ...jsonObject } - } - } catch (e) { - console.error('解析完整JSON失败:', e) - } - } - } catch (e) { - console.error('处理累积JSON失败:', e) - } - - // 发送工具调用结束事件 - const argsString = JSON.stringify(currentToolInputs) - yield createStreamEvent.toolCallEnd(currentToolId, argsString) - - // 重置工具调用状态 - accumulatedJson = '' - } - continue - } - - // 检查消息是否因为工具调用而停止 - if (chunk.type === 'message_delta' && chunk.delta?.stop_reason === 'tool_use') { - // 设置为工具使用停止,主循环会处理工具调用 - continue - } - - // 处理思考内容(如果有) - // @ts-ignore - 类型定义中没有thinking相关字段 - if (chunk.type === 'content_block_delta' && chunk.delta?.type === 'thinking_delta') { - // @ts-ignore - delta.thinking不在类型定义中 - const thinkingText = chunk.delta.thinking - if (thinkingText) { - yield createStreamEvent.reasoning(thinkingText) - } - continue - } - - // 处理常规文本内容 - if (chunk.type === 'content_block_delta' && chunk.delta.type === 'text_delta') { - const text = chunk.delta.text - if (text) { - // 处理标签 - if (text.includes('')) { - const parts = text.split('') - if (parts[0]) { - yield createStreamEvent.text(parts[0]) - } - - if (parts[1]) { - // 检查是否包含 - const thinkParts = parts[1].split('') - if (thinkParts.length > 1) { - yield createStreamEvent.reasoning(thinkParts[0]) - - if (thinkParts[1]) { - yield createStreamEvent.text(thinkParts[1]) - } - } else { - yield createStreamEvent.reasoning(parts[1]) - } - } - } else if (text.includes('')) { - const parts = text.split('') - yield createStreamEvent.reasoning(parts[0]) - - if (parts[1]) { - yield createStreamEvent.text(parts[1]) - } - } else { - yield createStreamEvent.text(text) - } - } - continue - } - } - if (usageMetadata) { - const usageSnapshot = buildAnthropicUsageSnapshot(usageMetadata as CacheAwareAnthropicUsage) - if (usageSnapshot) { - yield createStreamEvent.usage(usageSnapshot) - } - } - // 发送停止事件 - yield createStreamEvent.stop(toolUseDetected ? 'tool_use' : 'complete') - } catch (error) { - console.error('Anthropic coreStream error:', error) - yield createStreamEvent.error(error instanceof Error ? error.message : '未知错误') - yield createStreamEvent.stop('error') - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts b/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts deleted file mode 100644 index 00ee51071..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider.ts +++ /dev/null @@ -1,986 +0,0 @@ -import { - LLMResponse, - MODEL_META, - LLMCoreStreamEvent, - ModelConfig, - MCPToolDefinition, - ChatMessage, - AWS_BEDROCK_PROVIDER, - IConfigPresenter -} from '@shared/presenter' -import { createStreamEvent } from '@shared/types/core/llm-events' -import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider' -import { BedrockClient, ListFoundationModelsCommand } from '@aws-sdk/client-bedrock' -import { - BedrockRuntimeClient, - InvokeModelCommand, - InvokeModelCommandOutput, - InvokeModelWithResponseStreamCommand -} from '@aws-sdk/client-bedrock-runtime' -import Anthropic from '@anthropic-ai/sdk' -import { Usage } from '@anthropic-ai/sdk/resources/messages' -import type { ProviderMcpRuntimePort } from '../runtimePorts' -import { - applyAnthropicExplicitCacheBreakpoint, - resolvePromptCachePlan -} from '../promptCacheStrategy' - -type CacheAwareBedrockUsage = Usage & { - cache_read_input_tokens?: number - cache_creation_input_tokens?: number - cacheReadInputTokens?: number - cacheWriteInputTokens?: number -} - -function getBedrockUsageNumber( - usage: CacheAwareBedrockUsage | undefined, - snakeKey: 'cache_read_input_tokens' | 'cache_creation_input_tokens', - camelKey: 'cacheReadInputTokens' | 'cacheWriteInputTokens' -): number { - const value = usage?.[snakeKey] ?? usage?.[camelKey] - return typeof value === 'number' && Number.isFinite(value) ? value : 0 -} - -function buildBedrockUsageSnapshot(usage: CacheAwareBedrockUsage | undefined): { - prompt_tokens: number - completion_tokens: number - total_tokens: number - cached_tokens?: number - cache_write_tokens?: number -} | null { - if (!usage) { - return null - } - - const uncachedInputTokens = - typeof usage.input_tokens === 'number' && Number.isFinite(usage.input_tokens) - ? usage.input_tokens - : 0 - const completionTokens = - typeof usage.output_tokens === 'number' && Number.isFinite(usage.output_tokens) - ? usage.output_tokens - : 0 - const cachedTokens = getBedrockUsageNumber( - usage, - 'cache_read_input_tokens', - 'cacheReadInputTokens' - ) - const cacheWriteTokens = getBedrockUsageNumber( - usage, - 'cache_creation_input_tokens', - 'cacheWriteInputTokens' - ) - const promptTokens = uncachedInputTokens + cachedTokens + cacheWriteTokens - - return { - prompt_tokens: promptTokens, - completion_tokens: completionTokens, - total_tokens: promptTokens + completionTokens, - ...(cachedTokens > 0 ? { cached_tokens: cachedTokens } : {}), - ...(cacheWriteTokens > 0 ? { cache_write_tokens: cacheWriteTokens } : {}) - } -} - -export class AwsBedrockProvider extends BaseLLMProvider { - private bedrock!: BedrockClient - private bedrockRuntime!: BedrockRuntimeClient - private defaultModel = 'anthropic.claude-3-5-sonnet-20240620-v1:0' - - constructor( - provider: AWS_BEDROCK_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - this.init() - } - - private getBedrockRegion(): string { - const provider = this.provider as AWS_BEDROCK_PROVIDER - return provider.credential?.region || process.env.BEDROCK_REGION || 'unknown-region' - } - - private buildBedrockStreamEndpoint(modelId: string): string { - const region = this.getBedrockRegion() - return `https://bedrock-runtime.${region}.amazonaws.com/model/${encodeURIComponent(modelId)}/invoke-with-response-stream` - } - - private decodeBedrockBody(body: unknown): unknown { - if (typeof body === 'string') { - try { - return JSON.parse(body) - } catch { - return body - } - } - if (body instanceof Uint8Array) { - const text = new TextDecoder().decode(body) - try { - return JSON.parse(text) - } catch { - return text - } - } - return body - } - - private applyPromptCache( - messages: Anthropic.MessageParam[], - modelId: string, - conversationId?: string - ): Anthropic.MessageParam[] { - const plan = resolvePromptCachePlan({ - providerId: this.provider.id, - apiType: 'anthropic', - modelId, - messages: messages as unknown[], - conversationId - }) - return applyAnthropicExplicitCacheBreakpoint(messages, plan) - } - - public onProxyResolved(): void { - this.init() - } - - protected async init() { - if (this.provider.enable) { - try { - const provider = this.provider as AWS_BEDROCK_PROVIDER - const accessKeyId = provider.credential?.accessKeyId || process.env.BEDROCK_ACCESS_KEY_ID - const secretAccessKey = - provider.credential?.secretAccessKey || process.env.BEDROCK_SECRET_ACCESS_KEY - const region = provider.credential?.region || process.env.BEDROCK_REGION - - if (!accessKeyId || !secretAccessKey || !region) { - throw new Error('Access Key Id, Secret Access Key and Region are all needed.') - } - - this.bedrock = new BedrockClient({ - credentials: { accessKeyId, secretAccessKey }, - region - }) - - this.bedrockRuntime = new BedrockRuntimeClient({ - credentials: { accessKeyId, secretAccessKey }, - region - }) - - await super.init() - } catch (error) { - console.error('Failed to initialize AWS Bedrock provider:', error) - } - } - } - - protected async fetchProviderModels(): Promise { - try { - const region = await this.bedrock.config.region() - const command = new ListFoundationModelsCommand({}) - const response = await this.bedrock.send(command) - const models = response.modelSummaries - - return ( - models - ?.filter((m) => m.modelId && /^anthropic.claude-[a-z0-9-]+(:\d+)$/g.test(m.modelId)) - ?.filter((m) => m.modelLifecycle?.status === 'ACTIVE') - ?.filter((m) => m.inferenceTypesSupported && m.inferenceTypesSupported.length > 0) - .map((m) => ({ - id: `${m.inferenceTypesSupported?.includes('ON_DEMAND') ? m.modelId! : `${region.split('-')[0]}.${m.modelId}`}`, - name: m.modelId?.replace('anthropic.', '') || '', - providerId: this.provider.id, - maxTokens: 64_000, - group: `AWS Bedrock Claude - ${m.modelId?.includes('opus') ? 'opus' : m.modelId?.includes('sonnet') ? 'sonnet' : m.modelId?.includes('haiku') ? 'haiku' : 'other'}`, - isCustom: false, - contextLength: 200_000, - vision: false, - functionCall: false, - reasoning: false - })) || [] - ) - } catch (error) { - console.error('获取AWS Bedrock Anthropic模型列表出错:', error) - } - - // 如果API请求失败或返回数据解析失败,优先使用聚合 Provider DB 的模型列表(仅筛选 Bedrock 上的 Anthropic 模型) - console.log('从API获取模型列表失败,使用 Provider DB 作为兜底(筛选 anthropic.*)') - const dbModels = this.configPresenter - .getDbProviderModels('amazon-bedrock') - .filter((m) => m.id.startsWith('anthropic.')) - .map((m) => ({ - id: m.id, - name: m.name, - providerId: this.provider.id, - maxTokens: m.maxTokens, - group: m.group || 'Bedrock Claude', - isCustom: false, - contextLength: m.contextLength, - vision: m.vision || false, - functionCall: m.functionCall || false, - reasoning: m.reasoning || false, - ...(m.type ? { type: m.type } : {}) - })) - - return dbModels - } - - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - if (!this.bedrockRuntime) { - return { isOk: false, errorMsg: '未初始化 AWS Bedrock SDK' } - } - - // 发送一个简单请求来检查 API 连接状态 - // Prepare the payload for the Messages API request. - const payload = { - anthropic_version: 'bedrock-2023-05-31', - max_tokens: 10, - messages: [ - { - role: 'user', - content: [{ type: 'text', text: 'Hi' }] - } - ] - } - const command = new InvokeModelCommand({ - contentType: 'application/json', - body: JSON.stringify(payload), - modelId: this.defaultModel - }) - const apiResponse = await this.bedrockRuntime.send(command) - - // Decode and return the response(s) - const decodedResponseBody = new TextDecoder().decode(apiResponse.body) - /** @type {MessagesResponseBody} */ - const responseBody = JSON.parse(decodedResponseBody) - const responseText = responseBody.content[0].text - - return { isOk: responseText.length > 0, errorMsg: null } - } catch (error: unknown) { - console.error('AWS Bedrock Claude API check failed:', error) - const errorMessage = error instanceof Error ? error.message : String(error) - return { isOk: false, errorMsg: `API 检查失败: ${errorMessage}` } - } - } - - // 依赖 generateText - public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { - const prompt = `${SUMMARY_TITLES_PROMPT}\n\n${messages.map((m) => `${m.role}: ${m.content}`).join('\n')}` - const response = await this.generateText(prompt, modelId, 0.3, 50) - - return response.content.trim() - } - - // 依赖 generateText - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number, - systemPrompt?: string - ): Promise { - const prompt = `请对以下内容进行摘要: - -${text} - -请提供一个简洁明了的摘要。` - - return this.generateText(prompt, modelId, temperature, maxTokens, systemPrompt) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number, - systemPrompt?: string - ): Promise { - try { - const payload = { - anthropic_version: 'bedrock-2023-05-31', - max_tokens: maxTokens, - temperature, - system: systemPrompt, - messages: [ - { - role: 'user', - content: [{ type: 'text', text: prompt }] - } - ] - } - const command = new InvokeModelCommand({ - contentType: 'application/json', - body: JSON.stringify(payload), - modelId - }) - - const response = await this.bedrockRuntime.send(command) - - // Decode and return the response(s) - const decodedResponseBody = new TextDecoder().decode(response.body) - /** @type {MessagesResponseBody} */ - const responseBody = JSON.parse(decodedResponseBody) - return { content: responseBody.content[0].text, reasoning_content: undefined } - } catch (error) { - console.error('AWS Bedrock generate text error:', error) - throw error - } - } - - private formatMessages(messages: ChatMessage[]): { - system?: string - messages: Anthropic.MessageParam[] - } { - // console.log('开始格式化消息,总消息数:', messages.length) - - // 提取系统消息 - let systemContent = '' - for (const msg of messages) { - if (msg.role === 'system') { - systemContent += - (typeof msg.content === 'string' - ? msg.content - : msg.content && Array.isArray(msg.content) - ? msg.content - .filter((c) => c.type === 'text') - .map((c) => c.text || '') - .join('\n') - : '') + '\n' - } - } - - // 定义消息组和内容块的类型 - type ContentBlock = Anthropic.ContentBlockParam - type ToolCall = { id: string; function: { name: string; arguments?: string } } - type MessageGroup = { - role: string - contents: ContentBlock[] - toolCalls?: string[] - hasToolUse?: boolean - } - - // 预处理:对消息进行分组 - // 新的逻辑:每个assistant消息如果包含tool_calls,就单独成组 - const messageGroups: MessageGroup[] = [] - let currentGroup: MessageGroup | null = null - - // 用于跟踪tool_calls和tool响应的匹配 - const toolResponseMap = new Map() - - // 第一阶段:建立初始分组和收集工具响应 - for (let i = 0; i < messages.length; i++) { - const msg = messages[i] - if (msg.role === 'system') continue // 系统消息已处理 - - // console.log( - // `处理第${i + 1}条消息, 角色:${msg.role}`, - // msg.content - // ? typeof msg.content === 'string' - // ? '内容长度:' + msg.content.length - // : '数组内容长度:' + (Array.isArray(msg.content) ? msg.content.length : 0) - // : '无内容' - // ) - - // 处理工具响应,将其与对应的工具调用关联 - if (msg.role === 'tool' && 'tool_call_id' in msg) { - const toolCallId = msg.tool_call_id as string - const responseContent = - typeof msg.content === 'string' - ? msg.content - : Array.isArray(msg.content) - ? JSON.stringify(msg.content) - : '' - - toolResponseMap.set(toolCallId, { - type: 'tool_result', - tool_use_id: toolCallId, - content: responseContent - } as ContentBlock) - - // console.log('记录tool响应,tool_call_id:', toolCallId) - continue - } - - // 处理用户消息 - 开始新组 - if (msg.role === 'user') { - if (currentGroup) { - messageGroups.push(currentGroup) - } - - let formattedContent: ContentBlock[] = [] - if (typeof msg.content === 'string') { - formattedContent = [{ type: 'text', text: msg.content }] - } else if (msg.content && Array.isArray(msg.content)) { - formattedContent = msg.content.map((c) => { - if (c.type === 'image_url' && c.image_url) { - return { - type: 'image', - source: c.image_url.url.startsWith('data:image') - ? { - type: 'base64', - data: c.image_url.url.split(',')[1], - media_type: c.image_url.url.split(';')[0].split(':')[1] as - | 'image/jpeg' - | 'image/png' - | 'image/gif' - | 'image/webp' - } - : { type: 'url', url: c.image_url.url } - } as ContentBlock - } else { - return { type: 'text', text: c.text || '' } as ContentBlock - } - }) - } - - currentGroup = { - role: 'user', - contents: formattedContent - } - - // console.log('开始新的用户消息组') - continue - } - - // 处理assistant消息 - 添加到当前组或开始新组 - if (msg.role === 'assistant') { - // 检查是否需要新建一个组: - // 1. 当前还没有组 - // 2. 当前组不是assistant - // 3. 当前组是assistant但包含了工具调用 - const shouldCreateNewGroup = - !currentGroup || currentGroup.role !== 'assistant' || currentGroup.hasToolUse === true - - if (shouldCreateNewGroup) { - if (currentGroup) { - messageGroups.push(currentGroup) - } - - currentGroup = { - role: 'assistant', - contents: [], - toolCalls: [], - hasToolUse: false - } - } - - // 确保currentGroup已初始化 - if (!currentGroup) { - currentGroup = { - role: 'assistant', - contents: [], - toolCalls: [], - hasToolUse: false - } - } - - // 处理常规内容 - if (msg.content) { - let assistantContent: ContentBlock[] = [] - if (typeof msg.content === 'string') { - if (msg.content.trim()) { - assistantContent = [{ type: 'text', text: msg.content }] - } - } else if (Array.isArray(msg.content)) { - // 处理各种内容类型 - for (const content of msg.content) { - if (content.type === 'text') { - currentGroup.contents.push({ - type: 'text', - text: content.text || '' - } as ContentBlock) - } else if (content.type === 'image_url' && content.image_url) { - currentGroup.contents.push({ - type: 'image', - source: content.image_url.url.startsWith('data:image') - ? { - type: 'base64', - data: content.image_url.url.split(',')[1], - media_type: content.image_url.url.split(';')[0].split(':')[1] as - | 'image/jpeg' - | 'image/png' - | 'image/gif' - | 'image/webp' - } - : { type: 'url', url: content.image_url.url } - } as ContentBlock) - } - } - - continue - } - - currentGroup.contents.push(...assistantContent) - // console.log('添加文本内容到当前assistant组, 项数:', assistantContent.length) - } - - // 处理tool_calls - if ('tool_calls' in msg && Array.isArray(msg.tool_calls)) { - // console.log('处理assistant消息中的tool_calls', msg.tool_calls.length) - - // 标记当前组包含工具调用 - if (currentGroup) { - currentGroup.hasToolUse = true - } - - for (const toolCall of msg.tool_calls as ToolCall[]) { - try { - // @ts-ignore - 转换为Anthropic格式 - currentGroup.contents.push({ - type: 'tool_use', - id: toolCall.id, - name: toolCall.function.name, - input: JSON.parse(toolCall.function.arguments || '{}') - } as ContentBlock) - - // console.log('添加tool_call到当前assistant组:', toolCall.function.name) - - // 记录工具调用,稍后查找响应 - if (!currentGroup.toolCalls) currentGroup.toolCalls = [] - currentGroup.toolCalls.push(toolCall.id) - } catch (e) { - console.error('Error processing tool_call:', e) - } - } - } - } - } - - // 添加最后一个组 - if (currentGroup) { - messageGroups.push(currentGroup) - } - - // console.log('预处理完成,消息组数量:', messageGroups.length) - - // 第二阶段:生成最终的格式化消息 - const formattedMessages: Anthropic.MessageParam[] = [] - - for (const group of messageGroups) { - if (group.contents.length === 0) continue - - // 添加组的主要内容 - formattedMessages.push({ - role: group.role as 'user' | 'assistant', - content: group.contents as Anthropic.ContentBlockParam[] - }) - - // console.log(`添加${group.role}组,内容项数:${group.contents.length}`) - - // 如果是assistant组且有工具调用,添加对应的工具响应 - if (group.role === 'assistant' && group.toolCalls && group.toolCalls.length > 0) { - for (const toolCallId of group.toolCalls) { - const toolResponse = toolResponseMap.get(toolCallId) - if (toolResponse) { - formattedMessages.push({ - role: 'user', - content: [toolResponse] - }) - - // console.log('添加工具响应,tool_call_id:', toolCallId) - } - } - } - } - - // console.log('格式化完成, 最终消息数:', formattedMessages.length) - // 为调试目的,打印前3条消息的结构 - // formattedMessages.slice(0, Math.min(3, formattedMessages.length)).forEach((msg, i) => { - // console.log(`最终消息#${i + 1}:`, { - // role: msg.role, - // contentLength: Array.isArray(msg.content) ? msg.content.length : 0, - // contentTypes: Array.isArray(msg.content) - // ? msg.content.map((c) => c.type).join(',') - // : typeof msg.content - // }) - // }) - - return { - system: systemContent || undefined, - messages: formattedMessages - } - } - - // 依赖 formatMessages - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - try { - if (!this.bedrockRuntime) { - throw new Error('AWS Bedrock client is not initialized') - } - - const formattedMessages = this.formatMessages(messages) - const cachedMessages = this.applyPromptCache(formattedMessages.messages, modelId) - - // 创建基本请求参数 - const payload = { - anthropic_version: 'bedrock-2023-05-31', - max_tokens: maxTokens, - temperature, - system: formattedMessages.system, - messages: cachedMessages - } - - const command = new InvokeModelCommand({ - contentType: 'application/json', - body: JSON.stringify(payload), - modelId - }) - - // 执行请求 - const response = (await this.bedrockRuntime.send(command)) as InvokeModelCommandOutput & { - usage?: any - } - - const resultResp: LLMResponse = { - content: '' - } - - // 添加usage信息 - if (response.usage) { - const usageSnapshot = buildBedrockUsageSnapshot(response.usage as CacheAwareBedrockUsage) - resultResp.totalUsage = { - prompt_tokens: usageSnapshot?.prompt_tokens ?? 0, - completion_tokens: usageSnapshot?.completion_tokens ?? 0, - total_tokens: usageSnapshot?.total_tokens ?? 0 - } - } - - // 获取文本内容 - // const content = response.content - // .filter((block) => block.type === 'text') - // .map((block) => (block.type === 'text' ? block.text : '')) - // .join('') - const decodedResponseBody = new TextDecoder().decode(response.body) - /** @type {MessagesResponseBody} */ - const responseBody = JSON.parse(decodedResponseBody) - const content = responseBody.content[0].text - - // 处理标签 - if (content.includes('')) { - const thinkStart = content.indexOf('') - const thinkEnd = content.indexOf('') - - if (thinkEnd > thinkStart) { - // 提取reasoning_content - resultResp.reasoning_content = content.substring(thinkStart + 7, thinkEnd).trim() - - // 合并前后的普通内容 - const beforeThink = content.substring(0, thinkStart).trim() - const afterThink = content.substring(thinkEnd + 8).trim() - resultResp.content = [beforeThink, afterThink].filter(Boolean).join('\n') - } else { - // 如果没有找到配对的结束标签,将所有内容作为普通内容 - resultResp.content = content - } - } else { - // 没有think标签,所有内容作为普通内容 - resultResp.content = content - } - - return resultResp - } catch (error) { - console.error('AWS Bedrock Claude completions error:', error) - throw error - } - } - - // 依赖 formatMessages - // 添加coreStream方法 - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!this.bedrockRuntime) throw new Error('AWS Bedrock client is not initialized') - if (!modelId) throw new Error('Model ID is required') - console.log('modelConfig', modelConfig, modelId) - try { - // 格式化消息 - const formattedMessagesObject = this.formatMessages(messages) - const cachedMessages = this.applyPromptCache( - formattedMessagesObject.messages, - modelId, - modelConfig.conversationId - ) - console.log('formattedMessagesObject', JSON.stringify(formattedMessagesObject)) - - // 将MCP工具转换为Anthropic工具格式 - const anthropicTools = - mcpTools.length > 0 - ? await this.mcpRuntime?.mcpToolsToAnthropicTools(mcpTools, this.provider.id) - : undefined - - // 创建基本请求参数 - // const streamParams = { - // model: modelId, - // max_tokens: maxTokens || 1024, - // temperature: temperature || 0.7, - // messages: formattedMessagesObject.messages, - // stream: true - // } as Anthropic.Messages.MessageCreateParamsStreaming - const payload = { - anthropic_version: 'bedrock-2023-05-31', - max_tokens: maxTokens || 1024, - temperature: temperature ?? 0.7, - system: formattedMessagesObject.system, - messages: cachedMessages, - thinking: undefined as any, - tools: undefined as any - } - const command = new InvokeModelWithResponseStreamCommand({ - contentType: 'application/json', - body: JSON.stringify(payload), - modelId - }) - - // 启用Claude 3.7模型的思考功能 - if (modelId.includes('claude-3-7')) { - payload.thinking = { budget_tokens: 1024, type: 'enabled' } - } - - // // 如果有系统消息,添加到请求参数中 - // if (formattedMessagesObject.system) { - // // @ts-ignore - system属性在类型定义中可能不存在,但API已支持 - // streamParams.system = formattedMessagesObject.system - // } - - // 添加工具参数 - if (anthropicTools && anthropicTools.length > 0) { - // @ts-ignore - 类型不匹配,但格式是正确的 - payload.tools = anthropicTools - } - - await this.emitRequestTrace(modelConfig, { - endpoint: this.buildBedrockStreamEndpoint(modelId), - headers: { - 'Content-Type': 'application/json', - 'x-aws-region': this.getBedrockRegion() - }, - body: this.decodeBedrockBody(command.input.body) - }) - // 创建Anthropic流 - const response = await this.bedrockRuntime.send(command) - const body = await response.body - if (!body) { - throw new Error('No response body from AWS Bedrock') - } - - // 状态变量 - let accumulatedJson = '' - let toolUseDetected = false - let currentToolId = '' - let currentToolName = '' - let currentToolInputs: Record = {} - let usageMetadata: Usage | undefined - // 处理流中的各种事件 - for await (const item of body) { - if (!item.chunk) continue - - const chunk = JSON.parse(new TextDecoder().decode(item.chunk.bytes)) - // 处理使用统计 - if (chunk.type === 'message_start' && chunk.message.usage) { - usageMetadata = chunk.message.usage - } - - // 处理工具调用开始 - // @ts-ignore - Anthropic SDK类型定义不完整 - if (chunk.type === 'content_block_start' && chunk.content_block?.type === 'tool_use') { - toolUseDetected = true - // @ts-ignore - content_block不在类型定义中 - currentToolId = chunk.content_block.id || `anthropic-tool-${Date.now()}` - // @ts-ignore - content_block不在类型定义中 - currentToolName = chunk.content_block.name || '' - currentToolInputs = {} - accumulatedJson = '' - - // 发送工具调用开始事件 - if (currentToolName) { - yield { - type: 'tool_call_start', - tool_call_id: currentToolId, - tool_call_name: currentToolName - } - } - continue - } - - // 处理工具调用参数更新 - input_json_delta - // @ts-ignore - 类型定义中没有工具相关字段 - if (chunk.type === 'content_block_delta' && chunk.delta?.type === 'input_json_delta') { - // @ts-ignore - partial_json不在类型定义中 - const partialJson = chunk.delta.partial_json - if (partialJson) { - accumulatedJson += partialJson - - // 发送工具调用参数块事件 - yield { - type: 'tool_call_chunk', - tool_call_id: currentToolId, - tool_call_arguments_chunk: partialJson - } - } - continue - } - - // 处理工具使用更新 - tool_use_delta - // @ts-ignore - 类型定义中没有工具相关字段 - if (chunk.type === 'content_block_delta' && chunk.delta?.type === 'tool_use_delta') { - // @ts-ignore - delta.name不在类型定义中 - if (chunk.delta.name && !currentToolName) { - // @ts-ignore - 访问delta.name - currentToolName = chunk.delta.name - yield { - type: 'tool_call_start', - tool_call_id: currentToolId, - tool_call_name: currentToolName - } - } - - // @ts-ignore - delta.input不在类型定义中 - if (chunk.delta.input) { - currentToolInputs = { - ...currentToolInputs, - // @ts-ignore - 访问delta.input - ...chunk.delta.input - } - } - continue - } - - // 处理内容块结束 - if (chunk.type === 'content_block_stop') { - // 处理工具调用完成 - if (toolUseDetected && currentToolName && accumulatedJson) { - try { - // 尝试解析完整的JSON - const jsonStr = accumulatedJson.trim() - if (jsonStr && (jsonStr.startsWith('{') || jsonStr.startsWith('['))) { - try { - const jsonObject = JSON.parse(jsonStr) - if (jsonObject && typeof jsonObject === 'object') { - currentToolInputs = { ...currentToolInputs, ...jsonObject } - } - } catch (e) { - console.error('解析完整JSON失败:', e) - } - } - } catch (e) { - console.error('处理累积JSON失败:', e) - } - - // 发送工具调用结束事件 - const argsString = JSON.stringify(currentToolInputs) - yield { - type: 'tool_call_end', - tool_call_id: currentToolId, - tool_call_arguments_complete: argsString - } - - // 重置工具调用状态 - accumulatedJson = '' - } - continue - } - - // 检查消息是否因为工具调用而停止 - if (chunk.type === 'message_delta' && chunk.delta?.stop_reason === 'tool_use') { - // 设置为工具使用停止,主循环会处理工具调用 - continue - } - - // 处理思考内容(如果有) - // @ts-ignore - 类型定义中没有thinking相关字段 - if (chunk.type === 'content_block_delta' && chunk.delta?.type === 'thinking_delta') { - // @ts-ignore - delta.thinking不在类型定义中 - const thinkingText = chunk.delta.thinking - if (thinkingText) { - yield { - type: 'reasoning', - reasoning_content: thinkingText - } - } - continue - } - - // 处理常规文本内容 - if (chunk.type === 'content_block_delta' && chunk.delta.type === 'text_delta') { - const text = chunk.delta.text - if (text) { - // 处理标签 - if (text.includes('')) { - const parts = text.split('') - if (parts[0]) { - yield { - type: 'text', - content: parts[0] - } - } - - if (parts[1]) { - // 检查是否包含 - const thinkParts = parts[1].split('') - if (thinkParts.length > 1) { - yield { - type: 'reasoning', - reasoning_content: thinkParts[0] - } - - if (thinkParts[1]) { - yield { - type: 'text', - content: thinkParts[1] - } - } - } else { - yield { - type: 'reasoning', - reasoning_content: parts[1] - } - } - } - } else if (text.includes('')) { - const parts = text.split('') - yield { - type: 'reasoning', - reasoning_content: parts[0] - } - - if (parts[1]) { - yield { - type: 'text', - content: parts[1] - } - } - } else { - yield { - type: 'text', - content: text - } - } - } - continue - } - } - if (usageMetadata) { - const usageSnapshot = buildBedrockUsageSnapshot(usageMetadata as CacheAwareBedrockUsage) - if (usageSnapshot) { - yield createStreamEvent.usage(usageSnapshot) - } - } - // 发送停止事件 - yield createStreamEvent.stop(toolUseDetected ? 'tool_use' : 'complete') - } catch (error) { - console.error('AWS Bedrock Claude coreStream error:', error) - yield createStreamEvent.error(error instanceof Error ? error.message : '未知错误') - yield createStreamEvent.stop('error') - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/cherryInProvider.ts b/src/main/presenter/llmProviderPresenter/providers/cherryInProvider.ts deleted file mode 100644 index a80a3fa36..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/cherryInProvider.ts +++ /dev/null @@ -1,92 +0,0 @@ -import { LLM_PROVIDER, MODEL_META, IConfigPresenter, KeyStatus } from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -interface CherryInUsageResponse { - total_usage: number -} - -export class CherryInProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - private getBaseUrl(): string { - return (this.provider.baseUrl || 'https://open.cherryin.ai/v1').replace(/\/$/, '') - } - - public async getKeyStatus(): Promise { - if (!this.provider.apiKey) { - throw new Error('API key is required') - } - - const baseUrl = this.getBaseUrl() - const headers = { - Authorization: `Bearer ${this.provider.apiKey}`, - 'Content-Type': 'application/json' - } - - const usageResponse = await fetch(`${baseUrl}/dashboard/billing/usage`, { - method: 'GET', - headers - }) - - if (!usageResponse.ok) { - const errorText = await usageResponse.text() - throw new Error( - `CherryIn usage check failed: ${usageResponse.status} ${usageResponse.statusText} - ${errorText}` - ) - } - - const usageData: CherryInUsageResponse = await usageResponse.json() - - const totalUsage = Number(usageData?.total_usage) - - const usageUsd = Number.isFinite(totalUsage) ? totalUsage / 100 : 0 - - return { - usage: `$${usageUsd.toFixed(2)}` - } - } - - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - await this.getKeyStatus() - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - let errorMessage = 'An unknown error occurred during CherryIn API key check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - - console.error('CherryIn API key check failed:', error) - return { isOk: false, errorMsg: errorMessage } - } - } - - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - try { - const models = await super.fetchOpenAIModels(options) - if (models.length > 0) { - return models.map((model) => ({ - ...model, - group: model.group === 'default' ? 'cherryin' : model.group, - providerId: this.provider.id - })) - } - } catch (error) { - console.warn( - '[CherryInProvider] Failed to fetch models via API, falling back to defaults', - error - ) - } - - return [] - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/dashscopeProvider.ts b/src/main/presenter/llmProviderPresenter/providers/dashscopeProvider.ts deleted file mode 100644 index f1f7153e3..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/dashscopeProvider.ts +++ /dev/null @@ -1,133 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - MODEL_META, - ChatMessage, - IConfigPresenter, - LLMCoreStreamEvent, - ModelConfig, - MCPToolDefinition -} from '@shared/presenter' -import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS } from '@shared/modelConfigDefaults' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import { modelCapabilities } from '../../configPresenter/modelCapabilities' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -export class DashscopeProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - private supportsEnableThinking(modelId: string): boolean { - return modelCapabilities.supportsReasoning(this.provider.id, modelId) - } - - /** - * Override coreStream method to support DashScope's enable_thinking and enable_search parameters - */ - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!this.isInitialized) throw new Error('Provider not initialized') - if (!modelId) throw new Error('Model ID is required') - - const shouldAddEnableThinking = this.supportsEnableThinking(modelId) && modelConfig?.reasoning - const chatCompletions = this.openai.chat.completions - const originalCreate = chatCompletions.create - - if (shouldAddEnableThinking) { - // Original create method - const originalCreateWithContext = originalCreate.bind(chatCompletions) - // Replace create method to add enable_thinking parameter - chatCompletions.create = ((params: any, options?: any) => { - const modifiedParams = { ...params } - - modifiedParams.enable_thinking = true - const dbBudget = modelCapabilities.getThinkingBudgetRange(this.provider.id, modelId).default - const budget = modelConfig?.thinkingBudget ?? dbBudget - if (typeof budget === 'number') { - modifiedParams.thinking_budget = budget - } - - return originalCreateWithContext(modifiedParams, options) - }) as typeof chatCompletions.create - } - - try { - yield* super.coreStream(messages, modelId, modelConfig, temperature, maxTokens, mcpTools) - } finally { - if (shouldAddEnableThinking) { - chatCompletions.create = originalCreate - } - } - } - - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - const response = await this.openai.models.list(options) - return response.data.map((model) => ({ - id: model.id, - name: model.id, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, - maxTokens: DEFAULT_MODEL_MAX_TOKENS - })) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `Please summarize the following content using concise language and highlighting key points:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/deepseekProvider.ts b/src/main/presenter/llmProviderPresenter/providers/deepseekProvider.ts deleted file mode 100644 index d3eb5a765..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/deepseekProvider.ts +++ /dev/null @@ -1,158 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - ChatMessage, - KeyStatus, - IConfigPresenter -} from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import { SUMMARY_TITLES_PROMPT } from '../baseProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Define interface for DeepSeek API key response -interface DeepSeekBalanceResponse { - is_available: boolean - balance_infos: Array<{ - currency: string - total_balance: string - granted_balance: string - topped_up_balance: string - }> -} - -export class DeepseekProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `${SUMMARY_TITLES_PROMPT}\n\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - - /** - * Get current API key status from DeepSeek - * @returns Promise API key status information - */ - public async getKeyStatus(): Promise { - if (!this.provider.apiKey) { - throw new Error('API key is required') - } - - const response = await fetch('https://api.deepseek.com/user/balance', { - method: 'GET', - headers: { - Accept: 'application/json', - Authorization: `Bearer ${this.provider.apiKey}` - } - }) - - if (!response.ok) { - const errorText = await response.text() - throw new Error( - `DeepSeek API key check failed: ${response.status} ${response.statusText} - ${errorText}` - ) - } - - const balanceResponse: DeepSeekBalanceResponse = await response.json() - - if (!balanceResponse.is_available) { - throw new Error('DeepSeek API key is not available') - } - - // Find CNY balance info first, then USD, then default to first available - const balanceInfo = - balanceResponse.balance_infos.find((info) => info.currency === 'CNY') || - balanceResponse.balance_infos.find((info) => info.currency === 'USD') || - balanceResponse.balance_infos[0] - - if (!balanceInfo) { - throw new Error('No balance information available') - } - - const totalBalance = parseFloat(balanceInfo.total_balance) - const currencySymbol = balanceInfo.currency === 'USD' ? '$' : '¥' - - // Map to unified KeyStatus format - return { - limit_remaining: `${currencySymbol}${totalBalance}`, - remainNum: totalBalance - } - } - - /** - * Override check method to use DeepSeek's API key status endpoint - * @returns Promise<{ isOk: boolean; errorMsg: string | null }> - */ - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - const keyStatus = await this.getKeyStatus() - - // Check if there's remaining quota - if (keyStatus.remainNum !== undefined && keyStatus.remainNum <= 0) { - return { - isOk: false, - errorMsg: `API key quota exhausted. Remaining: ${keyStatus.limit_remaining}` - } - } - - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - let errorMessage = 'An unknown error occurred during DeepSeek API key check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - - console.error('DeepSeek API key check failed:', error) - return { isOk: false, errorMsg: errorMessage } - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/doubaoProvider.ts b/src/main/presenter/llmProviderPresenter/providers/doubaoProvider.ts deleted file mode 100644 index f278913d8..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/doubaoProvider.ts +++ /dev/null @@ -1,144 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - MODEL_META, - ChatMessage, - IConfigPresenter, - LLMCoreStreamEvent, - ModelConfig, - MCPToolDefinition -} from '@shared/presenter' -import { ModelType } from '@shared/model' -import { - resolveModelContextLength, - resolveModelFunctionCall, - resolveModelMaxTokens -} from '@shared/modelConfigDefaults' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import { providerDbLoader } from '../../configPresenter/providerDbLoader' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -const DOUBAO_THINKING_NOTE = 'doubao-thinking-parameter' - -export class DoubaoProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - // Initialize Doubao model configuration - super(provider, configPresenter, mcpRuntime) - } - - private supportsThinking(modelId: string): boolean { - const model = providerDbLoader.getModel(this.provider.id, modelId) - const notes = model?.extra_capabilities?.reasoning?.notes - return Array.isArray(notes) && notes.includes(DOUBAO_THINKING_NOTE) - } - - /** - * Override coreStream method to support Doubao's thinking parameter - */ - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!this.isInitialized) throw new Error('Provider not initialized') - if (!modelId) throw new Error('Model ID is required') - - const shouldAddThinking = this.supportsThinking(modelId) && modelConfig?.reasoning - - if (shouldAddThinking) { - // Original create method - const originalCreate = this.openai.chat.completions.create.bind(this.openai.chat.completions) - // Replace create method to add thinking parameter - this.openai.chat.completions.create = ((params: any, options?: any) => { - const modifiedParams = { - ...params, - thinking: { - type: 'enabled' - } - } - return originalCreate(modifiedParams, options) - }) as any - - try { - const effectiveModelConfig = { ...modelConfig, reasoning: false } - yield* super.coreStream( - messages, - modelId, - effectiveModelConfig, - temperature, - maxTokens, - mcpTools - ) - } finally { - this.openai.chat.completions.create = originalCreate - } - } else { - yield* super.coreStream(messages, modelId, modelConfig, temperature, maxTokens, mcpTools) - } - } - - protected async fetchOpenAIModels(): Promise { - const provider = providerDbLoader.getProvider(this.provider.id) - if (!provider || !Array.isArray(provider.models)) { - return [] - } - - return provider.models.map((model) => { - const inputs = model.modalities?.input - const outputs = model.modalities?.output - const hasImageInput = Array.isArray(inputs) && inputs.includes('image') - const hasImageOutput = Array.isArray(outputs) && outputs.includes('image') - const modelType = hasImageOutput ? ModelType.ImageGeneration : ModelType.Chat - - return { - id: model.id, - name: model.display_name || model.name || model.id, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: resolveModelContextLength(model.limit?.context), - maxTokens: resolveModelMaxTokens(model.limit?.output), - vision: hasImageInput, - functionCall: resolveModelFunctionCall(model.tool_call), - reasoning: Boolean(model.reasoning?.supported), - enableSearch: Boolean(model.search?.supported), - type: modelType - } - }) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/geminiProvider.ts b/src/main/presenter/llmProviderPresenter/providers/geminiProvider.ts deleted file mode 100644 index c4f0680d8..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/geminiProvider.ts +++ /dev/null @@ -1,1133 +0,0 @@ -import { - Content, - FunctionCallingConfigMode, - GenerateContentParameters, - GenerateContentResponseUsageMetadata, - GoogleGenAI, - HarmBlockThreshold, - HarmCategory, - Modality, - Part, - SafetySetting, - Tool, - GenerateContentConfig -} from '@google/genai' -import { ModelType } from '@shared/model' -import { - ChatMessage, - IConfigPresenter, - LLM_PROVIDER, - LLMCoreStreamEvent, - LLMResponse, - MCPToolDefinition, - MODEL_META, - ModelConfig -} from '@shared/presenter' -import { createStreamEvent } from '@shared/types/core/llm-events' -import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider' -import { modelCapabilities } from '../../configPresenter/modelCapabilities' -import { eventBus, SendTarget } from '@/eventbus' -import { CONFIG_EVENTS } from '@/events' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Mapping from simple keys to API HarmCategory constants -const keyToHarmCategoryMap: Record = { - harassment: HarmCategory.HARM_CATEGORY_HARASSMENT, - hateSpeech: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - sexuallyExplicit: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - dangerousContent: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT -} - -// Value mapping from config storage to API HarmBlockThreshold constants -// Assuming config stores 'BLOCK_NONE', 'BLOCK_LOW_AND_ABOVE', etc. directly -const valueToHarmBlockThresholdMap: Record = { - BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE, - BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED -} -const safetySettingKeys = Object.keys(keyToHarmCategoryMap) - -export class GeminiProvider extends BaseLLMProvider { - private genAI: GoogleGenAI - - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - this.genAI = new GoogleGenAI({ - apiKey: this.provider.apiKey, - httpOptions: { baseUrl: this.provider.baseUrl } - }) - this.init() - } - - public onProxyResolved(): void { - this.init() - } - - // 确保带有 models/ 前缀 - private ensureGoogleModelName(modelId: string): string { - return modelId?.startsWith('models/') ? modelId : `models/${modelId}` - } - - private buildGeminiStreamEndpoint(modelId: string): string { - const baseUrl = (this.provider.baseUrl || 'https://generativelanguage.googleapis.com').replace( - /\/+$/, - '' - ) - const normalizedModel = this.ensureGoogleModelName(modelId).replace(/^\/+/, '') - return `${baseUrl}/v1beta/${normalizedModel}:streamGenerateContent` - } - - private buildGeminiTraceHeaders(): Record { - return { - 'Content-Type': 'application/json', - 'x-goog-api-key': this.provider.apiKey || 'MISSING_API_KEY' - } - } - - // Implement abstract method fetchProviderModels from BaseLLMProvider - protected async fetchProviderModels(): Promise { - try { - const modelsResponse = await this.genAI.models.list() - // console.log('gemini models response:', modelsResponse) - - // 将 pager 转换为数组 - const models: any[] = [] - for await (const model of modelsResponse) { - models.push(model) - } - - if (models.length === 0) { - console.warn('No models found in Gemini API response, using Provider DB models') - const dbModels = this.configPresenter.getDbProviderModels(this.provider.id).map((m) => ({ - id: m.id, - name: m.name, - group: m.group || 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: m.contextLength, - maxTokens: m.maxTokens, - vision: m.vision || false, - functionCall: m.functionCall || false, - reasoning: m.reasoning || false, - ...(m.type ? { type: m.type } : {}) - })) - return dbModels - } - - // 映射 API 返回的模型数据(能力统一读 Provider DB) - const normalizeModelId = (mid: string): string => String(mid || '').replace(/^models\//i, '') - const apiModels: MODEL_META[] = models - .filter((model: any) => { - const name = String(model.name || '').toLowerCase() - return ( - !name.includes('embedding') && - !name.includes('aqa') && - !name.includes('text-embedding') && - !name.includes('gemma-3n-e4b-it') - ) - }) - .map((model: any) => { - const apiModelId: string = model.name - const displayName: string = model.displayName || apiModelId - - const normalizedId = normalizeModelId(apiModelId) - - const vision = modelCapabilities.supportsVision(this.provider.id, normalizedId) - const functionCall = modelCapabilities.supportsToolCall(this.provider.id, normalizedId) - const reasoning = modelCapabilities.supportsReasoning(this.provider.id, normalizedId) - const isImageOutput = modelCapabilities.supportsImageOutput( - this.provider.id, - normalizedId - ) - const modelType = isImageOutput ? ModelType.ImageGeneration : ModelType.Chat - - let group = 'default' - if (/\b(exp|preview)\b/i.test(apiModelId)) group = 'experimental' - else if (/\bgemma\b/i.test(apiModelId)) group = 'gemma' - - return { - id: apiModelId, - name: displayName, - group, - providerId: this.provider.id, - isCustom: false, - contextLength: model.inputTokenLimit, - maxTokens: model.outputTokenLimit, - vision, - functionCall, - reasoning, - ...(modelType !== ModelType.Chat && { type: modelType }) - } as MODEL_META - }) - - // console.log('Mapped Gemini models:', apiModels) - return apiModels - } catch (error) { - console.warn('Failed to fetch models from Gemini API:', error) - // If API call fails, fallback to Provider DB mapping - const dbModels = this.configPresenter.getDbProviderModels(this.provider.id).map((m) => ({ - id: m.id, - name: m.name, - group: m.group || 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: m.contextLength, - maxTokens: m.maxTokens, - vision: m.vision || false, - functionCall: m.functionCall || false, - reasoning: m.reasoning || false, - ...(m.type ? { type: m.type } : {}) - })) - return dbModels - } - } - - // Implement summaryTitles abstract method from BaseLLMProvider - public async summaryTitles( - messages: { role: 'system' | 'user' | 'assistant'; content: string }[], - modelId: string - ): Promise { - console.log('gemini ignore modelId', modelId) - // Use Gemini API to generate conversation titles - try { - const conversationText = messages.map((m) => `${m.role}: ${m.content}`).join('\n') - const prompt = `${SUMMARY_TITLES_PROMPT}\n\n${conversationText}` - - const result = await this.genAI.models.generateContent({ - model: this.ensureGoogleModelName(modelId), - contents: [{ role: 'user', parts: [{ text: prompt }] }], - config: this.getGenerateContentConfig(0.4, undefined, modelId, false) - }) - - return result.text?.trim() || 'New Conversation' - } catch (error) { - console.error('Failed to generate conversation title:', error) - return 'New Conversation' - } - } - - // Override fetchModels method since Gemini doesn't have a model fetching API - async fetchModels(): Promise { - // Gemini没有获取模型的API,直接使用init方法中的硬编码模型列表 - return this.models - } - - // Override check method to use the first default model for testing - async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - if (!this.provider.apiKey) { - return { isOk: false, errorMsg: 'Missing API key' } - } - - // Use the first model for simple testing - const testModelId = - this.models.find((m) => m.type !== ModelType.ImageGeneration)?.id || - this.models[0]?.id || - 'gemini-2.0-flash' - - const result = await this.genAI.models.generateContent({ - model: this.ensureGoogleModelName(testModelId), - contents: [{ role: 'user', parts: [{ text: 'Hello' }] }] - }) - return { isOk: result && result.text ? true : false, errorMsg: null } - } catch (error) { - console.error('Provider check failed:', this.provider.name, error) - return { isOk: false, errorMsg: error instanceof Error ? error.message : String(error) } - } - } - - protected async init() { - if (this.provider.enable) { - try { - this.isInitialized = true - // Use API to get model list, fallback to static list if failed - this.models = await this.fetchProviderModels() - await this.autoEnableModelsIfNeeded() - // Gemini is relatively slow, special compensation - eventBus.sendToRenderer( - CONFIG_EVENTS.MODEL_LIST_CHANGED, - SendTarget.ALL_WINDOWS, - this.provider.id - ) - console.info('Provider initialized successfully:', this.provider.name) - } catch (error) { - console.warn('Provider initialization failed:', this.provider.name, error) - } - } - } - - /** - * 重写 autoEnableModelsIfNeeded 方法 - * 不自动启用模型,交由用户手动选择。 - */ - protected async autoEnableModelsIfNeeded() { - if (!this.models || this.models.length === 0) return - const providerId = this.provider.id - - // 检查是否有自定义模型 - const customModels = this.configPresenter.getCustomModels(providerId) - if (customModels && customModels.length > 0) return - - // 检查是否有任何模型的状态被手动修改过 - const hasManuallyModifiedModels = this.models.some((model) => - this.configPresenter.getModelStatus(providerId, model.id) - ) - if (hasManuallyModifiedModels) return - - // 检查是否有任何已启用的模型 - const hasEnabledModels = this.models.some((model) => - this.configPresenter.getModelStatus(providerId, model.id) - ) - - // 不再自动启用模型,让用户手动选择启用需要的模型 - if (!hasEnabledModels) { - console.info( - `Provider ${this.provider.name} models loaded, please manually enable the models you need` - ) - } - } - - // Helper function to get and format safety settings - private async getFormattedSafetySettings(): Promise { - const safetySettings: SafetySetting[] = [] - - for (const key of safetySettingKeys) { - try { - // Use configPresenter to get the setting value for the 'gemini' provider - // Assuming getSetting returns the string value like 'BLOCK_MEDIUM_AND_ABOVE' - const settingValue = - (await this.configPresenter.getSetting( - `geminiSafety_${key}` // Match the key used in settings store - )) || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' // Default if not set - - const threshold = valueToHarmBlockThresholdMap[settingValue] - const category = keyToHarmCategoryMap[key] - - // Only add if threshold is defined, category is defined, and threshold is not BLOCK_NONE - if ( - threshold && - category && - threshold !== 'BLOCK_NONE' && - threshold !== 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' - ) { - safetySettings.push({ category, threshold }) - } - } catch (error) { - console.warn(`Failed to retrieve or map safety setting for ${key}:`, error) - } - } - - return safetySettings.length > 0 ? safetySettings : undefined - } - - // 判断模型是否支持 thinkingBudget - private supportsThinkingBudget(modelId: string): boolean { - const normalized = modelId.replace(/^models\//i, '') - const range = modelCapabilities.getThinkingBudgetRange( - this.getCapabilityProviderId(), - normalized - ) - return ( - typeof range.default === 'number' || - typeof range.min === 'number' || - typeof range.max === 'number' - ) - } - - // 获取生成配置,不再创建模型实例 - private getGenerateContentConfig( - temperature?: number, - maxTokens?: number, - modelId?: string, - reasoning?: boolean, - thinkingBudget?: number - ): GenerateContentConfig { - const config: GenerateContentConfig = { - temperature, - maxOutputTokens: maxTokens, - topP: 1 // topP默认为1.0 - } - - // 从当前模型列表中查找指定的模型 - if (modelId && this.models) { - const model = this.models.find((m) => m.id === modelId) - if (model && model.type === ModelType.ImageGeneration) { - config.responseModalities = [Modality.TEXT, Modality.IMAGE] - } - } - - // 正确配置思考功能 - if (reasoning) { - config.thinkingConfig = { - includeThoughts: true - } - - // 仅对支持 thinkingBudget 的 Gemini 2.5 系列模型添加 thinkingBudget 参数 - if (modelId && this.supportsThinkingBudget(modelId) && thinkingBudget !== undefined) { - config.thinkingConfig.thinkingBudget = thinkingBudget - } - } - - return config - } - - // 将 ChatMessage 转换为 Gemini 格式的消息 - private formatGeminiMessages(messages: ChatMessage[]): { - systemInstruction: string - contents: Content[] - } { - // 提取系统消息 - const systemMessages = messages.filter((msg) => msg.role === 'system') - let systemContent = '' - if (systemMessages.length > 0) { - systemContent = systemMessages.map((msg) => msg.content).join('\n') - } - - // 创建Gemini内容数组 - const formattedContents: Content[] = [] - - // 处理非系统消息 - const nonSystemMessages = messages.filter((msg) => msg.role !== 'system') - for (let i = 0; i < nonSystemMessages.length; i++) { - const message = nonSystemMessages[i] - - // 检查是否是带有tool_calls的assistant消息 - if (message.role === 'assistant' && 'tool_calls' in message) { - // 处理tool_calls消息 - for (const toolCall of message.tool_calls || []) { - // 添加模型发出的函数调用 - formattedContents.push({ - role: 'model', - parts: [ - { - functionCall: { - name: toolCall.function.name, - args: JSON.parse(toolCall.function.arguments || '{}') - } - } - ] - }) - - // 查找对应的工具响应消息 - const nextMessage = i + 1 < nonSystemMessages.length ? nonSystemMessages[i + 1] : null - if ( - nextMessage && - nextMessage.role === 'tool' && - 'tool_call_id' in nextMessage && - nextMessage.tool_call_id === toolCall.id - ) { - // 添加用户角色的函数响应 - formattedContents.push({ - role: 'user', - parts: [ - { - functionResponse: { - name: toolCall.function.name, - response: { - result: - typeof nextMessage.content === 'string' - ? nextMessage.content - : JSON.stringify(nextMessage.content) - } - } - } - ] - }) - - // 跳过下一条消息,因为已经处理过了 - i++ - } - } - continue - } - - // 为每条消息创建parts数组 - const parts: Part[] = [] - - // 检查消息是否包含工具调用或工具响应 - if (message.role === 'tool' && Array.isArray(message.content)) { - // 处理工具消息 - for (const part of message.content) { - // @ts-ignore - 处理类型兼容性 - if (part.type === 'function_call' && part.function_call) { - // 处理函数调用 - parts.push({ - // @ts-ignore - 处理类型兼容性 - functionCall: { - // @ts-ignore - 处理类型兼容性 - name: part.function_call.name || '', - // @ts-ignore - 处理类型兼容性 - args: part.function_call.arguments ? JSON.parse(part.function_call.arguments) : {} - } - }) - // @ts-ignore - 处理类型兼容性 - } else if (part.type === 'function_response') { - // 处理函数响应 - // @ts-ignore - 处理类型兼容性 - parts.push({ text: part.function_response || '' }) - } - } - } else if (typeof message.content === 'string') { - // 处理消息内容 - 可能是字符串或包含图片的数组 - // 处理纯文本消息 - // 只添加非空文本 - if (message.content.trim() !== '') { - parts.push({ text: message.content }) - } - } else if (Array.isArray(message.content)) { - // 处理多模态消息(带图片等) - for (const part of message.content) { - if (part.type === 'text') { - // 只添加非空文本 - if (part.text && part.text.trim() !== '') { - parts.push({ text: part.text }) - } - } else if (part.type === 'image_url' && part.image_url) { - // 处理图片(假设是 base64 格式) - const matches = part.image_url.url.match(/^data:([^;]+);base64,(.+)$/) - if (matches && matches.length === 3) { - const mimeType = matches[1] - const base64Data = matches[2] - parts.push({ - inlineData: { - data: base64Data, - mimeType: mimeType - } - }) - } - } - } - } - - // 只有当parts不为空时,才添加到formattedContents中 - if (parts.length > 0) { - // 将消息角色转换为Gemini支持的角色 - let role: 'user' | 'model' = 'user' - if (message.role === 'assistant') { - role = 'model' - } else if (message.role === 'tool') { - // 工具消息作为用户消息处理 - role = 'user' - } - - formattedContents.push({ - role: role, - parts: parts - }) - } - } - - return { systemInstruction: systemContent, contents: formattedContents } - } - - // 处理 Gemini API 响应,支持新旧格式的思考内容 - private processGeminiResponse(result: any): LLMResponse { - const resultResp: LLMResponse = { - content: '' - } - - let textContent = '' - let thoughtContent = '' - - // 检查是否有候选响应和 parts - if (result.candidates && result.candidates[0]?.content?.parts) { - for (const part of result.candidates[0].content.parts) { - // 检查是否是思考内容 (新格式) - if ((part as any).thought === true && part.text) { - thoughtContent += part.text - } else if (part.text) { - textContent += part.text - } - } - } else { - // 回退到使用 result.text - textContent = result.text || '' - } - - // 如果没有检测到新格式的思考内容,检查旧格式的 标签 - if (!thoughtContent && textContent.includes('')) { - const thinkStart = textContent.indexOf('') - const thinkEnd = textContent.indexOf('') - - if (thinkEnd > thinkStart) { - // 提取reasoning_content - thoughtContent = textContent.substring(thinkStart + 7, thinkEnd).trim() - - // 合并前后的普通内容 - const beforeThink = textContent.substring(0, thinkStart).trim() - const afterThink = textContent.substring(thinkEnd + 8).trim() - textContent = [beforeThink, afterThink].filter(Boolean).join('\n') - } - } - - resultResp.content = textContent - if (thoughtContent) { - resultResp.reasoning_content = thoughtContent - } - - return resultResp - } - - // 实现抽象方法 - async completions( - messages: { role: 'system' | 'user' | 'assistant'; content: string }[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - try { - if (!this.genAI) { - throw new Error('Google Generative AI client is not initialized') - } - - const { systemInstruction, contents } = this.formatGeminiMessages(messages) - - // 创建 GenerateContentConfig - const generateContentConfig: GenerateContentConfig = this.getGenerateContentConfig( - temperature ?? 0.7, - maxTokens, - modelId, - false // completions 方法中不处理 reasoning - ) - - if (systemInstruction) { - generateContentConfig.systemInstruction = systemInstruction - } - - // 一次性创建 requestParams - const requestParams: GenerateContentParameters = { - model: modelId, - contents, - config: generateContentConfig - } - - const result = await this.genAI.models.generateContent({ - ...requestParams, - model: this.ensureGoogleModelName(requestParams.model as string) - }) - - const resultResp: LLMResponse = { - content: '' - } - - // 尝试获取tokens信息 - 使用新SDK的usageMetadata结构 - try { - if (result.usageMetadata) { - const usage = result.usageMetadata - resultResp.totalUsage = { - prompt_tokens: usage.promptTokenCount || 0, - completion_tokens: usage.candidatesTokenCount || 0, - total_tokens: usage.totalTokenCount || 0 - } - } else { - // 估算token数量 - 简单方法,可以根据实际需要调整 - const promptText = messages.map((m) => m.content).join(' ') - const responseText = result.text || '' - - // 简单估算: 英文约1个token/4个字符,中文约1个token/1.5个字符 - const estimateTokens = (text: string): number => { - const chineseCharCount = (text.match(/[\u4e00-\u9fa5]/g) || []).length - const otherCharCount = text.length - chineseCharCount - return Math.ceil(chineseCharCount / 1.5 + otherCharCount / 4) - } - - const promptTokens = estimateTokens(promptText) - const completionTokens = estimateTokens(responseText) - - resultResp.totalUsage = { - prompt_tokens: promptTokens, - completion_tokens: completionTokens, - total_tokens: promptTokens + completionTokens - } - } - } catch (e) { - console.warn('Failed to estimate token count for Gemini response', e) - } - - // 处理响应内容,支持新格式的思考内容 - let textContent = '' - let thoughtContent = '' - - // 检查是否有候选响应和 parts - if (result.candidates && result.candidates[0]?.content?.parts) { - for (const part of result.candidates[0].content.parts) { - // 检查是否是思考内容 (新格式) - if ((part as any).thought === true && part.text) { - thoughtContent += part.text - } else if (part.text) { - textContent += part.text - } - } - } else { - // 回退到使用 result.text - textContent = result.text || '' - } - - // 如果没有检测到新格式的思考内容,检查旧格式的 标签 - if (!thoughtContent && textContent.includes('')) { - const thinkStart = textContent.indexOf('') - const thinkEnd = textContent.indexOf('') - - if (thinkEnd > thinkStart) { - // 提取reasoning_content - thoughtContent = textContent.substring(thinkStart + 7, thinkEnd).trim() - - // 合并前后的普通内容 - const beforeThink = textContent.substring(0, thinkStart).trim() - const afterThink = textContent.substring(thinkEnd + 8).trim() - textContent = [beforeThink, afterThink].filter(Boolean).join('\n') - } - } - - resultResp.content = textContent - if (thoughtContent) { - resultResp.reasoning_content = thoughtContent - } - - return resultResp - } catch (error) { - console.error('Gemini completions error:', error) - throw error - } - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (!this.isInitialized) { - throw new Error('Provider not initialized') - } - - if (!modelId) { - throw new Error('Model ID is required') - } - - try { - const prompt = `Please generate a concise summary for the following content:\n\n${text}` - - const result = await this.genAI.models.generateContent({ - model: this.ensureGoogleModelName(modelId), - contents: [{ role: 'user', parts: [{ text: prompt }] }], - config: this.getGenerateContentConfig(temperature, maxTokens, modelId, false) - }) - - return this.processGeminiResponse(result) - } catch (error) { - console.error('Gemini summaries error:', error) - throw error - } - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (!this.isInitialized) { - throw new Error('Provider not initialized') - } - - if (!modelId) { - throw new Error('Model ID is required') - } - - try { - const result = await this.genAI.models.generateContent({ - model: this.ensureGoogleModelName(modelId), - contents: [{ role: 'user', parts: [{ text: prompt }] }], - config: this.getGenerateContentConfig(temperature, maxTokens, modelId, false) - }) - - return this.processGeminiResponse(result) - } catch (error) { - console.error('Gemini generateText error:', error) - throw error - } - } - - async suggestions( - context: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (!this.isInitialized) { - throw new Error('Provider not initialized') - } - - if (!modelId) { - throw new Error('Model ID is required') - } - - try { - const prompt = `Based on the following context, please provide up to 5 reasonable suggestion options, each not exceeding 100 characters. Please return in JSON array format without other explanations:\n\n${context}` - - const result = await this.genAI.models.generateContent({ - model: this.ensureGoogleModelName(modelId), - contents: [{ role: 'user', parts: [{ text: prompt }] }], - config: this.getGenerateContentConfig(temperature, maxTokens, modelId, false) - }) - - const responseText = result.text || '' - - // 尝试从响应中解析出JSON数组 - try { - const cleanedText = responseText.replace(/```json|```/g, '').trim() - const suggestions = JSON.parse(cleanedText) - if (Array.isArray(suggestions)) { - return suggestions.map((item) => item.toString()) - } - } catch (parseError) { - console.error('Gemini suggestions parseError:', parseError) - // 如果解析失败,尝试分行处理 - const lines = responseText - .split('\n') - .map((line) => line.trim()) - .filter((line) => line && !line.startsWith('```') && !line.includes(':')) - .map((line) => line.replace(/^[0-9]+\.\s*/, '').replace(/^-\s*/, '')) - - if (lines.length > 0) { - return lines.slice(0, 5) - } - } - - // If all fail, return a default prompt - return ['Unable to generate suggestions'] - } catch (error) { - console.error('Gemini suggestions error:', error) - return ['Error occurred, unable to get suggestions'] - } - } - /** - * 核心流式处理方法 - * 实现BaseLLMProvider中的抽象方法 - */ - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!this.isInitialized) throw new Error('Provider not initialized') - if (!modelId) throw new Error('Model ID is required') - console.log('modelConfig', modelConfig, modelId) - - // 检查是否是图片生成模型 - const isImageGenerationModel = modelConfig?.type === ModelType.ImageGeneration - - // 如果是图片生成模型,使用特殊处理 - if (isImageGenerationModel) { - yield* this.handleImageGenerationStream(messages, modelId, temperature, maxTokens) - return - } - - const safetySettings = await this.getFormattedSafetySettings() - console.log('safetySettings', safetySettings) - - // 添加Gemini工具调用 - let geminiTools: Tool[] = [] - - // Load MCP tools if available - if (mcpTools.length > 0) - geminiTools = (await this.mcpRuntime?.mcpToolsToGeminiTools(mcpTools, this.provider.id)) ?? [] - - // 格式化消息为Gemini格式 - const formattedParts = this.formatGeminiMessages(messages) - - // 1. 获取基础 config - const generateContentConfig: GenerateContentConfig = this.getGenerateContentConfig( - temperature, - maxTokens, - modelId, - modelConfig.reasoning, - modelConfig.thinkingBudget - ) - - // 2. 在本地变量上添加其他属性 - if (formattedParts.systemInstruction) { - generateContentConfig.systemInstruction = formattedParts.systemInstruction - } - - if (geminiTools.length > 0) { - generateContentConfig.tools = geminiTools - // 仅当存在 functionDeclarations 时才配置 functionCallingConfig - const hasFunctionDeclarations = geminiTools.some((t: any) => { - const fns = t?.functionDeclarations - return Array.isArray(fns) && fns.length > 0 - }) - if (hasFunctionDeclarations) { - generateContentConfig.toolConfig = { - functionCallingConfig: { - mode: FunctionCallingConfigMode.AUTO // 允许模型自动决定是否调用工具 - } - } - } - } - - if (safetySettings) { - generateContentConfig.safetySettings = safetySettings - } - - // 3. 一次性创建完整的 requestParams - const requestParams: GenerateContentParameters = { - model: modelId, - contents: formattedParts.contents, - config: generateContentConfig - } - - const streamRequestParams = { - ...requestParams, - model: this.ensureGoogleModelName(requestParams.model as string) - } - - await this.emitRequestTrace(modelConfig, { - endpoint: this.buildGeminiStreamEndpoint(modelId), - headers: this.buildGeminiTraceHeaders(), - body: streamRequestParams - }) - - // 发送流式请求 - const result = await this.genAI.models.generateContentStream(streamRequestParams) - - // 状态变量 - let buffer = '' - let isInThinkTag = false - let toolUseDetected = false - let usageMetadata: GenerateContentResponseUsageMetadata | undefined - let isNewThoughtFormatDetected = modelConfig.reasoning === true - - // 流处理循环 - for await (const chunk of result) { - // 处理用量统计 - if (chunk.usageMetadata) { - usageMetadata = chunk.usageMetadata - } - - // console.log('chunk.candidates', JSON.stringify(chunk.candidates, null, 2)) - // 检查是否包含函数调用 - if (chunk.candidates && chunk.candidates[0]?.content?.parts?.[0]?.functionCall) { - const functionCall = chunk.candidates[0].content.parts[0].functionCall - const functionName = functionCall.name - const functionArgs = functionCall.args || {} - const toolCallId = `gemini-tool-${Date.now()}` - - toolUseDetected = true - - // 发送工具调用开始事件 - yield createStreamEvent.toolCallStart(toolCallId, functionName || '') - - // 发送工具调用参数 - const argsString = JSON.stringify(functionArgs) - yield createStreamEvent.toolCallChunk(toolCallId, argsString) - - // 发送工具调用结束事件 - yield createStreamEvent.toolCallEnd(toolCallId, argsString) - - // 设置停止原因为工具使用 - break - } - - // 处理内容块 - let content = '' - let thoughtContent = '' - - // 处理文本和图像内容 - if (chunk.candidates && chunk.candidates[0]?.content?.parts) { - for (const part of chunk.candidates[0].content.parts) { - // 检查是否是思考内容 (新格式) - if ((part as any).thought === true && part.text) { - isNewThoughtFormatDetected = true - thoughtContent += part.text - } else if (part.text) { - content += part.text - } else if (part.inlineData && part.inlineData.data && part.inlineData.mimeType) { - // 处理图像数据 - yield createStreamEvent.imageData({ - data: part.inlineData.data, - mimeType: part.inlineData.mimeType - }) - } - } - } else { - // 兼容处理 - content = chunk.text || '' - } - - // 如果检测到思考内容,直接发送 - if (thoughtContent) { - yield createStreamEvent.reasoning(thoughtContent) - } - - if (!content) continue - - if (isNewThoughtFormatDetected) { - yield createStreamEvent.text(content) - } else { - buffer += content - - if (buffer.includes('') && !isInThinkTag) { - const thinkStart = buffer.indexOf('') - if (thinkStart > 0) { - yield createStreamEvent.text(buffer.substring(0, thinkStart)) - } - buffer = buffer.substring(thinkStart + 7) - isInThinkTag = true - } - - if (isInThinkTag && buffer.includes('')) { - const thinkEnd = buffer.indexOf('') - const reasoningContent = buffer.substring(0, thinkEnd) - if (reasoningContent) { - yield createStreamEvent.reasoning(reasoningContent) - } - buffer = buffer.substring(thinkEnd + 8) - isInThinkTag = false - } - - if (!isInThinkTag && buffer) { - yield createStreamEvent.text(buffer) - buffer = '' - } - } - } - - if (usageMetadata) { - yield createStreamEvent.usage({ - prompt_tokens: usageMetadata.promptTokenCount || 0, - completion_tokens: usageMetadata.candidatesTokenCount || 0, - total_tokens: usageMetadata.totalTokenCount || 0 - }) - } - - // 处理剩余缓冲区内容 - if (!isNewThoughtFormatDetected && buffer) { - if (isInThinkTag) { - yield createStreamEvent.reasoning(buffer) - } else { - yield createStreamEvent.text(buffer) - } - } - - // 发送停止事件 - yield createStreamEvent.stop(toolUseDetected ? 'tool_use' : 'complete') - } - - /** - * 处理图片生成模型的流式输出 - */ - private async *handleImageGenerationStream( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): AsyncGenerator { - try { - // 提取用户消息并构建parts数组 - const userMessage = messages.findLast((msg) => msg.role === 'user') - if (!userMessage) { - throw new Error('No user message found for image generation') - } - - // 构建包含文本和图片的parts数组,参考formatGeminiMessages的逻辑 - const parts: Part[] = [] - - if (typeof userMessage.content === 'string') { - // 处理纯文本消息 - if (userMessage.content.trim() !== '') { - parts.push({ text: userMessage.content }) - } - } else if (Array.isArray(userMessage.content)) { - // 处理多模态消息(带图片等) - for (const part of userMessage.content) { - if (part.type === 'text') { - // 只添加非空文本 - if (part.text && part.text.trim() !== '') { - parts.push({ text: part.text }) - } - } else if (part.type === 'image_url' && part.image_url) { - // 处理图片(假设是 base64 格式) - const matches = part.image_url.url.match(/^data:([^;]+);base64,(.+)$/) - if (matches && matches.length === 3) { - const mimeType = matches[1] - const base64Data = matches[2] - parts.push({ - inlineData: { - data: base64Data, - mimeType: mimeType - } - }) - } - } - } - } - - // 如果没有有效的parts,抛出错误 - if (parts.length === 0) { - throw new Error('No valid content found for image generation') - } - - // 发送生成请求 - const result = await this.genAI.models.generateContentStream({ - model: this.ensureGoogleModelName(modelId), - contents: [{ role: 'user', parts }], - config: this.getGenerateContentConfig(temperature, maxTokens, modelId, false) // 图像生成不需要reasoning - }) - - // 处理流式响应 - for await (const chunk of result) { - if (chunk.candidates && chunk.candidates[0]?.content?.parts) { - for (const part of chunk.candidates[0].content.parts) { - if (part.text) { - // 输出文本内容 - yield createStreamEvent.text(part.text) - } else if (part.inlineData) { - // 输出图像数据 - yield createStreamEvent.imageData({ - data: part.inlineData.data || '', - mimeType: part.inlineData.mimeType || '' - }) - } - } - } - } - - // 发送停止事件 - yield createStreamEvent.stop('complete') - } catch (error) { - console.error('Image generation stream error:', error) - yield createStreamEvent.error( - error instanceof Error ? error.message : 'Image generation failed' - ) - yield createStreamEvent.stop('error') - } - } - - async getEmbeddings(modelId: string, texts: string[]): Promise { - if (!this.genAI) throw new Error('Google Generative AI client is not initialized') - // Gemini embedContent 支持批量输入 - const resp = await this.genAI.models.embedContent({ - model: this.ensureGoogleModelName(modelId), - contents: texts.map((text) => ({ - parts: [{ text }] - })) - }) - // resp.embeddings?: ContentEmbedding[] - if (resp && Array.isArray(resp.embeddings)) { - return resp.embeddings.map((e) => (Array.isArray(e.values) ? e.values : [])) - } - // 若无返回,抛出异常 - throw new Error('Gemini embedding API did not return embeddings') - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/githubProvider.ts b/src/main/presenter/llmProviderPresenter/providers/githubProvider.ts deleted file mode 100644 index 6f2bc269b..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/githubProvider.ts +++ /dev/null @@ -1,67 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - MODEL_META, - ChatMessage, - IConfigPresenter -} from '@shared/presenter' -import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS } from '@shared/modelConfigDefaults' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import { ModelsPage } from 'openai/resources' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -export class GithubProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - const response = (await this.openai.models.list(options)) as ModelsPage & { - body: { - id: string - name: string - description: string - }[] - } - return response.body.map((model) => ({ - id: model.name, - name: model.name, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, - maxTokens: DEFAULT_MODEL_MAX_TOKENS, - description: model.description - })) - } - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/grokProvider.ts b/src/main/presenter/llmProviderPresenter/providers/grokProvider.ts deleted file mode 100644 index 77f830ceb..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/grokProvider.ts +++ /dev/null @@ -1,230 +0,0 @@ -import { LLM_PROVIDER, LLMResponse, ChatMessage, IConfigPresenter } from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import { ModelConfig, MCPToolDefinition, LLMCoreStreamEvent } from '@shared/presenter' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -export class GrokProvider extends OpenAICompatibleProvider { - // Image generation model ID - private static readonly IMAGE_MODEL_ID = 'grok-2-image' - // private static readonly IMAGE_ENDPOINT = '/images/generations' - - // Reasoning models that support reasoning_content - private static readonly REASONING_MODELS: string[] = ['grok-4', 'grok-3-mini', 'grok-3-mini-fast'] - - // Models that support reasoning_effort parameter (grok-4 does not) - private static readonly REASONING_EFFORT_MODELS: string[] = ['grok-3-mini', 'grok-3-mini-fast'] - - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - // Check if it's an image model - private isImageModel(modelId: string): boolean { - return modelId.startsWith(GrokProvider.IMAGE_MODEL_ID) - } - - // Check if model supports reasoning - private isReasoningModel(modelId: string): boolean { - return GrokProvider.REASONING_MODELS.some((model) => - modelId.toLowerCase().includes(model.toLowerCase()) - ) - } - - // Check if model supports reasoning_effort parameter - private supportsReasoningEffort(modelId: string): boolean { - return GrokProvider.REASONING_EFFORT_MODELS.some((model) => - modelId.toLowerCase().includes(model.toLowerCase()) - ) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - // Image generation models require special handling - if (this.isImageModel(modelId)) { - return this.handleImageGeneration(messages) - } - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - // Image generation models do not support summaries - if (this.isImageModel(modelId)) { - throw new Error('Image generation model does not support summaries') - } - return this.openAICompletion( - [ - { - role: 'user', - content: `Please summarize the following content using concise language and highlighting key points:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - // Image generation models use special handling - if (this.isImageModel(modelId)) { - return this.handleImageGeneration([{ role: 'user', content: prompt }]) - } - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - - // Special method for handling image generation requests - private async handleImageGeneration( - messages: ChatMessage[] - ): Promise { - if (!this.isInitialized) { - throw new Error('Provider not initialized') - } - - // Extract prompt (use the last user message) - const userMessage = messages.findLast((msg) => msg.role === 'user') - if (!userMessage) { - throw new Error('No user message found for image generation') - } - - const prompt = - typeof userMessage.content === 'string' - ? userMessage.content - : userMessage.content - ?.filter((c) => c.type === 'text') - .map((c) => c.text) - .join('\n') || '' - - // Create image generation request - try { - const response = await this.openai.images.generate({ - model: GrokProvider.IMAGE_MODEL_ID, - prompt, - response_format: 'b64_json' - }) - // Handle response - if (response.data && response.data.length > 0) { - const imageData = response.data[0] - if (imageData.b64_json) { - // Return base64-encoded image data while preserving original data - return { - content: `![Generated Image](data:image/png;base64,${imageData.b64_json})`, - imageData: imageData.b64_json, - mimeType: 'image/png' - } - } else if (imageData.url) { - // Return image URL - return { - content: `![Generated Image](${imageData.url})` - } - } - } - throw new Error('No image data received from API') - } catch (error: unknown) { - console.error('Image generation failed:', error) - throw new Error( - `Image generation failed: ${error instanceof Error ? error.message : 'Unknown error'}` - ) - } - } - - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!this.isInitialized) throw new Error('Provider not initialized') - if (!modelId) throw new Error('Model ID is required') - - // Handle image generation models - if (this.isImageModel(modelId)) { - const result = await this.handleImageGeneration(messages) - // Use additional fields directly - if (result.imageData && result.mimeType) { - yield { - type: 'image_data', - image_data: { - data: result.imageData, - mimeType: result.mimeType - } - } - } else { - // If no imageData field, fallback to text format - yield { - type: 'text', - content: result.content - } - } - // Add brief delay to ensure all RESPONSE events are processed - await new Promise((resolve) => setTimeout(resolve, 300)) - return - } - - // Handle reasoning models - const shouldAddReasoningEffort = this.isReasoningModel(modelId) && modelConfig?.reasoningEffort - const needsParameterModification = shouldAddReasoningEffort - - if (needsParameterModification) { - const originalCreate = this.openai.chat.completions.create.bind(this.openai.chat.completions) - this.openai.chat.completions.create = ((params: any, options?: any) => { - const modifiedParams = { ...params } - - // Add reasoning effort parameter if supported - if (shouldAddReasoningEffort && this.supportsReasoningEffort(modelId)) { - modifiedParams.reasoning_effort = modelConfig.reasoningEffort - } - - return originalCreate(modifiedParams, options) - }) as any - - try { - const effectiveModelConfig = { - ...modelConfig, - reasoningEffort: undefined - } - yield* super.coreStream( - messages, - modelId, - effectiveModelConfig, - temperature, - maxTokens, - mcpTools - ) - } finally { - this.openai.chat.completions.create = originalCreate - } - } else { - yield* super.coreStream(messages, modelId, modelConfig, temperature, maxTokens, mcpTools) - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/groqProvider.ts b/src/main/presenter/llmProviderPresenter/providers/groqProvider.ts deleted file mode 100644 index 02bec185e..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/groqProvider.ts +++ /dev/null @@ -1,194 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - ChatMessage, - MODEL_META, - IConfigPresenter -} from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Define interface for Groq model response (following PPIO naming convention) -interface GroqModelResponse { - id: string - object: string - owned_by: string - created: number - display_name?: string - description?: string - context_size: number // Groq uses context_window, but we'll map it to context_size - max_output_tokens: number // Groq may use max_tokens, but we'll map it to max_output_tokens - features?: string[] - status?: number // Groq uses active boolean, we'll map it to status number - model_type?: string - // Groq specific fields that we need to handle - active?: boolean - context_window?: number - max_tokens?: number - public_apps?: boolean -} - -export class GroqProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `Please summarize the following content using concise language and highlighting key points:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - - /** - * Override fetchOpenAIModels to parse Groq specific model data and update model configurations - * @param options - Request options - * @returns Promise - Array of model metadata - */ - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - try { - const response = await this.openai.models.list(options) - // console.log('Groq models response:', JSON.stringify(response, null, 2)) - - const models: MODEL_META[] = [] - - for (const model of response.data) { - // Type the model as Groq specific response - const groqModel = model as unknown as GroqModelResponse - - // Skip inactive models (map Groq's active field to status) - const modelStatus = groqModel.status ?? (groqModel.active ? 1 : 0) - if (modelStatus === 0 || groqModel.active === false) { - continue - } - - // Extract model information - const modelId = groqModel.id - const features = groqModel.features || [] - - // Map Groq fields to PPIO-style naming - const contextSize = groqModel.context_size || groqModel.context_window || 4096 - const maxOutputTokens = groqModel.max_output_tokens || groqModel.max_tokens || 2048 - - // Check features for capabilities or infer from model name - const hasFunctionCalling = - features.includes('function-calling') || - (!modelId.toLowerCase().includes('distil') && !modelId.toLowerCase().includes('gemma')) - const hasVision = - features.includes('vision') || - modelId.toLowerCase().includes('vision') || - modelId.toLowerCase().includes('llava') - - // Get existing model configuration first - const existingConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) - - // Extract configuration values with proper fallback priority: API -> existing config -> default - const contextLength = contextSize || existingConfig.contextLength || 4096 - const maxTokens = maxOutputTokens || existingConfig.maxTokens || 2048 - - // Build new configuration based on API response - const newConfig = { - contextLength: contextLength, - maxTokens: maxTokens, - functionCall: hasFunctionCalling, - vision: hasVision, - reasoning: existingConfig.reasoning, // Keep existing reasoning setting - temperature: existingConfig.temperature, // Keep existing temperature - type: existingConfig.type // Keep existing type - } - - // Check if configuration has changed - const configChanged = - existingConfig.contextLength !== newConfig.contextLength || - existingConfig.maxTokens !== newConfig.maxTokens || - existingConfig.functionCall !== newConfig.functionCall || - existingConfig.vision !== newConfig.vision - - // Update configuration if changed - if (configChanged) { - // console.log(`Updating configuration for model ${modelId}:`, { - // old: { - // contextLength: existingConfig.contextLength, - // maxTokens: existingConfig.maxTokens, - // functionCall: existingConfig.functionCall, - // vision: existingConfig.vision - // }, - // new: newConfig - // }) - - this.configPresenter.setModelConfig(modelId, this.provider.id, newConfig, { - source: 'provider' - }) - } - - // Create MODEL_META object - const modelMeta: MODEL_META = { - id: modelId, - name: groqModel.display_name || modelId, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: contextLength, - maxTokens: maxTokens, - description: groqModel.description || `Groq model ${modelId}`, - vision: hasVision, - functionCall: hasFunctionCalling, - reasoning: existingConfig.reasoning || false - } - - models.push(modelMeta) - } - - console.log(`Processed ${models.length} Groq models with dynamic configuration updates`) - return models - } catch (error) { - console.error('Error fetching Groq models:', error) - // Fallback to parent implementation - return super.fetchOpenAIModels(options) - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/jiekouProvider.ts b/src/main/presenter/llmProviderPresenter/providers/jiekouProvider.ts deleted file mode 100644 index 29000d59b..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/jiekouProvider.ts +++ /dev/null @@ -1,21 +0,0 @@ -import { LLM_PROVIDER, MODEL_META, IConfigPresenter } from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -export class JiekouProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - const models = await super.fetchOpenAIModels(options) - return models.map((model) => ({ - ...model, - group: 'JieKou.AI' - })) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/lmstudioProvider.ts b/src/main/presenter/llmProviderPresenter/providers/lmstudioProvider.ts deleted file mode 100644 index ccc7c13a5..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/lmstudioProvider.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { IConfigPresenter, LLM_PROVIDER } from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' -export class LMStudioProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/minimaxProvider.ts b/src/main/presenter/llmProviderPresenter/providers/minimaxProvider.ts deleted file mode 100644 index 2513bb47b..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/minimaxProvider.ts +++ /dev/null @@ -1,44 +0,0 @@ -import { MODEL_META } from '@shared/presenter' -import { ModelType } from '@shared/model' -import { - resolveModelContextLength, - resolveModelFunctionCall, - resolveModelMaxTokens -} from '@shared/modelConfigDefaults' -import { AnthropicProvider } from './anthropicProvider' -import { providerDbLoader } from '../../configPresenter/providerDbLoader' -import { modelCapabilities } from '../../configPresenter/modelCapabilities' - -export class MinimaxProvider extends AnthropicProvider { - protected async fetchProviderModels(): Promise { - const resolvedId = modelCapabilities.resolveProviderId(this.provider.id) || this.provider.id - const provider = providerDbLoader.getProvider(resolvedId) - - if (provider && Array.isArray(provider.models) && provider.models.length > 0) { - return provider.models.map((model) => { - const inputs = model.modalities?.input - const outputs = model.modalities?.output - const hasImageInput = Array.isArray(inputs) && inputs.includes('image') - const hasImageOutput = Array.isArray(outputs) && outputs.includes('image') - const modelType = hasImageOutput ? ModelType.ImageGeneration : ModelType.Chat - - return { - id: model.id, - name: model.display_name || model.name || model.id, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: resolveModelContextLength(model.limit?.context), - maxTokens: resolveModelMaxTokens(model.limit?.output), - vision: hasImageInput, - functionCall: resolveModelFunctionCall(model.tool_call), - reasoning: Boolean(model.reasoning?.supported), - enableSearch: Boolean(model.search?.supported), - type: modelType - } - }) - } - - return super.fetchProviderModels() - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/modelscopeProvider.ts b/src/main/presenter/llmProviderPresenter/providers/modelscopeProvider.ts deleted file mode 100644 index 2f9960693..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/modelscopeProvider.ts +++ /dev/null @@ -1,339 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - ChatMessage, - KeyStatus, - IConfigPresenter, - MCPServerConfig, - ModelScopeMcpSyncOptions -} from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Define interface for ModelScope MCP API response -export interface ModelScopeMcpServerResponse { - code: number - data: { - mcp_server_list: ModelScopeMcpServer[] - total_count: number - } - message: string - request_id: string - success: boolean -} - -// Define interface for ModelScope MCP server (updated for operational API) -export interface ModelScopeMcpServer { - name: string - description: string - id: string - chinese_name?: string // Chinese name field - logo_url: string - operational_urls: Array<{ - id: string - url: string - }> - tags: string[] - locales: { - zh: { - name: string - description: string - } - en: { - name: string - description: string - } - } -} - -export class ModelscopeProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `You need to summarize the user's conversation into a title of no more than 10 words, with the title language matching the user's primary language, without using punctuation or other special symbols:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - - /** - * Get current API key status from ModelScope - * @returns Promise API key status information - */ - public async getKeyStatus(): Promise { - if (!this.provider.apiKey) { - throw new Error('API key is required') - } - - try { - // Use models endpoint to check API key validity - const response = await this.openai.models.list({ timeout: 10000 }) - - return { - limit_remaining: 'Available', - remainNum: response.data?.length || 0 - } - } catch (error) { - console.error('ModelScope API key check failed:', error) - throw new Error( - `ModelScope API key check failed: ${error instanceof Error ? error.message : String(error)}` - ) - } - } - - /** - * Override check method to use ModelScope's API validation - * @returns Promise<{ isOk: boolean; errorMsg: string | null }> - */ - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - await this.getKeyStatus() - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - let errorMessage = 'An unknown error occurred during ModelScope API key check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - - console.error('ModelScope API key check failed:', error) - return { isOk: false, errorMsg: errorMessage } - } - } - - /** - * Sync operational MCP servers from ModelScope API - * @param _options - Sync options including filters (currently not used by operational API) - * @returns Promise MCP servers response - */ - public async syncMcpServers( - _syncOptions?: ModelScopeMcpSyncOptions - ): Promise { - if (!this.provider.apiKey) { - throw new Error('API key is required for MCP sync') - } - - try { - // Use the operational API endpoint - GET request, no body needed - const response = await fetch('https://www.modelscope.cn/openapi/v1/mcp/servers/operational', { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${this.provider.apiKey}` - }, - signal: AbortSignal.timeout(30000) // 30 second timeout - }) - - // Handle authentication errors - if (response.status === 401 || response.status === 403) { - throw new Error('ModelScope MCP sync unauthorized: Invalid or expired API key') - } - - // Handle server errors - if (response.status === 500 || !response.ok) { - const errorText = await response.text() - throw new Error( - `ModelScope MCP sync failed: ${response.status} ${response.statusText} - ${errorText}` - ) - } - - const data: ModelScopeMcpServerResponse = await response.json() - - if (!data.success) { - throw new Error(`ModelScope MCP sync failed: ${data.message}`) - } - - console.log( - `Successfully fetched ${data.data.mcp_server_list.length} operational MCP servers from ModelScope` - ) - return data - } catch (error) { - console.error('ModelScope MCP sync error:', error) - throw error - } - } - - /** - * Convert ModelScope operational MCP server to internal MCP server config format - * @param mcpServer - ModelScope MCP server data - * @returns Internal MCP server config - */ - public convertMcpServerToConfig(mcpServer: ModelScopeMcpServer): MCPServerConfig { - // Check if operational URLs are available - if (!mcpServer.operational_urls || mcpServer.operational_urls.length === 0) { - throw new Error(`No operational URLs found for server ${mcpServer.id}`) - } - - // Use the first operational URL - const baseUrl = mcpServer.operational_urls[0].url - - // Generate random emoji for icon - const emojis = [ - '🔧', - '⚡', - '🚀', - '🔨', - '⚙️', - '🛠️', - '🔥', - '💡', - '⭐', - '🎯', - '🎨', - '🔮', - '💎', - '🎪', - '🎭', - '🎨', - '🔬', - '📱', - '💻', - '🖥️', - '⌨️', - '🖱️', - '📡', - '🔊', - '📢', - '📣', - '📯', - '🔔', - '🔕', - '📻', - '📺', - '📷', - '📹', - '🎥', - '📽️', - '🔍', - '🔎', - '💰', - '💳', - '💸', - '💵', - '🎲', - '🃏', - '🎮', - '🕹️', - '🎯', - '🎳', - '🎨', - '🖌️', - '🖍️', - '📝', - '✏️', - '📏', - '📐', - '📌', - '📍', - '🗂️', - '📂', - '📁', - '📰', - '📄', - '📃', - '📜', - '📋', - '📊', - '📈', - '📉', - '📦', - '📫', - '📪', - '📬', - '📭', - '📮', - '🗳️', - '✉️', - '📧', - '📨', - '📩', - '📤', - '📥', - '📬', - '📭', - '📮', - '🗂️', - '📂', - '📁', - '🗄️', - '🗃️', - '📋', - '📑', - '📄', - '📃', - '📰', - '🗞️', - '📜', - '🔖' - ] - const randomEmoji = emojis[Math.floor(Math.random() * emojis.length)] - - // Get display name: chinese_name first, then name, then id - const displayName = mcpServer.chinese_name || mcpServer.name || mcpServer.id - - return { - command: '', // Not needed for SSE type - args: [], // Not needed for SSE type - env: {}, - descriptions: - mcpServer.locales?.zh?.description || - mcpServer.description || - `ModelScope MCP Server: ${displayName}`, - icons: randomEmoji, // Random emoji instead of URL - autoApprove: ['all'], - enabled: false, - disable: false, // Default to disabled for safety - type: 'sse' as const, // SSE type for operational servers - baseUrl: baseUrl, // Use operational URL - source: 'modelscope', - sourceId: mcpServer.id - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/newApiProvider.ts b/src/main/presenter/llmProviderPresenter/providers/newApiProvider.ts deleted file mode 100644 index f214dab68..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/newApiProvider.ts +++ /dev/null @@ -1,672 +0,0 @@ -import Anthropic from '@anthropic-ai/sdk' -import { - ChatMessage, - IConfigPresenter, - KeyStatus, - LLMCoreStreamEvent, - LLM_EMBEDDING_ATTRS, - LLM_PROVIDER, - LLMResponse, - MCPToolDefinition, - MODEL_META, - ModelConfig -} from '@shared/presenter' -import { - ApiEndpointType, - ModelType, - isNewApiEndpointType, - resolveNewApiCapabilityProviderId, - type NewApiEndpointType -} from '@shared/model' -import { ProxyAgent } from 'undici' -import { BaseLLMProvider } from '../baseProvider' -import { proxyConfig } from '../../proxyConfig' -import { AnthropicProvider } from './anthropicProvider' -import { GeminiProvider } from './geminiProvider' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import { OpenAIResponsesProvider } from './openAIResponsesProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -type NewApiModelRecord = { - id?: unknown - name?: unknown - owned_by?: unknown - description?: unknown - type?: unknown - supported_endpoint_types?: unknown - context_length?: unknown - contextLength?: unknown - input_token_limit?: unknown - max_input_tokens?: unknown - max_tokens?: unknown - max_output_tokens?: unknown - output_token_limit?: unknown -} - -type NewApiModelsResponse = { - data?: NewApiModelRecord[] -} - -const DEFAULT_NEW_API_BASE_URL = 'https://www.newapi.ai' - -class NewApiOpenAIChatDelegate extends OpenAICompatibleProvider { - protected override async init() { - this.isInitialized = true - } -} - -class NewApiOpenAIResponsesDelegate extends OpenAIResponsesProvider { - protected override async init() { - this.isInitialized = true - } -} - -class NewApiGeminiDelegate extends GeminiProvider { - protected override async init() { - this.isInitialized = true - } -} - -class NewApiAnthropicDelegate extends AnthropicProvider { - private clientInitialized = false - - protected override async init() {} - - public async ensureClientInitialized(): Promise { - const apiKey = this.provider.apiKey || process.env.ANTHROPIC_API_KEY || null - if (!apiKey) { - this.clientInitialized = false - this.isInitialized = false - return - } - - const proxyUrl = proxyConfig.getProxyUrl() - const fetchOptions: { dispatcher?: ProxyAgent } = {} - - if (proxyUrl) { - fetchOptions.dispatcher = new ProxyAgent(proxyUrl) - } - - const self = this as unknown as { anthropic?: Anthropic } - self.anthropic = new Anthropic({ - apiKey, - baseURL: this.provider.baseUrl || DEFAULT_NEW_API_BASE_URL, - defaultHeaders: this.defaultHeaders, - fetchOptions - }) - - this.clientInitialized = true - this.isInitialized = true - } - - public isClientInitialized(): boolean { - return this.clientInitialized - } - - public override onProxyResolved(): void { - void this.ensureClientInitialized() - } -} - -export class NewApiProvider extends BaseLLMProvider { - private readonly openaiChatDelegate: NewApiOpenAIChatDelegate - private readonly openaiResponsesDelegate: NewApiOpenAIResponsesDelegate - private readonly anthropicDelegate: NewApiAnthropicDelegate - private readonly geminiDelegate: NewApiGeminiDelegate - - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - - const host = this.getNormalizedBaseHost() - - this.openaiChatDelegate = new NewApiOpenAIChatDelegate( - this.buildDelegateProvider({ - apiType: 'openai-completions', - baseUrl: `${host}/v1`, - capabilityProviderId: resolveNewApiCapabilityProviderId('openai') - }), - configPresenter, - mcpRuntime - ) - - this.openaiResponsesDelegate = new NewApiOpenAIResponsesDelegate( - this.buildDelegateProvider({ - apiType: 'openai-responses', - baseUrl: `${host}/v1`, - capabilityProviderId: resolveNewApiCapabilityProviderId('openai-response') - }), - configPresenter, - mcpRuntime - ) - - this.anthropicDelegate = new NewApiAnthropicDelegate( - this.buildDelegateProvider({ - apiType: 'anthropic', - baseUrl: host, - capabilityProviderId: resolveNewApiCapabilityProviderId('anthropic') - }), - configPresenter, - mcpRuntime - ) - - this.geminiDelegate = new NewApiGeminiDelegate( - this.buildDelegateProvider({ - apiType: 'gemini', - baseUrl: host, - capabilityProviderId: resolveNewApiCapabilityProviderId('gemini') - }), - configPresenter, - mcpRuntime - ) - - this.init() - } - - private getNormalizedBaseHost(): string { - const rawBaseUrl = (this.provider.baseUrl || DEFAULT_NEW_API_BASE_URL).trim() - const normalizedBaseUrl = rawBaseUrl.replace(/\/+$/, '') - return normalizedBaseUrl.replace(/\/(v1|v1beta(?:\d+)?)$/i, '') || DEFAULT_NEW_API_BASE_URL - } - - private getStoredModelMeta(modelId: string): MODEL_META | undefined { - return [...this.models, ...this.customModels].find((model) => model.id === modelId) - } - - private buildDelegateProvider(overrides: Partial): LLM_PROVIDER { - return { - ...this.provider, - ...overrides - } - } - - private getDefaultEndpointType(model: Pick) { - const supportedEndpointTypes = model.supportedEndpointTypes ?? [] - if (supportedEndpointTypes.length === 0) { - return model.type === ModelType.ImageGeneration ? 'image-generation' : undefined - } - - if ( - model.type === ModelType.ImageGeneration && - supportedEndpointTypes.includes('image-generation') - ) { - return 'image-generation' - } - - return supportedEndpointTypes[0] - } - - private resolveEndpointType(modelId: string): NewApiEndpointType { - const modelConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) - if (isNewApiEndpointType(modelConfig.endpointType)) { - return modelConfig.endpointType - } - - const storedModel = this.getStoredModelMeta(modelId) - if (storedModel && isNewApiEndpointType(storedModel.endpointType)) { - return storedModel.endpointType - } - - const defaultEndpointType = storedModel ? this.getDefaultEndpointType(storedModel) : undefined - return defaultEndpointType ?? 'openai' - } - - private buildImageModelConfig(modelId: string, modelConfig?: ModelConfig): ModelConfig { - const baseConfig = modelConfig ?? this.configPresenter.getModelConfig(modelId, this.provider.id) - return { - ...baseConfig, - apiEndpoint: ApiEndpointType.Image, - type: ModelType.ImageGeneration, - endpointType: 'image-generation' - } - } - - private buildFallbackSummaryTitle(messages: ChatMessage[]): string { - const latestUserMessage = [...messages].reverse().find((message) => message.role === 'user') - const rawContent = latestUserMessage?.content - - const textContent = - typeof rawContent === 'string' - ? rawContent - : Array.isArray(rawContent) - ? rawContent - .filter((part) => part.type === 'text' && typeof part.text === 'string') - .map((part) => part.text) - .join(' ') - : '' - - const normalizedTitle = textContent.replace(/\s+/g, ' ').trim() - if (!normalizedTitle) { - return 'New Conversation' - } - - return normalizedTitle.slice(0, 60) - } - - private inferModelType(rawModel: NewApiModelRecord, supported: NewApiEndpointType[]) { - const normalizedRawType = - typeof rawModel.type === 'string' ? rawModel.type.trim().toLowerCase() : '' - const normalizedModelId = typeof rawModel.id === 'string' ? rawModel.id.toLowerCase() : '' - - if ( - normalizedRawType === 'imagegeneration' || - normalizedRawType === 'image-generation' || - normalizedRawType === 'image' || - supported.includes('image-generation') - ) { - return ModelType.ImageGeneration - } - - if ( - normalizedRawType === 'embedding' || - normalizedRawType === 'embeddings' || - normalizedModelId.includes('embedding') - ) { - return ModelType.Embedding - } - - if (normalizedRawType === 'rerank' || normalizedModelId.includes('rerank')) { - return ModelType.Rerank - } - - return undefined - } - - private toGeminiMessages(messages: ChatMessage[]): Array<{ - role: 'system' | 'user' | 'assistant' - content: string - }> { - return messages - .filter((message): message is ChatMessage & { role: 'system' | 'user' | 'assistant' } => { - return message.role === 'system' || message.role === 'user' || message.role === 'assistant' - }) - .map((message) => ({ - role: message.role, - content: - typeof message.content === 'string' - ? message.content - : Array.isArray(message.content) - ? message.content - .filter((part) => part.type === 'text' && typeof part.text === 'string') - .map((part) => part.text) - .join('\n') - : '' - })) - } - - private resolveContextLength(rawModel: NewApiModelRecord): number | undefined { - const candidates = [ - rawModel.context_length, - rawModel.contextLength, - rawModel.input_token_limit, - rawModel.max_input_tokens - ] - - const firstNumber = candidates.find( - (candidate): candidate is number => - typeof candidate === 'number' && Number.isFinite(candidate) - ) - return firstNumber - } - - private resolveMaxTokens(rawModel: NewApiModelRecord): number | undefined { - const candidates = [ - rawModel.max_tokens, - rawModel.max_output_tokens, - rawModel.output_token_limit - ] - - const firstNumber = candidates.find( - (candidate): candidate is number => - typeof candidate === 'number' && Number.isFinite(candidate) - ) - return firstNumber - } - - private async ensureAnthropicDelegateReady(): Promise { - await this.anthropicDelegate.ensureClientInitialized() - - if (!this.anthropicDelegate.isClientInitialized()) { - throw new Error('Anthropic SDK not initialized') - } - - return this.anthropicDelegate - } - - private async collectImageCompletion( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - const response: LLMResponse = { - content: '' - } - - const modelConfig = this.buildImageModelConfig(modelId) - - for await (const event of this.openaiChatDelegate.coreStream( - messages, - modelId, - modelConfig, - temperature ?? modelConfig.temperature ?? 0.7, - maxTokens ?? modelConfig.maxTokens ?? 1024, - [] - )) { - switch (event.type) { - case 'text': - response.content += event.content - break - case 'reasoning': - response.reasoning_content = `${response.reasoning_content ?? ''}${event.reasoning_content}` - break - case 'image_data': - if (!response.content) { - response.content = event.image_data.data - } - break - case 'usage': - response.totalUsage = event.usage - break - case 'error': - throw new Error(event.error_message) - } - } - - return response - } - - private async syncProviderManagedEndpointType(models: MODEL_META[]): Promise { - for (const model of models) { - if (this.configPresenter.hasUserModelConfig(model.id, this.provider.id)) { - continue - } - - const existingConfig = this.configPresenter.getModelConfig(model.id, this.provider.id) - const defaultEndpointType = this.getDefaultEndpointType(model) - const nextApiEndpoint = - defaultEndpointType === 'image-generation' ? ApiEndpointType.Image : ApiEndpointType.Chat - - this.configPresenter.setModelConfig( - model.id, - this.provider.id, - { - ...existingConfig, - type: model.type ?? existingConfig.type, - apiEndpoint: nextApiEndpoint, - endpointType: defaultEndpointType ?? existingConfig.endpointType - }, - { source: 'provider' } - ) - } - } - - protected async fetchProviderModels(): Promise { - const controller = new AbortController() - const timeout = setTimeout(() => controller.abort(), this.getModelFetchTimeout()) - - try { - const proxyUrl = proxyConfig.getProxyUrl() - const dispatcher = proxyUrl ? new ProxyAgent(proxyUrl) : undefined - const response = await fetch(`${this.getNormalizedBaseHost()}/v1/models`, { - method: 'GET', - headers: { - Authorization: `Bearer ${this.provider.apiKey}`, - 'Content-Type': 'application/json', - ...this.defaultHeaders - }, - signal: controller.signal, - ...(dispatcher ? ({ dispatcher } as Record) : {}) - }) - - if (!response.ok) { - const responseText = await response.text() - throw new Error(responseText || `Failed to fetch models: ${response.status}`) - } - - const payload = (await response.json()) as NewApiModelsResponse - const rawModels = Array.isArray(payload.data) ? payload.data : [] - - const models = rawModels - .filter((rawModel): rawModel is NewApiModelRecord & { id: string } => { - return typeof rawModel.id === 'string' && rawModel.id.trim().length > 0 - }) - .map((rawModel) => { - const supportedEndpointTypes = Array.isArray(rawModel.supported_endpoint_types) - ? rawModel.supported_endpoint_types.filter(isNewApiEndpointType) - : [] - const type = this.inferModelType(rawModel, supportedEndpointTypes) - const contextLength = this.resolveContextLength(rawModel) - const maxTokens = this.resolveMaxTokens(rawModel) - const model: MODEL_META = { - id: rawModel.id, - name: typeof rawModel.name === 'string' ? rawModel.name : rawModel.id, - group: typeof rawModel.owned_by === 'string' ? rawModel.owned_by : 'default', - providerId: this.provider.id, - isCustom: false, - supportedEndpointTypes, - endpointType: this.getDefaultEndpointType({ - supportedEndpointTypes, - type - }), - ...(typeof rawModel.description === 'string' - ? { description: rawModel.description } - : {}), - ...(type ? { type } : {}), - ...(contextLength !== undefined ? { contextLength } : {}), - ...(maxTokens !== undefined ? { maxTokens } : {}) - } - return model - }) - - await this.syncProviderManagedEndpointType(models) - return models - } finally { - clearTimeout(timeout) - } - } - - public override onProxyResolved(): void { - this.openaiChatDelegate.onProxyResolved() - this.openaiResponsesDelegate.onProxyResolved() - this.geminiDelegate.onProxyResolved() - this.anthropicDelegate.onProxyResolved() - } - - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - await this.fetchProviderModels() - return { isOk: true, errorMsg: null } - } catch (error) { - return { - isOk: false, - errorMsg: error instanceof Error ? error.message : String(error) - } - } - } - - public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { - const endpointType = this.resolveEndpointType(modelId) - - switch (endpointType) { - case 'anthropic': { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.summaryTitles(messages, modelId) - } - case 'gemini': - return this.geminiDelegate.summaryTitles(this.toGeminiMessages(messages), modelId) - case 'openai-response': - return this.openaiResponsesDelegate.summaryTitles(messages, modelId) - case 'image-generation': - return this.buildFallbackSummaryTitle(messages) - case 'openai': - default: - return this.openaiChatDelegate.summaryTitles(messages, modelId) - } - } - - public async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - const endpointType = this.resolveEndpointType(modelId) - - switch (endpointType) { - case 'anthropic': { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.completions(messages, modelId, temperature, maxTokens) - } - case 'gemini': - return this.geminiDelegate.completions( - this.toGeminiMessages(messages), - modelId, - temperature, - maxTokens - ) - case 'openai-response': - return this.openaiResponsesDelegate.completions(messages, modelId, temperature, maxTokens) - case 'image-generation': - return this.collectImageCompletion(messages, modelId, temperature, maxTokens) - case 'openai': - default: - return this.openaiChatDelegate.completions(messages, modelId, temperature, maxTokens) - } - } - - public async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - const endpointType = this.resolveEndpointType(modelId) - - switch (endpointType) { - case 'anthropic': { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.summaries(text, modelId, temperature, maxTokens) - } - case 'gemini': - return this.geminiDelegate.summaries(text, modelId, temperature, maxTokens) - case 'openai-response': - return this.openaiResponsesDelegate.summaries(text, modelId, temperature, maxTokens) - case 'image-generation': - return this.collectImageCompletion( - [{ role: 'user', content: text }], - modelId, - temperature, - maxTokens - ) - case 'openai': - default: - return this.openaiChatDelegate.summaries(text, modelId, temperature, maxTokens) - } - } - - public async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - const endpointType = this.resolveEndpointType(modelId) - - switch (endpointType) { - case 'anthropic': { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.generateText(prompt, modelId, temperature, maxTokens) - } - case 'gemini': - return this.geminiDelegate.generateText(prompt, modelId, temperature, maxTokens) - case 'openai-response': - return this.openaiResponsesDelegate.generateText(prompt, modelId, temperature, maxTokens) - case 'image-generation': - return this.collectImageCompletion( - [{ role: 'user', content: prompt }], - modelId, - temperature, - maxTokens - ) - case 'openai': - default: - return this.openaiChatDelegate.generateText(prompt, modelId, temperature, maxTokens) - } - } - - public async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - tools: MCPToolDefinition[] - ): AsyncGenerator { - const endpointType = this.resolveEndpointType(modelId) - - switch (endpointType) { - case 'anthropic': { - const delegate = await this.ensureAnthropicDelegateReady() - yield* delegate.coreStream(messages, modelId, modelConfig, temperature, maxTokens, tools) - return - } - case 'gemini': - yield* this.geminiDelegate.coreStream( - messages, - modelId, - modelConfig, - temperature, - maxTokens, - tools - ) - return - case 'openai-response': - yield* this.openaiResponsesDelegate.coreStream( - messages, - modelId, - modelConfig, - temperature, - maxTokens, - tools - ) - return - case 'image-generation': - yield* this.openaiChatDelegate.coreStream( - messages, - modelId, - this.buildImageModelConfig(modelId, modelConfig), - temperature, - maxTokens, - tools - ) - return - case 'openai': - default: - yield* this.openaiChatDelegate.coreStream( - messages, - modelId, - modelConfig, - temperature, - maxTokens, - tools - ) - return - } - } - - public async getEmbeddings(modelId: string, texts: string[]): Promise { - return this.openaiChatDelegate.getEmbeddings(modelId, texts) - } - - public async getDimensions(modelId: string): Promise { - return this.openaiChatDelegate.getDimensions(modelId) - } - - public async getKeyStatus(): Promise { - return this.openaiChatDelegate.getKeyStatus() - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/o3fanProvider.ts b/src/main/presenter/llmProviderPresenter/providers/o3fanProvider.ts deleted file mode 100644 index fc92712c4..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/o3fanProvider.ts +++ /dev/null @@ -1,105 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - MODEL_META, - ChatMessage, - IConfigPresenter -} from '@shared/presenter' -import { ModelType } from '@shared/model' -import { - resolveModelContextLength, - resolveModelFunctionCall, - resolveModelMaxTokens -} from '@shared/modelConfigDefaults' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import { providerDbLoader } from '../../configPresenter/providerDbLoader' -import { modelCapabilities } from '../../configPresenter/modelCapabilities' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -export class O3fanProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - protected async fetchOpenAIModels(): Promise { - const resolvedId = modelCapabilities.resolveProviderId(this.provider.id) || this.provider.id - const provider = providerDbLoader.getProvider(resolvedId) - if (!provider || !Array.isArray(provider.models)) { - return [] - } - - return provider.models.map((model) => { - const inputs = model.modalities?.input - const outputs = model.modalities?.output - const hasImageInput = Array.isArray(inputs) && inputs.includes('image') - const hasImageOutput = Array.isArray(outputs) && outputs.includes('image') - const modelType = hasImageOutput ? ModelType.ImageGeneration : ModelType.Chat - - return { - id: model.id, - name: model.display_name || model.name || model.id, - group: 'o3fan', - providerId: this.provider.id, - isCustom: false, - contextLength: resolveModelContextLength(model.limit?.context), - maxTokens: resolveModelMaxTokens(model.limit?.output), - vision: hasImageInput, - functionCall: resolveModelFunctionCall(model.tool_call), - reasoning: Boolean(model.reasoning?.supported), - enableSearch: Boolean(model.search?.supported), - type: modelType - } - }) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `You need to summarize the user's conversation into a title of no more than 10 words, with the title language matching the user's primary language, without using punctuation or other special symbols:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/ollamaProvider.ts b/src/main/presenter/llmProviderPresenter/providers/ollamaProvider.ts index d1a21e4c0..078c85510 100644 --- a/src/main/presenter/llmProviderPresenter/providers/ollamaProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/ollamaProvider.ts @@ -1,79 +1,65 @@ import { + ChatMessage, + IConfigPresenter, + LLM_EMBEDDING_ATTRS, LLM_PROVIDER, + LLMCoreStreamEvent, LLMResponse, - MODEL_META, - OllamaModel, - ProgressResponse, MCPToolDefinition, + MODEL_META, ModelConfig, - LLMCoreStreamEvent, - ChatMessage, - LLM_EMBEDDING_ATTRS, - IConfigPresenter + OllamaModel, + ProgressResponse } from '@shared/presenter' import { ModelType } from '@shared/model' import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS } from '@shared/modelConfigDefaults' -import { createStreamEvent } from '@shared/types/core/llm-events' import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider' -import { Ollama, Message, ShowResponse } from 'ollama' -import { EMBEDDING_TEST_KEY, isNormalized } from '@/utils/vector' +import { Ollama, ShowResponse } from 'ollama' +import { + runAiSdkCoreStream, + runAiSdkDimensions, + runAiSdkEmbeddings, + runAiSdkGenerateText, + type AiSdkRuntimeContext +} from '../aiSdk' import type { ProviderMcpRuntimePort } from '../runtimePorts' -// Define Ollama tool type -interface OllamaTool { - type: 'function' - function: { - name: string - description: string - parameters: { - type: 'object' - properties: { - [key: string]: { - type: string - description: string - enum?: string[] - } - } - required: string[] - } - } -} - export class OllamaProvider extends BaseLLMProvider { private ollama: Ollama + constructor( provider: LLM_PROVIDER, configPresenter: IConfigPresenter, mcpRuntime?: ProviderMcpRuntimePort ) { super(provider, configPresenter, mcpRuntime) + this.ollama = this.createOllamaClient() + this.init() + } + + private createOllamaClient(): Ollama { if (this.provider.apiKey) { - this.ollama = new Ollama({ + return new Ollama({ host: this.provider.baseUrl, headers: { Authorization: `Bearer ${this.provider.apiKey}` } }) - } else { - this.ollama = new Ollama({ - host: this.provider.baseUrl - }) } - this.init() - } - private getOllamaBaseUrl(): string { - const raw = this.provider.baseUrl?.trim() - return raw && raw.length > 0 ? raw.replace(/\/+$/, '') : 'http://localhost:11434' + return new Ollama({ + host: this.provider.baseUrl + }) } - private buildOllamaTraceHeaders(): Record { - const headers: Record = { - 'Content-Type': 'application/json', - ...this.defaultHeaders - } - if (this.provider.apiKey) { - headers.Authorization = `Bearer ${this.provider.apiKey}` + protected getAiSdkRuntimeContext(): AiSdkRuntimeContext { + return { + providerKind: 'ollama', + provider: this.provider, + configPresenter: this.configPresenter, + defaultHeaders: this.defaultHeaders, + buildLegacyFunctionCallPrompt: (tools) => this.getFunctionCallWrapPrompt(tools), + emitRequestTrace: (modelConfig, payload) => this.emitRequestTrace(modelConfig, payload), + supportsNativeTools: (_modelId, modelConfig) => modelConfig.functionCall === true } - return headers } private mergeCapabilities(...sources: Array): string[] { @@ -168,10 +154,10 @@ export class OllamaProvider extends BaseLLMProvider { } } - // Basic Provider functionality implementation + public onProxyResolved(): void {} + protected async fetchProviderModels(): Promise { try { - console.log('Ollama service check', this.ollama, this.provider) const [localModels, runningModels] = await Promise.all([ this.listModels(), this.listRunningModels() @@ -201,77 +187,28 @@ export class OllamaProvider extends BaseLLMProvider { return resolvedModels } catch (error) { console.error('Failed to fetch Ollama models:', error) - // Fallback to aggregated Provider DB curated list for Ollama - const dbModels = this.configPresenter.getDbProviderModels(this.provider.id).map((m) => ({ - id: m.id, - name: m.name, + return this.configPresenter.getDbProviderModels(this.provider.id).map((model) => ({ + id: model.id, + name: model.name, providerId: this.provider.id, - contextLength: m.contextLength, - maxTokens: m.maxTokens, + contextLength: model.contextLength, + maxTokens: model.maxTokens, isCustom: false, - group: m.group || 'default', + group: model.group || 'default', description: undefined, - vision: m.vision || false, - functionCall: m.functionCall || false, - reasoning: m.reasoning || false, - ...(m.type ? { type: m.type } : {}) + vision: model.vision || false, + functionCall: model.functionCall || false, + reasoning: model.reasoning || false, + ...(model.type ? { type: model.type } : {}) })) - return dbModels } } - // Helper method: format messages - private formatMessages(messages: ChatMessage[]): Message[] { - return messages.map((msg) => { - if (typeof msg.content === 'string') { - return { - role: msg.role, - content: msg.content - } - } else { - // Separate text and image content - const text = - msg.content && Array.isArray(msg.content) - ? msg.content - .filter((c) => c.type === 'text') - .map((c) => c.text) - .join('\n') - : '' - - const rawImages = - msg.content && Array.isArray(msg.content) - ? (msg.content - .filter((c) => c.type === 'image_url') - .map((c) => c.image_url?.url) - .filter(Boolean) as string[]) - : [] - - // Extract base64 data from data URIs (Ollama expects just the base64 string, not the full data URI) - const images: string[] = rawImages.map((imgUrl) => { - // If it's a data URI (data:image/...;base64,...), extract just the base64 part - if (imgUrl.startsWith('data:image') && imgUrl.includes('base64,')) { - return imgUrl.split(',')[1] - } - // For regular URLs, pass as-is - return imgUrl - }) - - return { - role: msg.role, - content: text, - ...(images.length > 0 && { images }) - } - } - }) - } - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { try { - // Try to get model list to check if Ollama service is available await this.ollama.list() return { isOk: true, errorMsg: null } } catch (error) { - console.error('Ollama service check failed:', error) return { isOk: false, errorMsg: `Unable to connect to Ollama service: ${(error as Error).message}` @@ -280,23 +217,17 @@ export class OllamaProvider extends BaseLLMProvider { } public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { - try { - const prompt = `${SUMMARY_TITLES_PROMPT}\n\n${messages.map((m) => `${m.role}: ${m.content}`).join('\n')}` - - const response = await this.ollama.generate({ - model: modelId, - prompt: prompt, - options: { - temperature: 0.3, - num_predict: 30 - } - }) + const prompt = `${SUMMARY_TITLES_PROMPT}\n\n${messages.map((m) => `${m.role}: ${m.content}`).join('\n')}` + const response = await runAiSdkGenerateText( + this.getAiSdkRuntimeContext(), + [{ role: 'user', content: prompt }], + modelId, + this.configPresenter.getModelConfig(modelId, this.provider.id), + 0.3, + 30 + ) - return response.response.trim() - } catch (error) { - console.error('Failed to generate title with Ollama:', error) - return 'New Conversation' - } + return response.content.trim() || 'New Conversation' } public async completions( @@ -305,64 +236,14 @@ export class OllamaProvider extends BaseLLMProvider { temperature?: number, maxTokens?: number ): Promise { - try { - const response = await this.ollama.chat({ - model: modelId, - messages: this.formatMessages(messages), - options: { - temperature: temperature ?? 0.7, - num_predict: maxTokens - } - }) - - const resultResp: LLMResponse = { - content: '' - } - - // Ollama may not provide complete token counts - if (response.prompt_eval_count !== undefined || response.eval_count !== undefined) { - resultResp.totalUsage = { - prompt_tokens: response.prompt_eval_count || 0, - completion_tokens: response.eval_count || 0, - total_tokens: (response.prompt_eval_count || 0) + (response.eval_count || 0) - } - } - - // 处理 thinking 字段 - const content = response.message?.content || '' - const thinking = response.message?.thinking || '' - - if (thinking) { - resultResp.reasoning_content = thinking - resultResp.content = content - } - // 处理标签(其他模型) - else if (content.includes('')) { - const thinkStart = content.indexOf('') - const thinkEnd = content.indexOf('') - - if (thinkEnd > thinkStart) { - // 提取reasoning_content - resultResp.reasoning_content = content.substring(thinkStart + 7, thinkEnd).trim() - - // 合并前后的普通内容 - const beforeThink = content.substring(0, thinkStart).trim() - const afterThink = content.substring(thinkEnd + 8).trim() - resultResp.content = [beforeThink, afterThink].filter(Boolean).join('\n') - } else { - // 如果没有找到配对的结束标签,将所有内容作为普通内容 - resultResp.content = content - } - } else { - // 没有特殊格式,所有内容作为普通内容 - resultResp.content = content - } - - return resultResp - } catch (error) { - console.error('Ollama completions failed:', error) - throw error - } + return runAiSdkGenerateText( + this.getAiSdkRuntimeContext(), + messages, + modelId, + this.configPresenter.getModelConfig(modelId, this.provider.id), + temperature, + maxTokens + ) } public async summaries( @@ -371,26 +252,14 @@ export class OllamaProvider extends BaseLLMProvider { temperature?: number, maxTokens?: number ): Promise { - try { - const prompt = `Please summarize the following content:\n\n${text}` - - const response = await this.ollama.generate({ - model: modelId, - prompt: prompt, - options: { - temperature: temperature ?? 0.5, - num_predict: maxTokens - } - }) - - return { - content: response.response, - reasoning_content: undefined - } - } catch (error) { - console.error('Ollama summaries failed:', error) - throw error - } + return runAiSdkGenerateText( + this.getAiSdkRuntimeContext(), + [{ role: 'user', content: `Please summarize the following content:\n\n${text}` }], + modelId, + this.configPresenter.getModelConfig(modelId, this.provider.id), + temperature ?? 0.5, + maxTokens + ) } public async generateText( @@ -399,24 +268,14 @@ export class OllamaProvider extends BaseLLMProvider { temperature?: number, maxTokens?: number ): Promise { - try { - const response = await this.ollama.generate({ - model: modelId, - prompt: prompt, - options: { - temperature: temperature ?? 0.7, - num_predict: maxTokens - } - }) - - return { - content: response.response, - reasoning_content: undefined - } - } catch (error) { - console.error('Ollama generate text failed:', error) - throw error - } + return runAiSdkGenerateText( + this.getAiSdkRuntimeContext(), + [{ role: 'user', content: prompt }], + modelId, + this.configPresenter.getModelConfig(modelId, this.provider.id), + temperature, + maxTokens + ) } public async suggestions( @@ -425,28 +284,18 @@ export class OllamaProvider extends BaseLLMProvider { temperature?: number, maxTokens?: number ): Promise { - try { - const prompt = `Based on the following context, generate 5 possible follow-up questions or suggestions:\n\n${context}` - - const response = await this.ollama.generate({ - model: modelId, - prompt: prompt, - options: { - temperature: temperature ?? 0.8, - num_predict: maxTokens || 200 - } - }) + const response = await this.generateText( + `Based on the following context, generate 5 possible follow-up questions or suggestions, one per line:\n\n${context}`, + modelId, + temperature ?? 0.8, + maxTokens ?? 200 + ) - // 简单处理返回的文本,按行分割,并过滤掉空行 - return response.response - .split('\n') - .map((line) => line.trim()) - .filter((line) => line && line.length > 0) - .slice(0, 5) // 最多返回5个建议 - } catch (error) { - console.error('Ollama suggestions failed:', error) - return [] - } + return response.content + .split('\n') + .map((line) => line.trim()) + .filter(Boolean) + .slice(0, 5) } private async attachModelInfo(model: OllamaModel): Promise { @@ -458,7 +307,6 @@ export class OllamaProvider extends BaseLLMProvider { const embedding_length = info?.[family + '.embedding_length'] ?? 512 const capabilities = showResponse.capabilities ?? ['chat'] - // Merge customConfig properties to model return { ...model, model_info: { @@ -468,7 +316,6 @@ export class OllamaProvider extends BaseLLMProvider { capabilities } } catch (error) { - // If showModelInfo fails, return the model with default info console.warn( `Failed to get info for model ${model.name}, using defaults:`, (error as Error).message @@ -484,12 +331,10 @@ export class OllamaProvider extends BaseLLMProvider { } } - // Ollama 特有的模型管理功能 public async listModels(): Promise { try { const response = await this.ollama.list() const models = response.models as unknown as OllamaModel[] - // FIXME: Merge model properties, optimize after ollama list API is improved return await Promise.all(models.map(async (model) => this.attachModelInfo(model))) } catch (error) { console.error('Failed to list Ollama models:', (error as Error).message) @@ -501,7 +346,6 @@ export class OllamaProvider extends BaseLLMProvider { try { const response = await this.ollama.ps() const runningModels = response.models as unknown as OllamaModel[] - // FIXME: Merge model properties, optimize after ollama list API is improved return await Promise.all(runningModels.map(async (model) => this.attachModelInfo(model))) } catch (error) { console.error('Failed to list running Ollama models:', (error as Error).message) @@ -521,9 +365,7 @@ export class OllamaProvider extends BaseLLMProvider { }) for await (const chunk of stream) { - if (onProgress) { - onProgress(chunk as ProgressResponse) - } + onProgress?.(chunk as ProgressResponse) } return true @@ -535,61 +377,15 @@ export class OllamaProvider extends BaseLLMProvider { public async showModelInfo(modelName: string): Promise { try { - const response = await this.ollama.show({ + return await this.ollama.show({ model: modelName }) - return response } catch (error) { console.error(`Failed to show Ollama model info for ${modelName}:`, (error as Error).message) throw error } } - // 辅助方法:将 MCP 工具转换为 Ollama 工具格式 - private async convertToOllamaTools(mcpTools: MCPToolDefinition[]): Promise { - const openAITools = - (await this.mcpRuntime?.mcpToolsToOpenAITools(mcpTools, this.provider.id)) ?? [] - return openAITools.map((rawTool) => { - const tool = rawTool as unknown as { - function: { - name: string - description?: string - parameters: { properties: Record; required?: string[] } - } - } - const properties = tool.function.parameters.properties || {} - const convertedProperties: Record< - string, - { type: string; description: string; enum?: string[] } - > = {} - - for (const [key, value] of Object.entries(properties)) { - if (typeof value === 'object' && value !== null) { - const param = value as { type: unknown; description: unknown; enum?: string[] } - convertedProperties[key] = { - type: String(param.type || 'string'), - description: String(param.description || ''), - ...(param.enum ? { enum: param.enum } : {}) - } - } - } - - return { - type: 'function' as const, - function: { - name: tool.function.name, - description: tool.function.description || '', - parameters: { - type: 'object' as const, - properties: convertedProperties, - required: tool.function.parameters.required || [] - } - } - } - }) - } - - // 实现BaseLLMProvider抽象方法 - 核心流处理 async *coreStream( messages: ChatMessage[], modelId: string, @@ -598,9 +394,8 @@ export class OllamaProvider extends BaseLLMProvider { maxTokens: number, mcpTools: MCPToolDefinition[] ): AsyncGenerator { - if (!modelId) throw new Error('Model ID is required') - // Ollama 不需要图片生成分支,直接处理聊天完成 - yield* this.handleChatCompletion( + yield* runAiSdkCoreStream( + this.getAiSdkRuntimeContext(), messages, modelId, modelConfig, @@ -610,734 +405,11 @@ export class OllamaProvider extends BaseLLMProvider { ) } - /////////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * 处理 Ollama 聊天补全模型请求的内部方法。 - * @param messages 聊天消息数组。 - * @param modelId 模型ID。 - * @param modelConfig 模型配置。 - * @param temperature 温度参数。 - * @param maxTokens 最大 token 数。 - * @param mcpTools MCP 工具定义数组。 - * @returns AsyncGenerator 流式事件。 - */ - private async *handleChatCompletion( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - try { - const tools = mcpTools || [] - const supportsFunctionCall = modelConfig?.functionCall || false - let processedMessages = this.formatMessages(messages) - - // 工具参数准备 - let ollamaTools: OllamaTool[] | undefined = undefined - if (tools.length > 0) { - if (supportsFunctionCall) { - // 支持原生函数调用,转换工具定义 - ollamaTools = await this.convertToOllamaTools(tools) - } else { - // 不支持原生函数调用,使用提示词包装 - processedMessages = this.prepareFunctionCallPrompt(processedMessages, tools) - // Ollama对于非原生支持通常情况下也不需要传递tools参数 - ollamaTools = undefined - } - } - - // Ollama聊天参数 - const chatParams = { - model: modelId, - messages: processedMessages, - options: { - temperature: temperature ?? 0.7, - num_predict: maxTokens - }, - stream: true as const, - ...(modelConfig?.reasoningEffort && { reasoning_effort: modelConfig.reasoningEffort }), - ...(supportsFunctionCall && ollamaTools && ollamaTools.length > 0 - ? { tools: ollamaTools } - : {}) - } - - await this.emitRequestTrace(modelConfig, { - endpoint: `${this.getOllamaBaseUrl()}/api/chat`, - headers: this.buildOllamaTraceHeaders(), - body: chatParams - }) - - // 创建流 - const stream = await this.ollama.chat(chatParams) - - // --- 状态变量 --- - type TagState = 'none' | 'start' | 'inside' | 'end' - let thinkState: TagState = 'none' - let funcState: TagState = 'none' - - let pendingBuffer = '' // 用于标签匹配和潜在文本输出的缓冲区 - let thinkBuffer = '' // 思考内容缓冲区 - let funcCallBuffer = '' // 非原生函数调用内容的缓冲区 - let codeBlockBuffer = '' // 代码块内容的缓冲区 - - const thinkStartMarker = '' - const thinkEndMarker = '' - const funcStartMarker = '' - const funcEndMarker = '' - - // 代码块标记变体 - const codeBlockMarkers = [ - '```tool_code', - '```tool', - '``` tool_code', - '``` tool', - '```function_call', - '``` function_call' - ] - const codeBlockEndMarker = '```' - - let isInCodeBlock = false - - // 用于跟踪原生工具调用 - const nativeToolCalls: Record< - string, - { name: string; arguments: string; completed?: boolean } - > = {} - let stopReason: LLMCoreStreamEvent['stop_reason'] = 'complete' - let toolUseDetected = false - let usage: - | { - prompt_tokens: number - completion_tokens: number - total_tokens: number - } - | undefined = undefined - - // --- 流处理循环 --- - for await (const chunk of stream) { - // 处理使用统计 - if (chunk.prompt_eval_count !== undefined || chunk.eval_count !== undefined) { - usage = { - prompt_tokens: chunk.prompt_eval_count || 0, - completion_tokens: chunk.eval_count || 0, - total_tokens: (chunk.prompt_eval_count || 0) + (chunk.eval_count || 0) - } - } - - // 处理原生工具调用 - if ( - supportsFunctionCall && - chunk.message?.tool_calls && - chunk.message.tool_calls.length > 0 - ) { - toolUseDetected = true - for (const toolCall of chunk.message.tool_calls) { - const toolId = toolCall.function?.name || `ollama-tool-${Date.now()}` - if (!nativeToolCalls[toolId]) { - nativeToolCalls[toolId] = { - name: toolCall.function?.name || '', - arguments: JSON.stringify(toolCall.function?.arguments || {}) - } - - // 发送工具调用开始事件 - yield createStreamEvent.toolCallStart(toolId, toolCall.function?.name || '') - - // 发送工具调用参数块事件 - yield createStreamEvent.toolCallChunk( - toolId, - JSON.stringify(toolCall.function?.arguments || {}) - ) - - // 发送工具调用结束事件 - yield createStreamEvent.toolCallEnd( - toolId, - JSON.stringify(toolCall.function?.arguments || {}) - ) - } - } - - stopReason = 'tool_use' - continue - } - - // 处理 thinking 字段 - const currentThinking = chunk.message?.thinking || '' - if (currentThinking) { - yield createStreamEvent.reasoning(currentThinking) - } - - // 获取当前内容 - const currentContent = chunk.message?.content || '' - if (!currentContent) continue - - // 逐字符处理 - for (const char of currentContent) { - pendingBuffer += char - let processedChar = false // 标记字符是否被状态逻辑处理 - - // --- 处理代码块 --- - if (isInCodeBlock) { - codeBlockBuffer += char - - // 检查代码块结束 - if (codeBlockBuffer.endsWith(codeBlockEndMarker)) { - isInCodeBlock = false - const codeContent = codeBlockBuffer - .substring(0, codeBlockBuffer.length - codeBlockEndMarker.length) - .trim() - - try { - // 尝试解析JSON - let parsedCall - try { - // 移除可能的语言标识和开头的空白 - const cleanContent = codeContent.replace(/^tool_code\s*/i, '').trim() - parsedCall = JSON.parse(cleanContent) - } catch { - // 尝试修复通用JSON格式问题 - const cleanContent = codeContent.replace(/^tool_code\s*/i, '').trim() - const repaired = cleanContent - .replace(/,\s*}/g, '}') // 移除对象末尾的逗号 - .replace(/,\s*\]/g, ']') // 移除数组末尾的逗号 - .replace(/(['"])?([a-zA-Z0-9_]+)(['"])?:/g, '"$2":') // 确保所有键都有双引号 - - parsedCall = JSON.parse(repaired) - } - - // 提取函数名和参数 - let functionName, functionArgs - - if (parsedCall.function_call && typeof parsedCall.function_call === 'object') { - functionName = parsedCall.function_call.name - functionArgs = parsedCall.function_call.arguments - } else if (parsedCall.name && parsedCall.arguments !== undefined) { - functionName = parsedCall.name - functionArgs = parsedCall.arguments - } else if ( - parsedCall.function && - typeof parsedCall.function === 'object' && - parsedCall.function.name - ) { - functionName = parsedCall.function.name - functionArgs = parsedCall.function.arguments - } else { - throw new Error('Unable to recognize function call format from code block') - } - - // 确保参数是字符串 - if (typeof functionArgs !== 'string') { - functionArgs = JSON.stringify(functionArgs) - } - - // 生成唯一ID - const id = parsedCall.id || `ollama-tool-${Date.now()}` - - // 发送工具调用 - toolUseDetected = true - yield { - type: 'tool_call_start', - tool_call_id: id, - tool_call_name: functionName - } - - yield { - type: 'tool_call_chunk', - tool_call_id: id, - tool_call_arguments_chunk: functionArgs - } - - yield { - type: 'tool_call_end', - tool_call_id: id, - tool_call_arguments_complete: functionArgs - } - - stopReason = 'tool_use' - } catch { - // 解析失败,将内容作为普通文本输出 - yield { - type: 'text', - content: '```tool_code\n' + codeContent + '\n```' - } - } - - // 重置状态和缓冲区 - codeBlockBuffer = '' - pendingBuffer = '' - processedChar = true - } - - continue - } - if (usage) { - yield createStreamEvent.usage(usage) - } - - // --- 思考标签处理 --- - if (thinkState === 'inside') { - if (pendingBuffer.endsWith(thinkEndMarker)) { - thinkState = 'none' - if (thinkBuffer) { - yield createStreamEvent.reasoning(thinkBuffer) - thinkBuffer = '' - } - pendingBuffer = '' - processedChar = true - } else if (thinkEndMarker.startsWith(pendingBuffer)) { - thinkState = 'end' - processedChar = true - } else if (pendingBuffer.length >= thinkEndMarker.length) { - const charsToYield = pendingBuffer.slice(0, -thinkEndMarker.length + 1) - if (charsToYield) { - thinkBuffer += charsToYield - yield createStreamEvent.reasoning(charsToYield) - } - pendingBuffer = pendingBuffer.slice(-thinkEndMarker.length + 1) - if (thinkEndMarker.startsWith(pendingBuffer)) { - thinkState = 'end' - } else { - thinkBuffer += pendingBuffer - yield createStreamEvent.reasoning(pendingBuffer) - pendingBuffer = '' - thinkState = 'inside' - } - processedChar = true - } else { - thinkBuffer += char - yield createStreamEvent.reasoning(char) - pendingBuffer = '' - processedChar = true - } - } else if (thinkState === 'end') { - if (pendingBuffer.endsWith(thinkEndMarker)) { - thinkState = 'none' - if (thinkBuffer) { - yield createStreamEvent.reasoning(thinkBuffer) - thinkBuffer = '' - } - pendingBuffer = '' - processedChar = true - } else if (!thinkEndMarker.startsWith(pendingBuffer)) { - const failedTagChars = pendingBuffer - thinkBuffer += failedTagChars - yield createStreamEvent.reasoning(failedTagChars) - pendingBuffer = '' - thinkState = 'inside' - processedChar = true - } else { - processedChar = true - } - } - - // --- 函数调用标签处理 --- - else if ( - !supportsFunctionCall && - tools.length > 0 && - (funcState === 'inside' || funcState === 'end') - ) { - processedChar = true // 假设已处理,除非下面的逻辑改变状态 - if (funcState === 'inside') { - if (pendingBuffer.endsWith(funcEndMarker)) { - funcState = 'none' - funcCallBuffer += pendingBuffer.slice(0, -funcEndMarker.length) - pendingBuffer = '' - toolUseDetected = true - - const parsedCalls = this.parseFunctionCalls( - `${funcStartMarker}${funcCallBuffer}${funcEndMarker}`, - `non-native-${this.provider.id}` - ) - for (const parsedCall of parsedCalls) { - yield { - type: 'tool_call_start', - tool_call_id: parsedCall.id, - tool_call_name: parsedCall.function.name - } - yield { - type: 'tool_call_chunk', - tool_call_id: parsedCall.id, - tool_call_arguments_chunk: parsedCall.function.arguments - } - yield { - type: 'tool_call_end', - tool_call_id: parsedCall.id, - tool_call_arguments_complete: parsedCall.function.arguments - } - } - funcCallBuffer = '' - } else if (funcEndMarker.startsWith(pendingBuffer)) { - funcState = 'end' - } else if (pendingBuffer.length >= funcEndMarker.length) { - const charsToAdd = pendingBuffer.slice(0, -funcEndMarker.length + 1) - funcCallBuffer += charsToAdd - pendingBuffer = pendingBuffer.slice(-funcEndMarker.length + 1) - if (funcEndMarker.startsWith(pendingBuffer)) { - funcState = 'end' - } else { - funcCallBuffer += pendingBuffer - pendingBuffer = '' - funcState = 'inside' - } - } else { - funcCallBuffer += char - pendingBuffer = '' - } - } else { - // funcState === 'end' - if (pendingBuffer.endsWith(funcEndMarker)) { - funcState = 'none' - pendingBuffer = '' - toolUseDetected = true - - const parsedCalls = this.parseFunctionCalls( - `${funcStartMarker}${funcCallBuffer}${funcEndMarker}`, - `non-native-${this.provider.id}` - ) - for (const parsedCall of parsedCalls) { - yield { - type: 'tool_call_start', - tool_call_id: parsedCall.id, - tool_call_name: parsedCall.function.name - } - yield { - type: 'tool_call_chunk', - tool_call_id: parsedCall.id, - tool_call_arguments_chunk: parsedCall.function.arguments - } - yield { - type: 'tool_call_end', - tool_call_id: parsedCall.id, - tool_call_arguments_complete: parsedCall.function.arguments - } - } - funcCallBuffer = '' - } else if (!funcEndMarker.startsWith(pendingBuffer)) { - funcCallBuffer += pendingBuffer - pendingBuffer = '' - funcState = 'inside' - } - } - } - - // --- 处理一般文本/标签检测(当不在任何标签内时)--- - if (!processedChar) { - let potentialThink = thinkStartMarker.startsWith(pendingBuffer) - let potentialFunc = - !supportsFunctionCall && tools.length > 0 && funcStartMarker.startsWith(pendingBuffer) - const matchedThink = pendingBuffer.endsWith(thinkStartMarker) - const matchedFunc = - !supportsFunctionCall && tools.length > 0 && pendingBuffer.endsWith(funcStartMarker) - - // 检查代码块标记 - let codeBlockDetected = false - for (const marker of codeBlockMarkers) { - if (pendingBuffer.endsWith(marker)) { - codeBlockDetected = true - break - } - } - - // --- 首先处理完整匹配 --- - if (matchedThink) { - const textBefore = pendingBuffer.slice(0, -thinkStartMarker.length) - if (textBefore) { - yield createStreamEvent.text(textBefore) - } - thinkState = 'inside' - funcState = 'none' // 重置其他状态 - pendingBuffer = '' - } else if (matchedFunc) { - const textBefore = pendingBuffer.slice(0, -funcStartMarker.length) - if (textBefore) { - yield createStreamEvent.text(textBefore) - } - funcState = 'inside' - thinkState = 'none' // 重置其他状态 - pendingBuffer = '' - } else if (codeBlockDetected) { - // 找到代码块开始标记,提取前面的文本 - let markerText = '' - for (const marker of codeBlockMarkers) { - if (pendingBuffer.endsWith(marker)) { - markerText = marker - break - } - } - - const textBefore = pendingBuffer.slice(0, -markerText.length) - if (textBefore) { - yield createStreamEvent.text(textBefore) - } - - isInCodeBlock = true - codeBlockBuffer = '' - pendingBuffer = '' - } - // --- 处理部分匹配(继续累积)--- - else if (potentialThink || potentialFunc) { - // 如果可能匹配任一标签,只保留缓冲区并等待更多字符 - thinkState = potentialThink ? 'start' : 'none' - funcState = potentialFunc ? 'start' : 'none' - } - // --- 处理不匹配/失败 --- - else if (pendingBuffer.length > 0) { - // 缓冲区不以'<'开头,或以'<'开头但不再匹配任何标签的开始 - const charToYield = pendingBuffer[0] - yield createStreamEvent.text(charToYield) - pendingBuffer = pendingBuffer.slice(1) - // 使用缩短的缓冲区立即重新评估潜在匹配 - potentialThink = - pendingBuffer.length > 0 && thinkStartMarker.startsWith(pendingBuffer) - potentialFunc = - pendingBuffer.length > 0 && - !supportsFunctionCall && - tools.length > 0 && - funcStartMarker.startsWith(pendingBuffer) - thinkState = potentialThink ? 'start' : 'none' - funcState = potentialFunc ? 'start' : 'none' - } - } - } // 字符循环结束 - } // 块循环结束 - - // --- 完成处理 --- - - // 输出缓冲区中剩余的文本 - if (pendingBuffer) { - // 根据最终状态决定如何输出 - if (thinkState === 'inside' || thinkState === 'end') { - yield { type: 'reasoning', reasoning_content: pendingBuffer } - thinkBuffer += pendingBuffer - } else if (funcState === 'inside' || funcState === 'end') { - // 将剩余内容添加到函数缓冲区 - 稍后处理 - funcCallBuffer += pendingBuffer - } else { - yield createStreamEvent.text(pendingBuffer) - } - pendingBuffer = '' - } - - // 处理不完整的非原生函数调用 - if (funcCallBuffer) { - const potentialContent = `${funcStartMarker}${funcCallBuffer}` - try { - const parsedCalls = this.parseFunctionCalls( - potentialContent, - `non-native-incomplete-${this.provider.id}` - ) - if (parsedCalls.length > 0) { - toolUseDetected = true - for (const parsedCall of parsedCalls) { - yield { - type: 'tool_call_start', - tool_call_id: parsedCall.id + '-incomplete', - tool_call_name: parsedCall.function.name - } - yield { - type: 'tool_call_chunk', - tool_call_id: parsedCall.id + '-incomplete', - tool_call_arguments_chunk: parsedCall.function.arguments - } - yield { - type: 'tool_call_end', - tool_call_id: parsedCall.id + '-incomplete', - tool_call_arguments_complete: parsedCall.function.arguments - } - } - } else { - // 如果解析失败或没有结果,将缓冲区作为文本输出 - yield createStreamEvent.text(potentialContent) - } - } catch (e) { - console.error('Error parsing incomplete function call buffer:', e) - yield { type: 'text', content: potentialContent } - } - funcCallBuffer = '' - } - - // 处理不完整的代码块 - if (isInCodeBlock && codeBlockBuffer) { - yield { - type: 'text', - content: '```' + codeBlockBuffer - } - } - - // 最终检查和发出原生工具调用 - if (supportsFunctionCall && toolUseDetected) { - for (const toolId in nativeToolCalls) { - const tool = nativeToolCalls[toolId] - if (tool.name && tool.arguments && !tool.completed) { - try { - JSON.parse(tool.arguments) // 检查有效性 - yield { - type: 'tool_call_end', - tool_call_id: toolId, - tool_call_arguments_complete: tool.arguments - } - } catch (e) { - console.error( - `[handleChatCompletion] Tool ${toolId} parameter parsing error: ${tool.arguments}`, - e - ) - yield { - type: 'tool_call_end', - tool_call_id: toolId, - tool_call_arguments_complete: tool.arguments - } - } - } - } - } - - // 记录状态警告 - if (thinkState !== 'none') console.warn(`Stream ended in thinkState: ${thinkState}`) - if (funcState !== 'none') console.warn(`Stream ended in funcState: ${funcState}`) - - // 输出使用情况 - if (usage) { - yield createStreamEvent.usage(usage) - } - - // 如果检测到工具使用,则覆盖停止原因 - const finalStopReason = toolUseDetected ? 'tool_use' : stopReason - yield createStreamEvent.stop(finalStopReason) - } catch (error: unknown) { - yield createStreamEvent.error(error instanceof Error ? error.message : String(error)) - yield createStreamEvent.stop('error') - } - } - - // 用于包装不支持函数调用的提示词 - private prepareFunctionCallPrompt(messages: Message[], mcpTools: MCPToolDefinition[]): Message[] { - // 创建消息副本 - const result = [...messages] - - const functionCallPrompt = this.getFunctionCallWrapPrompt(mcpTools) - const userMessageIndex = result.findLastIndex((message) => message.role === 'user') - - if (userMessageIndex !== -1) { - const userMessage = result[userMessageIndex] - // 添加提示词到用户消息 - result[userMessageIndex] = { - ...userMessage, - content: `${functionCallPrompt}\n\n${userMessage.content || ''}` - } - } - - return result - } - - // 解析函数调用标签 - protected parseFunctionCalls( - response: string, - fallbackIdPrefix: string = 'ollama-tool' - ): Array<{ id: string; type: string; function: { name: string; arguments: string } }> { - try { - // 使用非贪婪模式匹配function_call标签对 - const functionCallMatches = response.match(/([\s\S]*?)<\/function_call>/gs) - if (!functionCallMatches) { - return [] - } - - const toolCalls = functionCallMatches - .map((match, index) => { - const content = match.replace(/<\/?function_call>/g, '').trim() - try { - // 尝试解析JSON - let parsedCall - try { - parsedCall = JSON.parse(content) - } catch { - // 尝试修复格式问题 - const repaired = content - .replace(/,\s*}/g, '}') // 移除对象末尾的逗号 - .replace(/,\s*\]/g, ']') // 移除数组末尾的逗号 - .replace(/(['"])?([a-zA-Z0-9_]+)(['"])?:/g, '"$2":') // 确保所有键都有双引号 - - parsedCall = JSON.parse(repaired) - } - - // 提取函数名和参数 - let functionName, functionArgs - - if (parsedCall.function_call && typeof parsedCall.function_call === 'object') { - functionName = parsedCall.function_call.name - functionArgs = parsedCall.function_call.arguments - } else if (parsedCall.name && parsedCall.arguments !== undefined) { - functionName = parsedCall.name - functionArgs = parsedCall.arguments - } else if ( - parsedCall.function && - typeof parsedCall.function === 'object' && - parsedCall.function.name - ) { - functionName = parsedCall.function.name - functionArgs = parsedCall.function.arguments - } else { - return null - } - - // 确保参数是字符串 - if (typeof functionArgs !== 'string') { - functionArgs = JSON.stringify(functionArgs) - } - - // 生成唯一ID - const id = parsedCall.id || `${fallbackIdPrefix}-${index}-${Date.now()}` - - return { - id: String(id), - type: 'function', - function: { - name: String(functionName), - arguments: functionArgs - } - } - } catch { - return null - } - }) - .filter((call) => call !== null) as Array<{ - id: string - type: string - function: { name: string; arguments: string } - }> - - return toolCalls - } catch { - return [] - } - } - - public onProxyResolved(): void { - console.log('ollama onProxyResolved') - } - async getEmbeddings(modelId: string, texts: string[]): Promise { - // Ollama embedding API: 只支持单条文本 - const results: number[][] = [] - for (const text of texts) { - const resp = await this.ollama.embeddings({ - model: modelId, - prompt: text - }) - if (resp && Array.isArray(resp.embedding)) { - results.push(resp.embedding) - } else { - results.push([]) - } - } - return results + return runAiSdkEmbeddings(this.getAiSdkRuntimeContext(), modelId, texts) } async getDimensions(modelId: string): Promise { - const res = await this.getEmbeddings(modelId, [EMBEDDING_TEST_KEY]) - return { - dimensions: res[0].length, - normalized: isNormalized(res[0]) - } + return runAiSdkDimensions(this.getAiSdkRuntimeContext(), modelId) } } diff --git a/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts b/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts deleted file mode 100644 index 039e19397..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider.ts +++ /dev/null @@ -1,2078 +0,0 @@ -import { EMBEDDING_TEST_KEY, isNormalized } from '@/utils/vector' -import { - LLM_PROVIDER, - LLMResponse, - MODEL_META, - MCPToolDefinition, - LLMCoreStreamEvent, - ModelConfig, - ChatMessage, - LLM_EMBEDDING_ATTRS, - IConfigPresenter -} from '@shared/presenter' -import { ApiEndpointType } from '@shared/model' -import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS } from '@shared/modelConfigDefaults' -import { createStreamEvent } from '@shared/types/core/llm-events' -import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider' -import OpenAI, { AzureOpenAI } from 'openai' -import { - ChatCompletionContentPart, - ChatCompletionContentPartText, - ChatCompletionMessage, - ChatCompletionMessageParam -} from 'openai/resources' -import { presenter } from '@/presenter' -import { eventBus, SendTarget } from '@/eventbus' -import { NOTIFICATION_EVENTS } from '@/events' -import { jsonrepair } from 'jsonrepair' -import { app } from 'electron' -import path from 'path' -import fs from 'fs' -import sharp from 'sharp' -import { proxyConfig } from '../../proxyConfig' -import { modelCapabilities } from '../../configPresenter/modelCapabilities' -import { ProxyAgent } from 'undici' -import type { ProviderMcpRuntimePort } from '../runtimePorts' -import { - applyOpenAIChatExplicitCacheBreakpoint, - applyOpenAIPromptCacheKey, - resolvePromptCachePlan -} from '../promptCacheStrategy' - -const OPENAI_REASONING_MODELS = [ - 'o4-mini', - 'o1-pro', - 'o3', - 'o3-pro', - 'o3-mini', - 'o3-preview', - 'o1-mini', - 'o1-pro', - 'o1-preview', - 'o1', - 'gpt-5', - 'gpt-5-mini', - 'gpt-5-nano', - 'gpt-5-chat' -] -const OPENAI_IMAGE_GENERATION_MODELS = ['gpt-4o-all', 'gpt-4o-image'] -const OPENAI_IMAGE_GENERATION_MODEL_PREFIXES = ['dall-e-', 'gpt-image-'] -const isOpenAIImageGenerationModel = (modelId: string): boolean => - OPENAI_IMAGE_GENERATION_MODELS.includes(modelId) || - OPENAI_IMAGE_GENERATION_MODEL_PREFIXES.some((prefix) => modelId.startsWith(prefix)) - -// Add supported image size constants -const SUPPORTED_IMAGE_SIZES = { - SQUARE: '1024x1024', - LANDSCAPE: '1536x1024', - PORTRAIT: '1024x1536' -} as const - -// Add list of models with configurable sizes -const SIZE_CONFIGURABLE_MODELS = ['gpt-image-1', 'gpt-4o-image', 'gpt-4o-all'] - -export function normalizeExtractedImageText(content: string): string { - const normalized = content - .replace(/\r\n/g, '\n') - .replace(/\n\s*\n/g, '\n') - .trim() - if (!normalized) { - return '' - } - - const semanticText = normalized.replace(/[\`*_~!\[\]\(\)]/g, '').trim() - return semanticText.length > 0 ? normalized : '' -} - -function getOpenAIChatCachedTokens(usage: unknown): number | undefined { - return getOpenAIChatUsageDetail(usage, 'cached_tokens') -} - -function getOpenAIChatCacheWriteTokens(usage: unknown): number | undefined { - return getOpenAIChatUsageDetail(usage, 'cache_write_tokens') -} - -function getOpenAIChatUsageDetail( - usage: unknown, - key: 'cached_tokens' | 'cache_write_tokens' -): number | undefined { - if (!usage || typeof usage !== 'object') { - return undefined - } - - const promptTokensDetails = (usage as { prompt_tokens_details?: unknown }).prompt_tokens_details - const inputTokensDetails = (usage as { input_tokens_details?: unknown }).input_tokens_details - const promptCachedTokens = - promptTokensDetails && typeof promptTokensDetails === 'object' - ? (promptTokensDetails as Record)[key] - : undefined - const inputCachedTokens = - inputTokensDetails && typeof inputTokensDetails === 'object' - ? (inputTokensDetails as Record)[key] - : undefined - const cachedTokens = - typeof promptCachedTokens === 'number' ? promptCachedTokens : inputCachedTokens - return typeof cachedTokens === 'number' && Number.isFinite(cachedTokens) - ? cachedTokens - : undefined -} - -export class OpenAICompatibleProvider extends BaseLLMProvider { - protected openai!: OpenAI - protected isNoModelsApi: boolean = false - // Add blacklist of providers that don't support OpenAI standard interface - private static readonly NO_MODELS_API_LIST: string[] = [] - - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - this.createOpenAIClient() - if (OpenAICompatibleProvider.NO_MODELS_API_LIST.includes(this.provider.id.toLowerCase())) { - this.isNoModelsApi = true - } - this.init() - } - - private supportsEffortParameter(modelId: string): boolean { - return modelCapabilities.supportsReasoningEffort(this.getCapabilityProviderId(), modelId) - } - - private supportsVerbosityParameter(modelId: string): boolean { - return modelCapabilities.supportsVerbosity(this.getCapabilityProviderId(), modelId) - } - - private resolveTraceAuthToken(): string { - return this.provider.oauthToken || this.provider.apiKey || 'MISSING_API_KEY' - } - - private buildChatCompletionsTraceHeaders(): Record { - const headers: Record = { - 'Content-Type': 'application/json', - ...this.defaultHeaders - } - - if (this.provider.id === 'azure-openai') { - headers['api-key'] = this.resolveTraceAuthToken() - } else { - headers.Authorization = `Bearer ${this.resolveTraceAuthToken()}` - } - - return headers - } - - private buildChatCompletionsEndpoint(): string { - const baseUrl = (this.provider.baseUrl || 'https://api.openai.com/v1').replace(/\/+$/, '') - return `${baseUrl}/chat/completions` - } - - private getEffectiveApiEndpoint(modelId: string): ApiEndpointType { - const modelConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) - - if (modelConfig?.apiEndpoint) { - return modelConfig.apiEndpoint - } - - if (isOpenAIImageGenerationModel(modelId)) { - return ApiEndpointType.Image - } - - return ApiEndpointType.Chat - } - - protected createOpenAIClient(): void { - // Get proxy configuration - const proxyUrl = proxyConfig.getProxyUrl() - const fetchOptions: { dispatcher?: ProxyAgent } = {} - - if (proxyUrl) { - console.log(`[OpenAI Compatible Provider] Using proxy: ${proxyUrl}`) - const proxyAgent = new ProxyAgent(proxyUrl) - fetchOptions.dispatcher = proxyAgent - } - - // Check if this is official OpenAI or Azure OpenAI - const isOfficialOpenAI = this.isOfficialOpenAIService() - const isAzureOpenAI = this.provider.id === 'azure-openai' - - // Only use custom fetch for third-party services to avoid triggering 403 - // Keep original behavior for official OpenAI and Azure OpenAI for best compatibility - const shouldUseCleanFetch = !isOfficialOpenAI && !isAzureOpenAI - const customFetch = shouldUseCleanFetch ? this.createCleanFetch() : undefined - - if (isAzureOpenAI) { - try { - const apiVersion = this.configPresenter.getSetting('azureApiVersion') - const azureConfig: any = { - apiKey: this.provider.apiKey, - baseURL: this.provider.baseUrl, - apiVersion: apiVersion || '2024-02-01', - defaultHeaders: { - ...this.defaultHeaders - } - } - - // Use fetchOptions for proxy (original behavior for Azure) - if (fetchOptions.dispatcher) { - azureConfig.fetchOptions = fetchOptions - } - - this.openai = new AzureOpenAI(azureConfig) - } catch (e) { - console.warn('create azure openai failed', e) - } - } else { - const openaiConfig: any = { - apiKey: this.provider.apiKey, - baseURL: this.provider.baseUrl, - defaultHeaders: { - ...this.defaultHeaders - } - } - - if (customFetch) { - // Third-party service: use custom fetch to avoid 403 - openaiConfig.fetch = customFetch - // Also apply proxy via fetchOptions for third-party services - if (fetchOptions.dispatcher) { - openaiConfig.fetchOptions = fetchOptions - } - console.log( - `[OpenAI Compatible Provider] Using custom fetch for third-party service: ${this.provider.baseUrl}` - ) - } else { - // Official OpenAI: use original behavior with fetchOptions - if (fetchOptions.dispatcher) { - openaiConfig.fetchOptions = fetchOptions - } - console.log(`[OpenAI Compatible Provider] Using original fetch for official OpenAI`) - } - - this.openai = new OpenAI(openaiConfig) - } - } - - /** - * Check if this is the official OpenAI service by provider ID - */ - private isOfficialOpenAIService(): boolean { - return this.provider.id === 'openai' - } - - /** - * Creates a custom fetch function that removes OpenAI SDK headers that may trigger 403 - * This ensures all OpenAI SDK requests (including streaming) use clean headers - */ - private createCleanFetch() { - return async (url: RequestInfo | URL, init?: RequestInit): Promise => { - // Create a copy of init to avoid modifying the original - const cleanInit = { ...init } - - if (cleanInit.headers) { - // Convert headers to a plain object for easier manipulation - const headers = new Headers(cleanInit.headers) - const cleanHeaders: Record = {} - - // Only keep essential headers, remove SDK-specific ones that trigger 403 - const allowedHeaders = [ - 'authorization', - 'content-type', - 'accept', - 'http-referer', - 'x-title' - ] - - headers.forEach((value, key) => { - const lowerKey = key.toLowerCase() - // Keep only allowed headers and avoid X-Stainless-* headers - if ( - allowedHeaders.includes(lowerKey) || - (!lowerKey.startsWith('x-stainless-') && - !lowerKey.includes('user-agent') && - !lowerKey.includes('openai-')) - ) { - cleanHeaders[key] = value - } - }) - - // Ensure we have Authorization header - if (!cleanHeaders['Authorization'] && !cleanHeaders['authorization']) { - cleanHeaders['Authorization'] = `Bearer ${this.provider.apiKey}` - } - - // Add our default headers - Object.assign(cleanHeaders, this.defaultHeaders) - - cleanInit.headers = cleanHeaders - } - - // Use regular fetch - proxy is already handled by OpenAI SDK's fetchOptions - return fetch(url, cleanInit) - } - } - - public onProxyResolved(): void { - this.createOpenAIClient() - } - - // Implement abstract method fetchProviderModels from BaseLLMProvider - protected async fetchProviderModels(options?: { timeout: number }): Promise { - // Check if provider is in blacklist - if (this.isNoModelsApi) { - // console.log(`Provider ${this.provider.name} does not support OpenAI models API`) - return this.models - } - return this.fetchOpenAIModels(options) - } - - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - // Now using the clean fetch function via OpenAI SDK - const response = await this.openai.models.list(options) - return response.data.map((model) => ({ - id: model.id, - name: model.id, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, - maxTokens: DEFAULT_MODEL_MAX_TOKENS - })) - } - - /** - * User messages: Upper layer will insert image_url based on whether vision exists - * Assistant messages: Need to judge and convert images to correct context, as models can be switched - * Tool calls and tool responses: - * - If supportsFunctionCall=true: Use standard OpenAI format (tool_calls + role:tool) - * - If supportsFunctionCall=false: Convert to mock user messages with function_call_record format - * @param messages - Chat messages array - * @param supportsFunctionCall - Whether the model supports native function calling - * @returns Formatted messages for OpenAI API - */ - protected formatMessages( - messages: ChatMessage[], - supportsFunctionCall: boolean = false - ): ChatCompletionMessageParam[] { - // console.log('formatMessages', messages) - const result: ChatCompletionMessageParam[] = [] - // Track pending tool calls for non-FC models (to pair with tool responses) - const pendingToolCalls: Map< - string, - { name: string; arguments: string; assistantContent?: string } - > = new Map() - const pendingToolCallOrder: string[] = [] - - // Track expected tool_call_ids for native function calling models - const pendingNativeToolCallIds: string[] = [] - const pendingNativeToolCallSet: Set = new Set() - - const serializeContent = (content: unknown): string => { - if (content === undefined) return '' - if (typeof content === 'string') return content - return JSON.stringify(content) - } - - const enqueueNativeToolCallId = (toolCallId?: string) => { - if (!toolCallId) return - pendingNativeToolCallIds.push(toolCallId) - pendingNativeToolCallSet.add(toolCallId) - } - - const consumeNativeToolCallId = (preferredId?: string): string | undefined => { - if (preferredId && pendingNativeToolCallSet.has(preferredId)) { - pendingNativeToolCallSet.delete(preferredId) - const idx = pendingNativeToolCallIds.indexOf(preferredId) - if (idx !== -1) pendingNativeToolCallIds.splice(idx, 1) - return preferredId - } - - while (pendingNativeToolCallIds.length > 0) { - const candidate = pendingNativeToolCallIds.shift() - if (!candidate) continue - if (!pendingNativeToolCallSet.has(candidate)) continue - pendingNativeToolCallSet.delete(candidate) - return candidate - } - - return undefined - } - - const snapshotPendingNativeToolCallIds = () => - pendingNativeToolCallIds.filter((id) => pendingNativeToolCallSet.has(id)) - - const removePendingMockToolCallId = (toolCallId: string) => { - pendingToolCalls.delete(toolCallId) - const idx = pendingToolCallOrder.indexOf(toolCallId) - if (idx !== -1) pendingToolCallOrder.splice(idx, 1) - } - - const getPendingMockToolCallEntries = () => - pendingToolCallOrder - .map((id) => { - const meta = pendingToolCalls.get(id) - if (!meta) return undefined - return { id, meta } - }) - .filter( - ( - entry - ): entry is { - id: string - meta: { name: string; arguments: string; assistantContent?: string } - } => Boolean(entry) - ) - - const pushMockToolResponse = ( - toolCallId: string, - pendingCall: { name: string; arguments: string; assistantContent?: string }, - responseContent: string - ) => { - let argsObj - try { - argsObj = - typeof pendingCall.arguments === 'string' - ? JSON.parse(pendingCall.arguments) - : pendingCall.arguments - } catch { - argsObj = {} - } - - const mockRecord = { - function_call_record: { - name: pendingCall.name, - arguments: argsObj, - response: responseContent - } - } - - result.push({ - role: 'user', - content: `${JSON.stringify(mockRecord)}` - } as ChatCompletionMessageParam) - - removePendingMockToolCallId(toolCallId) - } - - for (let i = 0; i < messages.length; i++) { - const msg = messages[i] - - // Handle basic message structure - const baseMessage: Partial = { - role: msg.role as 'system' | 'user' | 'assistant' | 'tool' - } - - // Handle content conversion to string for non-user messages - if (msg.content !== undefined && msg.role !== 'user') { - if (typeof msg.content === 'string') { - baseMessage.content = msg.content - } else if (Array.isArray(msg.content)) { - // Handle multimodal content arrays - const textParts: string[] = [] - for (const part of msg.content) { - if (part.type === 'text' && part.text) { - textParts.push(part.text) - } - if (part.type === 'image_url' && part.image_url?.url) { - textParts.push(`image: ${part.image_url.url}`) - } - } - baseMessage.content = textParts.join('\n') - } - } - - // Handle user messages (keep multimodal content structure) - if (msg.role === 'user') { - if (typeof msg.content === 'string') { - baseMessage.content = msg.content - } else if (Array.isArray(msg.content)) { - baseMessage.content = msg.content as ChatCompletionContentPart[] - } - result.push(baseMessage as ChatCompletionMessageParam) - continue - } - - // Handle assistant messages with tool_calls - if (msg.role === 'assistant' && msg.tool_calls && msg.tool_calls.length > 0) { - const reasoningContent = (msg as any).reasoning_content as string | undefined - if (supportsFunctionCall) { - // Standard OpenAI format - preserve tool_calls structure - const normalizedToolCalls = msg.tool_calls.map((toolCall) => { - const toolCallId = toolCall.id || `tool-${Date.now()}-${Math.random()}` - enqueueNativeToolCallId(toolCallId) - return { - ...toolCall, - id: toolCallId - } - }) - - result.push({ - role: 'assistant', - content: baseMessage.content || null, - tool_calls: normalizedToolCalls, - ...(reasoningContent !== undefined ? { reasoning_content: reasoningContent } : {}) - } as ChatCompletionMessageParam) - } else { - // Mock format: Store tool calls and assistant content, wait for tool responses - // First add the assistant message if it has content - if (baseMessage.content) { - result.push({ - role: 'assistant', - content: baseMessage.content, - ...(reasoningContent !== undefined ? { reasoning_content: reasoningContent } : {}) - } as ChatCompletionMessageParam) - } - - // Store tool calls for pairing with responses - for (const toolCall of msg.tool_calls) { - const toolCallId = toolCall.id || `tool-${Date.now()}-${Math.random()}` - pendingToolCalls.set(toolCallId, { - name: toolCall.function?.name || 'unknown', - arguments: - typeof toolCall.function?.arguments === 'string' - ? toolCall.function.arguments - : JSON.stringify(toolCall.function?.arguments || {}), - assistantContent: baseMessage.content as string | undefined - }) - pendingToolCallOrder.push(toolCallId) - } - } - continue - } - - // Handle tool messages - if (msg.role === 'tool') { - if (supportsFunctionCall) { - const serializedContent = serializeContent(msg.content) - const pendingIds = snapshotPendingNativeToolCallIds() - - // Only attempt to heuristically split merged tool responses when we don't - // already have an explicit tool_call_id from upstream. If tool_call_id is - // present, we trust the upstream pairing and pass the content through as-is. - if (pendingIds.length > 1 && !msg.tool_call_id && serializedContent) { - const splitParts = this.splitMergedToolContent(serializedContent, pendingIds.length) - if (splitParts && splitParts.length === pendingIds.length) { - splitParts.forEach((part, index) => { - const toolCallId = pendingIds[index] - if (!toolCallId) return - consumeNativeToolCallId(toolCallId) - result.push({ - role: 'tool', - content: part, - tool_call_id: toolCallId - } as ChatCompletionMessageParam) - }) - continue - } - } - - let resolvedToolCallId = msg.tool_call_id - if (resolvedToolCallId && pendingNativeToolCallSet.has(resolvedToolCallId)) { - consumeNativeToolCallId(resolvedToolCallId) - } else if (!resolvedToolCallId) { - resolvedToolCallId = consumeNativeToolCallId() - } - - result.push({ - role: 'tool', - content: serializedContent, - tool_call_id: resolvedToolCallId || msg.tool_call_id || '' - } as ChatCompletionMessageParam) - } else { - // Mock format: Create user message with function_call_record - const serializedContent = serializeContent(msg.content) - const pendingEntries = getPendingMockToolCallEntries() - - // 同样地,在 legacy mock 模式下,仅当没有明确的 tool_call_id 时才尝试 - // 对单条 tool 响应做多段拆分;一旦上游已经指定了 tool_call_id,就不要再 - // 进行基于内容的猜测拆分,以避免错误地跨工具混淆响应。 - if (pendingEntries.length > 1 && !msg.tool_call_id && serializedContent) { - const splitParts = this.splitMergedToolContent(serializedContent, pendingEntries.length) - if (splitParts && splitParts.length === pendingEntries.length) { - splitParts.forEach((part, index) => { - const entry = pendingEntries[index] - pushMockToolResponse(entry.id, entry.meta, part) - }) - continue - } - } - - let toolCallId = msg.tool_call_id || '' - if (!toolCallId && pendingEntries.length > 0) { - toolCallId = pendingEntries[0].id - } - - const pendingCall = toolCallId ? pendingToolCalls.get(toolCallId) : undefined - - if (toolCallId && pendingCall) { - pushMockToolResponse(toolCallId, pendingCall, serializedContent) - } else { - // Fallback: tool response without matching call, still format as user message - const mockRecord = { - function_call_record: { - name: 'unknown', - arguments: {}, - response: serializedContent - } - } - - result.push({ - role: 'user', - content: `${JSON.stringify(mockRecord)}` - } as ChatCompletionMessageParam) - } - } - continue - } - - // Handle other messages (system, assistant without tool_calls) - if (msg.role === 'assistant') { - const reasoningContent = (msg as any).reasoning_content as string | undefined - if (reasoningContent !== undefined) { - ;(baseMessage as any).reasoning_content = reasoningContent - } - } - result.push(baseMessage as ChatCompletionMessageParam) - } - - return result - } - - /** - * Some upstream MCP layers merge multiple tool responses into a single assistant message when - * `pendingIds.length > 1`; this is outside of the standard OpenAI tool_call flow, so we attempt - * to recover individual payloads by trying several splitting heuristics. Supported formats - * include JSON arrays of strings, delimiter blocks formed by lines of three or more hyphens/equals/asterisks, - * blank-line separation, and repeated header markers. Each strategy is attempted in order so we - * favor structured formats first and fall back to progressively looser parsing when strict - * patterns fail. - */ - private splitMergedToolContent(content: string, expectedParts: number): string[] | null { - if (!content || expectedParts <= 1) return null - const trimmed = content.trim() - if (!trimmed) return null - - const strategies: Array<() => string[] | null> = [ - () => this.trySplitJsonArray(trimmed, expectedParts), - () => this.trySplitByDelimiter(trimmed, /\n-{3,}\n+/g, expectedParts), - () => this.trySplitByDelimiter(trimmed, /\n={3,}\n+/g, expectedParts), - () => this.trySplitByDelimiter(trimmed, /\n\*{3,}\n+/g, expectedParts), - () => this.trySplitByDelimiter(trimmed, /\n\s*\n+/g, expectedParts), - () => this.trySplitByHeaderRepeats(trimmed, expectedParts) - ] - - for (const strategy of strategies) { - const parts = strategy() - if (parts) { - return parts - } - } - - return null - } - - private trySplitJsonArray(content: string, expectedParts: number): string[] | null { - if (!content.startsWith('[')) return null - - try { - const parsed = JSON.parse(content) - if (Array.isArray(parsed) && parsed.length === expectedParts) { - return parsed.map((entry) => (typeof entry === 'string' ? entry : JSON.stringify(entry))) - } - } catch { - return null - } - - return null - } - - private trySplitByDelimiter( - content: string, - delimiter: RegExp, - expectedParts: number - ): string[] | null { - const parts = content - .split(delimiter) - .map((part) => part.trim()) - .filter((part) => part.length > 0) - - if (parts.length === expectedParts) { - return parts - } - - return null - } - - private trySplitByHeaderRepeats(content: string, expectedParts: number): string[] | null { - const headerRegex = /(?:^|\n)([-*]?\s*[A-Za-z][A-Za-z0-9\s,'"-]{0,80}?:)/g - const matches = [...content.matchAll(headerRegex)] - if (matches.length === 0) { - return null - } - - const grouped = new Map() - for (const match of matches) { - const rawHeader = match[1] - if (!rawHeader) continue - const normalized = rawHeader.replace(/\d+/g, '').trim().toLowerCase() - if (!normalized || normalized.length < 3) continue - const startIndex = (match.index ?? 0) + (match[0].startsWith('\n') ? 1 : 0) - if (!grouped.has(normalized)) { - grouped.set(normalized, []) - } - grouped.get(normalized)!.push(startIndex) - } - - for (const [, indices] of grouped) { - if (indices.length === expectedParts) { - const segments: string[] = [] - for (let i = 0; i < indices.length; i++) { - const start = indices[i] - const end = i + 1 < indices.length ? indices[i + 1] : content.length - const segment = content.slice(start, end).trim() - if (!segment) { - return null - } - segments.push(segment) - } - - if (segments.length === expectedParts) { - return segments - } - } - } - - if (matches.length === expectedParts) { - const segments: string[] = [] - for (let i = 0; i < matches.length; i++) { - const match = matches[i] - const start = (match.index ?? 0) + (match[0].startsWith('\n') ? 1 : 0) - const end = - i + 1 < matches.length ? (matches[i + 1].index ?? content.length) : content.length - const segment = content.slice(start, end).trim() - if (!segment) { - return null - } - segments.push(segment) - } - - if (segments.length === expectedParts) { - return segments - } - } - - return null - } - - // OpenAI completion method - protected async openAICompletion( - messages: ChatMessage[], - modelId?: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (!this.isInitialized) { - throw new Error('Provider not initialized') - } - - if (!modelId) { - throw new Error('Model ID is required') - } - - // Check if model supports function calling - const modelConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) - const supportsFunctionCall = modelConfig?.functionCall || false - - const requestParams: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming = { - messages: this.formatMessages(messages, supportsFunctionCall), - model: modelId, - stream: false, - temperature: temperature, - ...(modelId.startsWith('o1') || - modelId.startsWith('o3') || - modelId.startsWith('o4') || - modelId.includes('gpt-4.1') || - modelId.includes('gpt-5') - ? { max_completion_tokens: maxTokens } - : { max_tokens: maxTokens }) - } - const promptCachePlan = resolvePromptCachePlan({ - providerId: this.provider.id, - apiType: 'openai_chat', - modelId, - messages: requestParams.messages as unknown[], - conversationId: modelConfig?.conversationId - }) - requestParams.messages = applyOpenAIChatExplicitCacheBreakpoint( - requestParams.messages as ChatCompletionMessageParam[], - promptCachePlan - ) - OPENAI_REASONING_MODELS.forEach((noTempId) => { - if (modelId.startsWith(noTempId)) { - delete requestParams.temperature - } - }) - const cachedRequestParams = applyOpenAIPromptCacheKey( - requestParams as unknown as Record, - promptCachePlan - ) as unknown as OpenAI.Chat.ChatCompletionCreateParamsNonStreaming - const completion = await this.openai.chat.completions.create(cachedRequestParams) - - const message = completion.choices[0].message as ChatCompletionMessage & { - reasoning_content?: string - } - const resultResp: LLMResponse = { - content: '' - } - - // Handle native reasoning_content - if (message.reasoning_content) { - resultResp.reasoning_content = message.reasoning_content - resultResp.content = message.content || '' - return resultResp - } - - // Handle tags - if (message.content) { - const content = message.content.trimStart() - if (content.includes('')) { - const thinkStart = content.indexOf('') - const thinkEnd = content.indexOf('') - - if (thinkEnd > thinkStart) { - // 提取 reasoning_content - resultResp.reasoning_content = content.substring(thinkStart + 7, thinkEnd).trim() - - // 合并 前后的普通内容 - const beforeThink = content.substring(0, thinkStart).trim() - const afterThink = content.substring(thinkEnd + 8).trim() - resultResp.content = [beforeThink, afterThink].filter(Boolean).join('\n') - } else { - // 如果没有找到配对的结束标签,将所有内容作为普通内容 - resultResp.content = message.content - } - } else { - // 没有 think 标签,所有内容作为普通内容 - resultResp.content = message.content - } - } - - return resultResp - } - - /////////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * 处理图片生成模型请求的内部方法。 - * @param messages 聊天消息数组。 - * @param modelId 模型ID。 - * @returns AsyncGenerator 流式事件。 - */ - private async *handleImgGeneration( - messages: ChatMessage[], - modelId: string - ): AsyncGenerator { - // 获取最后几条消息,检查是否有图片 - let prompt = '' - const imageUrls: string[] = [] - // 获取最后的用户消息内容作为提示词 - const lastUserMessage = messages.findLast((m) => m.role === 'user') - - if (lastUserMessage?.content) { - if (typeof lastUserMessage.content === 'string') { - prompt = lastUserMessage.content - } else if (Array.isArray(lastUserMessage.content)) { - // 处理多模态内容,提取文本 - const textParts: string[] = [] - for (const part of lastUserMessage.content) { - if (part.type === 'text' && part.text) { - textParts.push(part.text) - } - } - prompt = textParts.join('\n') - } - } - - // 检查最后几条消息中是否有图片 - // 通常我们只需要检查最后两条消息:最近的用户消息和最近的助手消息 - const lastMessages = messages.slice(-2) - for (const message of lastMessages) { - if (message.content) { - if (Array.isArray(message.content)) { - for (const part of message.content) { - if (part.type === 'image_url' && part.image_url?.url) { - imageUrls.push(part.image_url.url) - } - } - } - } - } - - if (!prompt) { - console.error('[handleImgGeneration] Could not extract prompt for image generation.') - yield createStreamEvent.error('Could not extract prompt for image generation.') - yield createStreamEvent.stop('error') - return - } - - try { - let result: OpenAI.Images.ImagesResponse - - if (imageUrls.length > 0) { - // 使用 images.edit 接口处理带有图片的请求 - let imageBuffer: Buffer - - if (imageUrls[0].startsWith('imgcache://')) { - const filePath = imageUrls[0].slice('imgcache://'.length) - const fullPath = path.join(app.getPath('userData'), 'images', filePath) - imageBuffer = fs.readFileSync(fullPath) - } else { - const imageResponse = await fetch(imageUrls[0]) - const imageBlob = await imageResponse.blob() - imageBuffer = Buffer.from(await imageBlob.arrayBuffer()) - } - - // 创建临时文件 - const imagePath = `/tmp/openai_image_${Date.now()}.png` - await new Promise((resolve, reject) => { - fs.writeFile(imagePath, imageBuffer, (err: Error | null) => { - if (err) { - reject(err) - } else { - resolve() - } - }) - }) - - // 使用文件路径创建 Readable 流 - const imageFile = fs.createReadStream(imagePath) - const params: OpenAI.Images.ImageEditParams = { - model: modelId, - image: imageFile, - prompt: prompt, - n: 1 - } - - // 如果是支持尺寸配置的模型,检测图片尺寸并设置合适的参数 - if (SIZE_CONFIGURABLE_MODELS.includes(modelId)) { - try { - const metadata = await sharp(imageBuffer).metadata() - if (metadata.width && metadata.height) { - const aspectRatio = metadata.width / metadata.height - - // 根据宽高比选择最接近的尺寸 - if (Math.abs(aspectRatio - 1) < 0.1) { - // 接近正方形 - params.size = SUPPORTED_IMAGE_SIZES.SQUARE - } else if (aspectRatio > 1) { - // 横向图片 - params.size = SUPPORTED_IMAGE_SIZES.LANDSCAPE - } else { - // 纵向图片 - params.size = SUPPORTED_IMAGE_SIZES.PORTRAIT - } - } else { - // 如果无法获取宽高,使用默认参数 - params.size = '1024x1536' - } - params.quality = 'high' - } catch (error) { - console.warn( - '[handleImgGeneration] Failed to detect image dimensions, using default size:', - error - ) - // 检测失败时使用默认参数 - params.size = '1024x1536' - params.quality = 'high' - } - } - - result = await this.openai.images.edit(params) - - // 清理临时文件 - try { - fs.unlinkSync(imagePath) - } catch (e) { - console.error('[handleImgGeneration] Failed to delete temporary file:', e) - } - } else { - // 使用原来的 images.generate 接口处理没有图片的请求 - console.log( - `[handleImgGeneration] Generating image with model ${modelId} and prompt: "${prompt}"` - ) - const params: OpenAI.Images.ImageGenerateParams = { - model: modelId, - prompt: prompt, - n: 1, - response_format: 'b64_json' // 请求 base64 格式 - } - if (modelId === 'gpt-image-1' || modelId === 'gpt-4o-image' || modelId === 'gpt-4o-all') { - params.size = '1024x1536' - params.quality = 'high' - } - result = await this.openai.images.generate(params, { - timeout: 300_000 - }) - } - - if (result.data && (result.data[0]?.url || result.data[0]?.b64_json)) { - // 使用devicePresenter缓存图片URL - try { - let imageUrl: string = '' - if (result.data[0]?.b64_json) { - // 处理 base64 数据 - const base64Data = result.data[0].b64_json - // 直接使用 devicePresenter 缓存 base64 数据 - imageUrl = await presenter.devicePresenter.cacheImage( - base64Data.startsWith('data:image/png;base64,') - ? base64Data - : 'data:image/png;base64,' + base64Data - ) - } else { - // 原有的 URL 处理逻辑 - imageUrl = result.data[0]?.url || '' - imageUrl = await presenter.devicePresenter.cacheImage(imageUrl) - } - - // 返回缓存后的URL - yield createStreamEvent.imageData({ data: imageUrl, mimeType: 'deepchat/image-url' }) - - // 处理 usage 信息 - if (result.usage) { - yield createStreamEvent.usage({ - prompt_tokens: result.usage.input_tokens || 0, - completion_tokens: result.usage.output_tokens || 0, - total_tokens: result.usage.total_tokens || 0, - cached_tokens: getOpenAIChatCachedTokens(result.usage), - cache_write_tokens: getOpenAIChatCacheWriteTokens(result.usage) - }) - } - - yield createStreamEvent.stop('complete') - } catch (cacheError) { - // 缓存失败时降级为使用原始URL - console.warn( - '[handleImgGeneration] Failed to cache image, using original data/URL:', - cacheError - ) - yield createStreamEvent.imageData({ - data: result.data[0]?.url || result.data[0]?.b64_json || '', - mimeType: result.data[0]?.url ? 'deepchat/image-url' : 'deepchat/image-base64' - }) - yield createStreamEvent.stop('complete') - } - } else { - console.error('[handleImgGeneration] No image data received from API.', result) - yield createStreamEvent.error('No image data received from API.') - yield createStreamEvent.stop('error') - } - } catch (error: unknown) { - const errorMessage = error instanceof Error ? error.message : String(error) - console.error('[handleImgGeneration] Error during image generation:', errorMessage) - yield createStreamEvent.error(`Image generation failed: ${errorMessage}`) - yield createStreamEvent.stop('error') - } - } - - /////////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * 处理 OpenAI 聊天补全模型请求的内部方法。 - * @param messages 聊天消息数组。 - * @param modelId 模型ID。 - * @param modelConfig 模型配置。 - * @param temperature 温度参数。 - * @param maxTokens 最大 token 数。 - * @param mcpTools MCP 工具定义数组。 - * @returns AsyncGenerator 流式事件。 - */ - private async *handleChatCompletion( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - //----------------------------------------------------------------------------------------------------- - // 为 OpenAI 聊天补全准备消息和工具 - const tools = mcpTools || [] - const supportsFunctionCall = modelConfig?.functionCall || false // 判断是否支持原生函数调用 - let processedMessages = [ - ...this.formatMessages(messages, supportsFunctionCall) - ] as ChatCompletionMessageParam[] - // console.log('processedMessages', JSON.stringify(processedMessages)) - // 如果不支持原生函数调用但存在工具,则准备非原生函数调用提示 - if (tools.length > 0 && !supportsFunctionCall) { - processedMessages = this.prepareFunctionCallPrompt(processedMessages, tools) - } - - // 如果支持原生函数调用,则转换工具定义为 OpenAI 格式 - const apiTools = - tools.length > 0 && supportsFunctionCall - ? await this.mcpRuntime?.mcpToolsToOpenAITools(tools, this.provider.id) - : undefined - - // 构建请求参数 - const requestParams: OpenAI.Chat.ChatCompletionCreateParamsStreaming = { - messages: processedMessages, - model: modelId, - stream: true, - temperature, - ...(modelId.startsWith('o1') || - modelId.startsWith('o3') || - modelId.startsWith('o4') || - modelId.includes('gpt-4.1') || - modelId.includes('gpt-5') - ? { max_completion_tokens: maxTokens } - : { max_tokens: maxTokens }) - } - - // 添加 stream_options 以获取 token usages(适用于如 Qwen 等模型) - requestParams.stream_options = { include_usage: true } - - // 防止某些模型(如 Qwen)以 JSON 形式输出结果正文,Grok 系列模型和供应商无需设置 - if (this.provider.id.toLowerCase().includes('dashscope')) { - requestParams.response_format = { type: 'text' } - } - - // openrouter deepseek-v3-0324:free 特定模型处理 - if ( - this.provider.id.toLowerCase().includes('openrouter') && - modelId.startsWith('deepseek/deepseek-chat-v3-0324:free') - ) { - // 限定服务供应商为chutes,sorry for hack... - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ;(requestParams as any).provider = { - only: ['chutes'] - } - } - - if (modelConfig.reasoningEffort && this.supportsEffortParameter(modelId)) { - ;(requestParams as any).reasoning_effort = modelConfig.reasoningEffort - } - - // 仅当模型能力集声明支持时,才添加 verbosity - if (modelConfig.verbosity && this.supportsVerbosityParameter(modelId)) { - ;(requestParams as any).verbosity = modelConfig.verbosity - } - - // 移除推理模型的温度参数 - OPENAI_REASONING_MODELS.forEach((noTempId) => { - if (modelId.startsWith(noTempId)) delete requestParams.temperature - }) - - // 如果存在 API 工具且支持函数调用,则添加到请求参数中 - if (apiTools && apiTools.length > 0 && supportsFunctionCall) requestParams.tools = apiTools - - const promptCachePlan = resolvePromptCachePlan({ - providerId: this.provider.id, - apiType: 'openai_chat', - modelId, - messages: processedMessages as unknown[], - tools, - conversationId: modelConfig?.conversationId - }) - requestParams.messages = applyOpenAIChatExplicitCacheBreakpoint( - requestParams.messages as ChatCompletionMessageParam[], - promptCachePlan - ) - const cachedRequestParams = applyOpenAIPromptCacheKey( - requestParams as unknown as Record, - promptCachePlan - ) as unknown as OpenAI.Chat.ChatCompletionCreateParamsStreaming - - await this.emitRequestTrace(modelConfig, { - endpoint: this.buildChatCompletionsEndpoint(), - headers: this.buildChatCompletionsTraceHeaders(), - body: cachedRequestParams - }) - - // console.log('[handleChatCompletion] requestParams', JSON.stringify(requestParams)) - // 发起 OpenAI 聊天补全请求 - const stream = await this.openai.chat.completions.create(cachedRequestParams) - - //----------------------------------------------------------------------------------------------------- - // 流处理状态定义 (已将相关变量声明提升到顶部,确保可见性) - let pendingBuffer = '' // 累积来自 delta.content 的字符,用于匹配标签和内容 - let currentTextOutputBuffer = '' // 用于累积在所有标签之外的纯文本,准备输出 - - const thinkStartMarker = '' - const thinkEndMarker = '' - const funcStartMarker = '' - const funcEndMarker = '' - - // 标记当前解析状态 - let inThinkBlock = false // 是否在 块内部 - let inFunctionCallBlock = false // 是否在 块内部(非原生) - - // 定义一个辅助函数,检查 buffer 是否可能是任何已知标签的有效前缀 - const hasPotentialMarkerPrefix = (buffer: string) => { - return ( - thinkStartMarker.startsWith(buffer) || - thinkEndMarker.startsWith(buffer) || - funcStartMarker.startsWith(buffer) || - funcEndMarker.startsWith(buffer) - ) - } - - const nativeToolCalls: Record< - string, - { name: string; arguments: string; completed?: boolean } - > = {} - const indexToIdMap: Record = {} - let stopReason: LLMCoreStreamEvent['stop_reason'] = 'complete' - let toolUseDetected = false // 标记是否检测到工具使用(原生或非原生) - let usage: - | { - prompt_tokens: number - completion_tokens: number - total_tokens: number - cached_tokens?: number - cache_write_tokens?: number - } - | undefined = undefined - - //----------------------------------------------------------------------------------------------------- - // 流处理循环 - for await (const chunk of stream) { - // console.log('[handleChatCompletion] chunk', JSON.stringify(chunk)) - const choice = chunk.choices[0] - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const delta = choice?.delta as any - const currentContent = delta?.content || '' - - // 1. 处理非内容事件 (如 usage, reasoning, tool_calls) - if (chunk.usage) { - usage = { - ...chunk.usage, - cached_tokens: getOpenAIChatCachedTokens(chunk.usage), - cache_write_tokens: getOpenAIChatCacheWriteTokens(chunk.usage) - } - } - - // 原生 reasoning 内容处理(直接产出) - if (delta?.reasoning_content || delta?.reasoning) { - yield createStreamEvent.reasoning(delta.reasoning_content || delta.reasoning) - continue - } - - // 处理图片数据(OpenRouter Gemini 格式) - if (delta?.images && Array.isArray(delta.images)) { - for (const image of delta.images) { - if (image.type === 'image_url' && image.image_url?.url) { - try { - const cachedUrl = await presenter.devicePresenter.cacheImage(image.image_url.url) - yield createStreamEvent.imageData({ data: cachedUrl, mimeType: 'deepchat/image-url' }) - } catch (cacheError) { - console.warn('[handleChatCompletion] Failed to cache image:', cacheError) - yield createStreamEvent.imageData({ - data: image.image_url.url, - mimeType: 'deepchat/image-url' - }) - } - } - } - continue - } - - // 处理 Gemini 原生格式的图片数据(inlineData) - if (delta?.content?.parts && Array.isArray(delta.content.parts)) { - for (const part of delta.content.parts) { - if (part.inlineData && part.inlineData.data) { - yield createStreamEvent.imageData({ - data: part.inlineData.data, - mimeType: part.inlineData.mimeType || 'image/png' - }) - } - } - continue - } - - // 处理 Gemini multi_mod_content 格式的图片数据 - if (delta?.multi_mod_content && Array.isArray(delta.multi_mod_content)) { - for (const item of delta.multi_mod_content) { - if (item.inline_data && item.inline_data.data) { - const base64Data = item.inline_data.data - const mimeType = item.inline_data.mime_type || 'image/png' - - // 将纯 base64 数据转换为 data:image/...;base64,xxx 格式 - const dataUri = base64Data.startsWith('data:image/') - ? base64Data - : `data:${mimeType};base64,${base64Data}` - - try { - // 缓存图片并获取URL - const cachedUrl = await presenter.devicePresenter.cacheImage(dataUri) - yield createStreamEvent.imageData({ data: cachedUrl, mimeType: 'deepchat/image-url' }) - } catch (cacheError) { - console.warn( - '[handleChatCompletion] Failed to cache image from multi_mod_content:', - cacheError - ) - // 缓存失败时,直接使用原始 base64 数据 - yield createStreamEvent.imageData({ - data: dataUri, - mimeType: mimeType - }) - } - } - } - continue - } - - // 处理 content 中直接包含 base64 图片的情况 - let processedCurrentContent = currentContent - if (currentContent && currentContent.includes('![image](data:image/')) { - try { - // 使用正则表达式匹配 markdown 格式的 base64 图片 - const base64ImageRegex = /!\[image\]\((data:image\/[^;]+;base64,[^)]+)\)/g - let hasImages = false - - let match - while ((match = base64ImageRegex.exec(currentContent)) !== null) { - const base64Data = match[1] // 完整的 data:image/...;base64,... 格式 - - try { - // 缓存图片并获取URL - const cachedUrl = await presenter.devicePresenter.cacheImage(base64Data) - - // 发送图片数据事件 - yield createStreamEvent.imageData({ data: cachedUrl, mimeType: 'deepchat/image-url' }) - - // 从内容中完全移除图片部分,避免重复显示(image_data事件已经处理了图片显示) - processedCurrentContent = processedCurrentContent.replace(match[0], '') - hasImages = true - - console.log( - `[handleChatCompletion] Successfully cached base64 image from content and removed from text` - ) - } catch (cacheError) { - console.warn( - '[handleChatCompletion] Failed to cache base64 image from content:', - cacheError - ) - // 缓存失败时保持原始内容不变 - } - } - - // 如果处理了图片,清理多余的空行并记录日志 - if (hasImages) { - // 清理移除图片后可能留下的多余空行 - processedCurrentContent = normalizeExtractedImageText(processedCurrentContent) - console.log( - `[handleChatCompletion] Processed ${currentContent.length} chars -> ${processedCurrentContent.length} chars (images removed)` - ) - } - } catch (error) { - console.error('[handleChatCompletion] Error processing base64 images in content:', error) - // 处理失败时继续正常流程 - } - } - - // 原生 tool_calls 处理 - if (supportsFunctionCall && delta?.tool_calls?.length > 0) { - toolUseDetected = true - // console.log('[handleChatCompletion] Handling native tool_calls', JSON.stringify(delta.tool_calls)) - for (const toolCallDelta of delta.tool_calls) { - const id = toolCallDelta.id ? toolCallDelta.id : toolCallDelta.function?.name - const index = toolCallDelta.index - const functionName = toolCallDelta.function?.name - const argumentChunk = toolCallDelta.function?.arguments - - let currentToolCallId: string | undefined = undefined - - if (id) { - currentToolCallId = id - if (index !== undefined) indexToIdMap[index] = id - } else if (index !== undefined && indexToIdMap[index]) { - currentToolCallId = indexToIdMap[index] - } else { - console.warn( - '[handleChatCompletion] Received tool call delta chunk without id/mapping:', - toolCallDelta - ) - continue - } - - if (currentToolCallId) { - if (!nativeToolCalls[currentToolCallId]) { - nativeToolCalls[currentToolCallId] = { name: '', arguments: '', completed: false } - // console.log(`[handleChatCompletion] Initialized nativeToolCalls entry for id ${currentToolCallId}`) - } - - const currentCallState = nativeToolCalls[currentToolCallId] - - // 处理增量更新 - // console.log(`[handleChatCompletion] Handling incremental update for ${currentToolCallId}.`) - if (functionName && !currentCallState.name) { - currentCallState.name = functionName - yield createStreamEvent.toolCallStart(currentToolCallId, functionName) - } - if (argumentChunk) { - currentCallState.arguments += argumentChunk - yield createStreamEvent.toolCallChunk(currentToolCallId, argumentChunk) - } - } - } - continue // 处理完原生工具调用后继续下一个 chunk - } - - // 处理停止原因 - if (choice?.finish_reason) { - const reasonFromAPI = choice.finish_reason - // console.log('[handleChatCompletion] Finish Reason from API:', reasonFromAPI) - if (reasonFromAPI === 'tool_calls') { - stopReason = 'tool_use' - toolUseDetected = true - } else if (toolUseDetected) { - // 如果之前已经有工具调用,那么 finish_reason 'stop' 意味着工具调用完成 - stopReason = - reasonFromAPI === 'stop' - ? 'complete' - : reasonFromAPI === 'length' - ? 'max_tokens' - : 'error' - } else if (reasonFromAPI === 'stop') { - stopReason = 'complete' - } else if (reasonFromAPI === 'length') { - stopReason = 'max_tokens' - } else { - console.warn(`[handleChatCompletion] Unhandled finish reason: ${reasonFromAPI}`) - stopReason = 'error' - } - /* - choice { - finish_reason: 'stop', - delta: { content: '!' }, - index: 0, - logprobs: null - } - */ - // continue - } - - // 如果没有内容,则继续下一个 chunk - if (!processedCurrentContent) continue - - // 2. 字符级流式处理内容 - for (const char of processedCurrentContent) { - pendingBuffer += char - - // 循环处理 pendingBuffer 直到它为空,或者不足以继续匹配 - while (pendingBuffer.length > 0) { - if (inThinkBlock) { - // 在 内部,所有内容都视为 reasoning - if ( - pendingBuffer.length >= thinkEndMarker.length && - pendingBuffer.endsWith(thinkEndMarker) - ) { - const content = pendingBuffer.slice(0, -thinkEndMarker.length) - if (content) { - yield createStreamEvent.reasoning(content) - } - inThinkBlock = false - pendingBuffer = '' // 清空 buffer,退出当前 while 循环 - } else { - // 如果 pendingBuffer 长度不足以匹配 ,或者已经超过 长度但不是其有效前缀 - // 意味着 pendingBuffer 的首字符是推理内容的一部分,可以产出。 - // 否则,pendingBuffer 是 的有效前缀,继续累积等待完整标签。 - if ( - pendingBuffer.length > thinkEndMarker.length || - !thinkEndMarker.startsWith(pendingBuffer) - ) { - const charToYield = pendingBuffer[0] - pendingBuffer = pendingBuffer.slice(1) - yield createStreamEvent.reasoning(charToYield) - } else { - break // 跳出 while 循环,等待更多字符以形成完整的结束标签 - } - } - } else if (inFunctionCallBlock) { - // 在非原生 内部 - if ( - pendingBuffer.length >= funcEndMarker.length && - pendingBuffer.endsWith(funcEndMarker) - ) { - const content = pendingBuffer.slice(0, -funcEndMarker.length) - // 解析非原生函数调用 - const parsedCalls = this.parseFunctionCalls( - `${funcStartMarker}${content}${funcEndMarker}`, // 确保完整标签以进行解析 - `non-native-${this.provider.id}` - ) - for (const parsedCall of parsedCalls) { - yield createStreamEvent.toolCallStart(parsedCall.id, parsedCall.function.name) - yield createStreamEvent.toolCallChunk(parsedCall.id, parsedCall.function.arguments) - yield createStreamEvent.toolCallEnd(parsedCall.id, parsedCall.function.arguments) - } - toolUseDetected = true // 标记检测到工具使用 - inFunctionCallBlock = false - pendingBuffer = '' // 清空 buffer,退出当前 while 循环 - } else { - // 在 内部,直到结束标签出现前,所有内容都应无条件累积。 - // 不进行字符产出,因为函数调用参数需要完整性。 - // 仅等待完整标签的出现。 - break // 跳出 while 循环,等待更多字符以形成完整的结束标签 - } - } else { - // 不在任何特殊块内部,检查开始标签或输出纯文本 - // 优先尝试匹配完整的开始标签 - if ( - pendingBuffer.length >= thinkStartMarker.length && - pendingBuffer.endsWith(thinkStartMarker) - ) { - const textBeforeTag = pendingBuffer.slice(0, -thinkStartMarker.length) - if (textBeforeTag) { - currentTextOutputBuffer += textBeforeTag - yield createStreamEvent.text(currentTextOutputBuffer) - currentTextOutputBuffer = '' - } - inThinkBlock = true - pendingBuffer = '' // 清空 buffer,继续内层 while 循环 - } else if ( - pendingBuffer.length >= funcStartMarker.length && - pendingBuffer.endsWith(funcStartMarker) - ) { - const textBeforeTag = pendingBuffer.slice(0, -funcStartMarker.length) - if (textBeforeTag) { - currentTextOutputBuffer += textBeforeTag - yield createStreamEvent.text(currentTextOutputBuffer) - currentTextOutputBuffer = '' - } - inFunctionCallBlock = true - pendingBuffer = '' // 清空 buffer,继续内层 while 循环 - } else { - // 如果没有匹配到完整的开始标签,并且 pendingBuffer 不再是任何已知标签的有效前缀, - // 那么 pendingBuffer 的首字符就是纯文本,可以安全产出。 - // 否则,pendingBuffer 仍可能是某个标签的有效前缀,继续累积。 - if (!hasPotentialMarkerPrefix(pendingBuffer)) { - const charToYield = pendingBuffer[0] - pendingBuffer = pendingBuffer.slice(1) - currentTextOutputBuffer += charToYield - yield createStreamEvent.text(currentTextOutputBuffer) - currentTextOutputBuffer = '' - } else { - break // 跳出 while 循环,等待更多字符以形成完整的标签 - } - } - } - } // 字符循环内部的 while 循环结束 - } // 字符循环结束 - } // chunk 循环结束 - - //----------------------------------------------------------------------------------------------------- - // 最终处理:流结束时处理任何剩余的缓冲内容 - // 1. 处理 pendingBuffer 中剩余的任何内容 - if (pendingBuffer.length > 0) { - if (inThinkBlock) { - // 如果流结束时 未闭合,将其内容作为 reasoning 产出 - console.warn( - `[handleChatCompletion] Stream ended while inside unclosed tag. Remaining content: "${pendingBuffer}"` - ) - yield createStreamEvent.reasoning(pendingBuffer) - } else if (inFunctionCallBlock) { - // 如果流结束时非原生函数调用未闭合,尝试解析并作为工具调用事件(不发出 end 事件) - console.warn( - `[handleChatCompletion] Stream ended while inside unclosed tag. Content: "${pendingBuffer}"` - ) - const parsedCalls = this.parseFunctionCalls( - `${funcStartMarker}${pendingBuffer}`, // 只尝试解析已有的部分,即使不完整,并以incomplete标记,以便下游发现 - `non-native-incomplete-${this.provider.id}` - ) - if (parsedCalls.length > 0) { - for (const parsedCall of parsedCalls) { - yield createStreamEvent.toolCallStart( - parsedCall.id + '-incomplete', - parsedCall.function.name - ) - yield createStreamEvent.toolCallChunk( - parsedCall.id + '-incomplete', - parsedCall.function.arguments || '' - ) - // 不会发出 tool_call_end,因为标签未闭合 - // 不发出 tool_call_end 的理由在于,提醒下游发现未完成的function调用 - } - toolUseDetected = true - } else { - // 如果解析失败,则作为纯文本输出,并附带开始标签 - yield createStreamEvent.text(`${funcStartMarker}${pendingBuffer}`) - } - } else { - // 否则,作为普通纯文本输出 - currentTextOutputBuffer += pendingBuffer - } - pendingBuffer = '' - } - - // 2. 处理 currentTextOutputBuffer 中剩余的任何纯文本 - if (currentTextOutputBuffer) { - yield createStreamEvent.text(currentTextOutputBuffer) - currentTextOutputBuffer = '' - } - - // 3. 最终检查和产出原生工具调用 - // 这里假设原生工具调用在流结束时其 arguments 都已完整。 - if (supportsFunctionCall && toolUseDetected) { - for (const toolId in nativeToolCalls) { - const tool = nativeToolCalls[toolId] - // 只有当工具名称和参数都存在且未标记为已完成时才尝试结束事件 - if (tool.name && tool.arguments && !tool.completed) { - try { - JSON.parse(tool.arguments) // 检查参数是否是有效的 JSON - yield createStreamEvent.toolCallEnd(toolId, tool.arguments) - tool.completed = true // 标记为已完成 - } catch (e) { - console.error( - `[handleChatCompletion] Error parsing arguments for native tool ${toolId} during finalization: ${tool.arguments}`, - e - ) - // 即使解析失败,也尝试发送,以提供尽可能多的信息 - yield createStreamEvent.toolCallEnd(toolId, tool.arguments) - tool.completed = true // 标记为已完成 - } - } else if (!tool.completed) { - // 记录警告,如果工具调用不完整且未被处理 - console.warn( - `[handleChatCompletion] Native tool call ${toolId} is incomplete and will not have an end event during finalization. Name: ${tool.name}, Args: ${tool.arguments}` - ) - } - } - } - - // 4. 产出 usage 信息 - if (usage) { - yield createStreamEvent.usage(usage) - } - - // 5. 产出最终停止原因 - const finalStopReason = toolUseDetected ? 'tool_use' : stopReason - yield createStreamEvent.stop(finalStopReason) - } - - /////////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * 核心流处理方法,根据模型类型分发请求。 - * @param messages 聊天消息数组。 - * @param modelId 模型ID。 - * @param modelConfig 模型配置。 - * @param temperature 温度参数。 - * @param maxTokens 最大 token 数。 - * @param mcpTools MCP 工具定义数组。 - * @returns AsyncGenerator 流式事件。 - */ - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!this.isInitialized) throw new Error('Provider not initialized') - if (!modelId) throw new Error('Model ID is required') - - const apiEndpoint = this.getEffectiveApiEndpoint(modelId) - - switch (apiEndpoint) { - case ApiEndpointType.Image: - yield* this.handleImgGeneration(messages, modelId) - break - case ApiEndpointType.Video: - yield* this.handleChatCompletion( - messages, - modelId, - modelConfig, - temperature, - maxTokens, - mcpTools - ) - break - case ApiEndpointType.Chat: - default: - yield* this.handleChatCompletion( - messages, - modelId, - modelConfig, - temperature, - maxTokens, - mcpTools - ) - } - } - - /////////////////////////////////////////////////////////////////////////////////////////////////////// - // ... [prepareFunctionCallPrompt remains unchanged] ... - private prepareFunctionCallPrompt( - messages: ChatCompletionMessageParam[], - mcpTools: MCPToolDefinition[] - ): ChatCompletionMessageParam[] { - // 创建消息副本而不是直接修改原始消息 - const result = messages.map((message) => ({ ...message })) - - const functionCallPrompt = this.getFunctionCallWrapPrompt(mcpTools) - const userMessage = result.findLast((message) => message.role === 'user') - - if (userMessage?.role === 'user') { - if (Array.isArray(userMessage.content)) { - // 创建content数组的深拷贝 - userMessage.content = [...userMessage.content] - const firstTextIndex = userMessage.content.findIndex((content) => content.type === 'text') - if (firstTextIndex !== -1) { - // 创建文本内容的副本 - const textContent = { - ...userMessage.content[firstTextIndex] - } as ChatCompletionContentPartText - textContent.text = `${functionCallPrompt}\n\n${(userMessage.content[firstTextIndex] as ChatCompletionContentPartText).text}` - userMessage.content[firstTextIndex] = textContent - } - } else { - userMessage.content = `${functionCallPrompt}\n\n${userMessage.content}` - } - } - return result - } - - // Updated parseFunctionCalls signature and implementation - protected parseFunctionCalls( - response: string, - // Pass a prefix for creating fallback IDs - fallbackIdPrefix: string = 'tool-call' - ): Array<{ id: string; type: string; function: { name: string; arguments: string } }> { - // console.log('[parseFunctionCalls] Received raw response:', response) // Log raw input - try { - // 使用非贪婪模式匹配function_call标签对,能够处理多行内容 - const functionCallMatches = response.match(/([\s\S]*?)<\/function_call>/gs) - if (!functionCallMatches) { - // console.log('[parseFunctionCalls] No tags found.') // Log no match - return [] - } - // console.log(`[parseFunctionCalls] Found ${functionCallMatches.length} potential matches.`) // Log match count - - const toolCalls = functionCallMatches - .map((match, index) => { - // console.log(`[parseFunctionCalls] Processing match ${index}:`, match) // Log each match - // Add index for unique fallback ID generation - const content = match.replace(/<\/?function_call>/g, '').trim() // Fixed regex escaping - // console.log(`[parseFunctionCalls] Extracted content for match ${index}:`, content) // Log extracted content - if (!content) { - // console.log(`[parseFunctionCalls] Match ${index} has empty content, skipping.`) - return null // Skip empty content between tags - } - - try { - let parsedCall - let repairedJson: string | undefined - try { - // Attempt standard JSON parse first - parsedCall = JSON.parse(content) - // console.log(`[parseFunctionCalls] Standard JSON.parse successful for match ${index}.`) // Log success - } catch (initialParseError) { - console.warn( - `[parseFunctionCalls] Standard JSON.parse failed for match ${index}, attempting jsonrepair. Error:`, - (initialParseError as Error).message - ) // Log failure and attempt repair - try { - // Fallback to jsonrepair for robustness - repairedJson = jsonrepair(content) - // console.log( - // `[parseFunctionCalls] jsonrepair result for match ${index}:`, - // repairedJson - // ) // Log repaired JSON - parsedCall = JSON.parse(repairedJson) - // console.log( - // `[parseFunctionCalls] JSON.parse successful after jsonrepair for match ${index}.` - // ) // Log repair success - } catch (repairError) { - console.error( - `[parseFunctionCalls] Failed to parse content for match ${index} even with jsonrepair:`, - repairError, - 'Original content:', - content, - 'Repaired content attempt:', - repairedJson ?? 'N/A' - ) // Log final failure - return null // Skip this malformed call - } - } - // console.log(`[parseFunctionCalls] Parsed object for match ${index}:`, parsedCall) // Log parsed object - - // Extract name and arguments, handling various potential structures - let functionName, functionArgs - if (parsedCall.function_call && typeof parsedCall.function_call === 'object') { - functionName = parsedCall.function_call.name - functionArgs = parsedCall.function_call.arguments - } else if (parsedCall.name && parsedCall.arguments !== undefined) { - functionName = parsedCall.name - functionArgs = parsedCall.arguments - } else if ( - parsedCall.function && - typeof parsedCall.function === 'object' && - parsedCall.function.name - ) { - functionName = parsedCall.function.name - functionArgs = parsedCall.function.arguments - } else { - // Attempt to find the function call structure if nested under a single key - const keys = Object.keys(parsedCall) - if (keys.length === 1) { - const potentialToolCall = parsedCall[keys[0]] - if (potentialToolCall && typeof potentialToolCall === 'object') { - if (potentialToolCall.name && potentialToolCall.arguments !== undefined) { - functionName = potentialToolCall.name - functionArgs = potentialToolCall.arguments - } else if ( - potentialToolCall.function && - typeof potentialToolCall.function === 'object' && - potentialToolCall.function.name - ) { - functionName = potentialToolCall.function.name - functionArgs = potentialToolCall.function.arguments - } - } - } - - // If still not found, log an error - if (!functionName) { - console.error( - '[parseFunctionCalls] Could not determine function name from parsed call:', - parsedCall - ) // Log name extraction failure - return null - } - } - // console.log( - // `[parseFunctionCalls] Extracted for match ${index}: Name='${functionName}', Args=`, - // functionArgs - // ) // Log extracted name/args - - // Ensure arguments are stringified if they are not already - if (typeof functionArgs !== 'string') { - // console.log( - // `[parseFunctionCalls] Arguments for match ${index} are not a string, stringifying.` - // ) // Log stringify attempt - try { - functionArgs = JSON.stringify(functionArgs) - } catch (stringifyError) { - console.error( - '[parseFunctionCalls] Failed to stringify function arguments:', - stringifyError, - functionArgs - ) // Log stringify failure - functionArgs = '{"error": "failed to stringify arguments"}' - } - } - - // Generate a unique ID if not provided in the parsed content - const id = - parsedCall.id ?? - (functionName - ? `${functionName}-${index}-${Date.now()}` - : `${fallbackIdPrefix}-${index}-${Date.now()}`) - - // console.log( - // `[parseFunctionCalls] Finalizing tool call for match ${index}: ID='${id}', Name='${functionName}', Args='${functionArgs}'` - // ) // Log final object details - - return { - id: String(id), // Ensure ID is string - type: 'function', // Standardize type - function: { - name: String(functionName), // Ensure name is string - arguments: functionArgs // Already ensured string - } - } - } catch (processingError) { - // Catch errors during the extraction/validation logic - console.error( - '[parseFunctionCalls] Error processing parsed function call JSON:', - processingError, - 'Content:', - content - ) // Log processing error - return null // Skip this call on error - } - }) - .filter( - ( - call - ): call is { id: string; type: string; function: { name: string; arguments: string } } => - // Type guard ensures correct structure - call !== null && - typeof call.id === 'string' && - typeof call.function === 'object' && - call.function !== null && - typeof call.function.name === 'string' && - typeof call.function.arguments === 'string' - ) - console.log(`[parseFunctionCalls] Returning ${toolCalls.length} parsed tool calls.`) // Log final count - return toolCalls - } catch (error) { - console.error( - '[parseFunctionCalls] Unexpected error during execution:', - error, - 'Input:', - response - ) // Log unexpected error - return [] // Return empty array on unexpected errors - } - } - - // ... [check, summaryTitles, completions, summaries, generateText, suggestions remain unchanged] ... - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - if (!this.isNoModelsApi) { - // Use unified timeout configuration from base class - const models = await this.fetchOpenAIModels({ timeout: this.getModelFetchTimeout() }) - this.models = models // Store fetched models - } - // Potentially add a simple API call test here if needed, e.g., list models even for no-API list to check key/endpoint - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - // Use unknown for type safety - let errorMessage = 'An unknown error occurred during provider check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - // Optionally log the full error object for debugging - console.error('OpenAICompatibleProvider check failed:', error) - - eventBus.sendToRenderer(NOTIFICATION_EVENTS.SHOW_ERROR, SendTarget.ALL_WINDOWS, { - title: 'API Check Failed', // More specific title - message: errorMessage, - id: `openai-check-error-${Date.now()}`, - type: 'error' - }) - return { isOk: false, errorMsg: errorMessage } - } - } - public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { - const summaryText = `${SUMMARY_TITLES_PROMPT}\n\n${messages.map((m) => `${m.role}: ${m.content}`).join('\n')}` - const fullMessage: ChatMessage[] = [{ role: 'user', content: summaryText }] - const response = await this.openAICompletion(fullMessage, modelId, 0.5) - return response.content.replace(/["']/g, '').trim() - } - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - // Simple completion, no specific system prompt needed unless required by base class or future design - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - const systemPrompt = `Summarize the following text concisely:` - // Create messages based on the input text - const requestMessages: ChatMessage[] = [ - { role: 'system', content: systemPrompt }, - { role: 'user', content: text } // Use the input text directly - ] - return this.openAICompletion(requestMessages, modelId, temperature, maxTokens) - } - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - // Use the prompt directly as the user message content - const requestMessages: ChatMessage[] = [{ role: 'user', content: prompt }] - // Note: formatMessages might not be needed here if it's just a single prompt string, - // but keeping it for consistency in case formatMessages adds system prompts or other logic. - return this.openAICompletion(requestMessages, modelId, temperature, maxTokens) - } - async suggestions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - const systemPrompt = `Based on the last user message in the conversation history, provide 3 brief, relevant follow-up suggestions or questions. Output ONLY the suggestions, each on a new line. Do not include numbering, bullet points, or introductory text like "Here are some suggestions:".` - const lastUserMessage = messages.filter((m) => m.role === 'user').pop() // Get the most recent user message - - if (!lastUserMessage) { - console.warn('suggestions called without user messages.') - return [] // Return empty array if no user message found - } - - // Provide some context if possible, e.g., last few messages - const contextMessages = messages.slice(-5) // Last 5 messages as context - - const requestMessages: ChatMessage[] = [ - { role: 'system', content: systemPrompt }, - // Include context leading up to the last user message - ...contextMessages - ] - - try { - const response = await this.openAICompletion( - requestMessages, - modelId, - temperature ?? 0.7, - maxTokens ?? 60 - ) // Adjusted temp/tokens - // Split, trim, and filter results robustly - return response.content - .split('\n') - .map((s) => s.trim()) - .filter((s) => s.length > 0 && !s.match(/^[0-9.\-*\s]*/)) // Fixed regex range - } catch (error) { - console.error('Failed to get suggestions:', error) - return [] // Return empty on error - } - } - - async getEmbeddings(modelId: string, texts: string[]): Promise { - if (!this.isInitialized) throw new Error('Provider not initialized') - if (!modelId) throw new Error('Model ID is required') - // OpenAI embeddings API - const response = await this.openai.embeddings.create({ - model: modelId, - input: texts, - encoding_format: 'float' - }) - // 兼容 OpenAI 返回格式 - return response.data.map((item) => item.embedding) - } - - async getDimensions(modelId: string): Promise { - switch (modelId) { - case 'text-embedding-3-small': - case 'text-embedding-ada-002': - return { - dimensions: 1536, - normalized: true - } - case 'text-embedding-3-large': - return { - dimensions: 3072, - normalized: true - } - default: - try { - const embeddings = await this.getEmbeddings(modelId, [EMBEDDING_TEST_KEY]) - return { - dimensions: embeddings[0].length, - normalized: isNormalized(embeddings[0]) - } - } catch (error) { - console.error( - `[OpenAICompatibleProvider] Failed to get dimensions for model ${modelId}:`, - error - ) - // Return sensible defaults or rethrow - throw new Error( - `Unable to determine embedding dimensions for model ${modelId}: ${ - error instanceof Error ? error.message : String(error) - }` - ) - } - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/openAIProvider.ts b/src/main/presenter/llmProviderPresenter/providers/openAIProvider.ts deleted file mode 100644 index 9b97fe60c..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/openAIProvider.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { IConfigPresenter, LLM_PROVIDER, LLMResponse } from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' -export class OpenAIProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: { role: 'system' | 'user' | 'assistant'; content: string }[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `Please summarize the following content using concise language and highlighting key points:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts b/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts deleted file mode 100644 index d10ebed0a..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider.ts +++ /dev/null @@ -1,1476 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - MODEL_META, - MCPToolDefinition, - LLMCoreStreamEvent, - ModelConfig, - ChatMessage, - IConfigPresenter -} from '@shared/presenter' -import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS } from '@shared/modelConfigDefaults' -import { createStreamEvent } from '@shared/types/core/llm-events' -import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider' -import OpenAI, { AzureOpenAI } from 'openai' -import { presenter } from '@/presenter' -import { eventBus, SendTarget } from '@/eventbus' -import { NOTIFICATION_EVENTS } from '@/events' -import { jsonrepair } from 'jsonrepair' -import { app } from 'electron' -import path from 'path' -import fs from 'fs' -import sharp from 'sharp' -import { proxyConfig } from '../../proxyConfig' -import { ProxyAgent } from 'undici' -import { modelCapabilities } from '../../configPresenter/modelCapabilities' -import type { ProviderMcpRuntimePort } from '../runtimePorts' -import { applyOpenAIPromptCacheKey, resolvePromptCachePlan } from '../promptCacheStrategy' - -const OPENAI_REASONING_MODELS = [ - 'o4-mini', - 'o1-pro', - 'o3', - 'o3-pro', - 'o3-mini', - 'o3-preview', - 'o1-mini', - 'o1-pro', - 'o1-preview', - 'o1', - 'gpt-5', - 'gpt-5-mini', - 'gpt-5-nano', - 'gpt-5-chat' -] -const OPENAI_IMAGE_GENERATION_MODELS = ['gpt-4o-all', 'gpt-4o-image'] -const OPENAI_IMAGE_GENERATION_MODEL_PREFIXES = ['dall-e-', 'gpt-image-'] -const isOpenAIImageGenerationModel = (modelId: string): boolean => - OPENAI_IMAGE_GENERATION_MODELS.includes(modelId) || - OPENAI_IMAGE_GENERATION_MODEL_PREFIXES.some((prefix) => modelId.startsWith(prefix)) - -// 添加支持的图片尺寸常量 -const SUPPORTED_IMAGE_SIZES = { - SQUARE: '1024x1024', - LANDSCAPE: '1536x1024', - PORTRAIT: '1024x1536' -} as const - -// 添加可设置尺寸的模型列表 -const SIZE_CONFIGURABLE_MODELS = ['gpt-image-1', 'gpt-4o-image', 'gpt-4o-all'] - -function getOpenAIResponseCachedTokens( - usage: - | { - input_tokens_details?: { - cached_tokens?: number - cache_write_tokens?: number - } - cache_write_tokens?: number - } - | null - | undefined -): number | undefined { - const cachedTokens = usage?.input_tokens_details?.cached_tokens - return typeof cachedTokens === 'number' && Number.isFinite(cachedTokens) - ? cachedTokens - : undefined -} - -function getOpenAIResponseCacheWriteTokens(usage: unknown): number | undefined { - if (!usage || typeof usage !== 'object') { - return undefined - } - - const inputTokensDetails = (usage as { input_tokens_details?: unknown }).input_tokens_details - const nestedCacheWriteTokens = - inputTokensDetails && typeof inputTokensDetails === 'object' - ? (inputTokensDetails as Record).cache_write_tokens - : undefined - const topLevelCacheWriteTokens = (usage as Record).cache_write_tokens - const cacheWriteTokens = - typeof nestedCacheWriteTokens === 'number' ? nestedCacheWriteTokens : topLevelCacheWriteTokens - return typeof cacheWriteTokens === 'number' && Number.isFinite(cacheWriteTokens) - ? cacheWriteTokens - : undefined -} - -export class OpenAIResponsesProvider extends BaseLLMProvider { - protected openai!: OpenAI - private isNoModelsApi: boolean = false - // 添加不支持 OpenAI 标准接口的供应商黑名单 - private static readonly NO_MODELS_API_LIST: string[] = [] - - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - this.createOpenAIClient() - if (OpenAIResponsesProvider.NO_MODELS_API_LIST.includes(this.provider.id.toLowerCase())) { - this.isNoModelsApi = true - } - this.init() - } - - private supportsEffortParameter(modelId: string): boolean { - return modelCapabilities.supportsReasoningEffort(this.getCapabilityProviderId(), modelId) - } - - private supportsVerbosityParameter(modelId: string): boolean { - return modelCapabilities.supportsVerbosity(this.getCapabilityProviderId(), modelId) - } - - private resolveTraceAuthToken(): string { - return this.provider.oauthToken || this.provider.apiKey || 'MISSING_API_KEY' - } - - private buildResponsesTraceHeaders(): Record { - const headers: Record = { - 'Content-Type': 'application/json', - ...this.defaultHeaders - } - - if (this.provider.id === 'azure-openai') { - headers['api-key'] = this.resolveTraceAuthToken() - } else { - headers.Authorization = `Bearer ${this.resolveTraceAuthToken()}` - } - - return headers - } - - private buildResponsesEndpoint(): string { - const baseUrl = (this.provider.baseUrl || 'https://api.openai.com/v1').replace(/\/+$/, '') - return `${baseUrl}/responses` - } - - private createOpenAIClient(): void { - // Get proxy configuration - const proxyUrl = proxyConfig.getProxyUrl() - const fetchOptions: { dispatcher?: ProxyAgent } = {} - - if (proxyUrl) { - console.log(`[OpenAI Responses Provider] Using proxy: ${proxyUrl}`) - const proxyAgent = new ProxyAgent(proxyUrl) - fetchOptions.dispatcher = proxyAgent - } - - if (this.provider.id === 'azure-openai') { - try { - const apiVersion = this.configPresenter.getSetting('azureApiVersion') - this.openai = new AzureOpenAI({ - apiKey: this.provider.apiKey, - baseURL: this.provider.baseUrl, - apiVersion: apiVersion || '2024-02-01', - defaultHeaders: { - ...this.defaultHeaders - }, - fetchOptions - }) - } catch (e) { - console.warn('create azure openai failed', e) - } - } else { - this.openai = new OpenAI({ - apiKey: this.provider.apiKey, - baseURL: this.provider.baseUrl, - defaultHeaders: { - ...this.defaultHeaders - }, - fetchOptions - }) - } - } - - public onProxyResolved(): void { - this.createOpenAIClient() - } - - // 实现BaseLLMProvider中的抽象方法fetchProviderModels - protected async fetchProviderModels(options?: { timeout: number }): Promise { - // 检查供应商是否在黑名单中 - if (this.isNoModelsApi) { - return this.models - } - return this.fetchOpenAIModels(options) - } - - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - const response = await this.openai.models.list(options) - - return response.data.map((model) => ({ - id: model.id, - name: model.id, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, - maxTokens: DEFAULT_MODEL_MAX_TOKENS - })) - } - - /** - * User消息,上层会根据是否存在 vision 去插入 image_url - * Ass 消息,需要判断一下,把图片转换成正确的上下文,因为模型可以切换 - * @param messages - * @returns - */ - protected formatMessages(messages: ChatMessage[]): OpenAI.Responses.ResponseInput { - const result: OpenAI.Responses.ResponseInput = [] - - for (const msg of messages) { - if (msg.role === 'tool') { - result.push({ - type: 'function_call_output', - call_id: msg.tool_call_id || '', - output: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) - }) - continue - } - - if (msg.role === 'assistant' && msg.tool_calls) { - for (const toolCall of msg.tool_calls) { - result.push({ - type: 'function_call', - call_id: toolCall.id, - name: toolCall.function.name, - arguments: toolCall.function.arguments - }) - } - continue - } - - if (msg.role === 'assistant') { - const assistantContent = this.flattenAssistantContent(msg.content) - if (!assistantContent) { - continue - } - - // Responses API assistant history does not accept input_text content parts. - result.push({ - role: 'assistant', - content: assistantContent - }) - continue - } - - const content: OpenAI.Responses.ResponseInputMessageContentList = [] - - if (msg.content !== undefined) { - if (typeof msg.content === 'string') { - content.push({ - type: 'input_text', - text: msg.content - }) - } else if (Array.isArray(msg.content)) { - for (const part of msg.content) { - if (part.type === 'text' && part.text) { - content.push({ - type: 'input_text', - text: part.text - }) - } - if (part.type === 'image_url' && part.image_url?.url) { - content.push({ - type: 'input_image', - image_url: part.image_url.url, - detail: 'auto' - }) - } - } - } - } - - result.push({ - role: msg.role as 'system' | 'user' | 'assistant', - content - }) - } - - return result - } - - private flattenAssistantContent(content: ChatMessage['content']): string | null { - if (typeof content === 'string') { - return content.length > 0 ? content : null - } - - if (!Array.isArray(content)) { - return null - } - - const textContent = content.reduce((result, part) => { - if (part.type !== 'text' || part.text.length === 0) { - return result - } - return `${result}${part.text}` - }, '') - - return textContent.length > 0 ? textContent : null - } - - // OpenAI完成方法 - protected async openAICompletion( - messages: ChatMessage[], - modelId?: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (!this.isInitialized) { - throw new Error('Provider not initialized') - } - - if (!modelId) { - throw new Error('Model ID is required') - } - - const formattedMessages = this.formatMessages(messages) - const requestParams: OpenAI.Responses.ResponseCreateParamsNonStreaming = { - model: modelId, - input: formattedMessages, - temperature: temperature, - max_output_tokens: maxTokens, - stream: false - } - - const modelConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) - const promptCachePlan = resolvePromptCachePlan({ - providerId: this.provider.id, - apiType: 'openai_responses', - modelId, - messages: formattedMessages as unknown[], - conversationId: modelConfig?.conversationId - }) - if (modelConfig.reasoningEffort && this.supportsEffortParameter(modelId)) { - ;(requestParams as any).reasoning = { - effort: modelConfig.reasoningEffort - } - } - - // 仅当模型能力集声明支持时,才添加 verbosity - if (modelConfig.verbosity && this.supportsVerbosityParameter(modelId)) { - ;(requestParams as any).text = { - verbosity: modelConfig.verbosity - } - } - - OPENAI_REASONING_MODELS.forEach((noTempId) => { - if (modelId.startsWith(noTempId)) { - delete requestParams.temperature - } - }) - - const cachedRequestParams = applyOpenAIPromptCacheKey( - requestParams as unknown as Record, - promptCachePlan - ) as unknown as OpenAI.Responses.ResponseCreateParamsNonStreaming - - const response = await this.openai.responses.create(cachedRequestParams) - const resultResp: LLMResponse = { - content: '' - } - - // Use the SDK-provided aggregated assistant text for Responses API. - if (typeof response.output_text === 'string') { - resultResp.content = response.output_text - } - - // 处理 reasoning 内容 - if (response.reasoning?.summary) { - resultResp.reasoning_content = response.reasoning.summary - } - - return resultResp - } - - /////////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * 核心流处理方法,根据模型类型分发请求。 - * @param messages 聊天消息数组。 - * @param modelId 模型ID。 - * @param modelConfig 模型配置。 - * @param temperature 温度参数。 - * @param maxTokens 最大 token 数。 - * @param mcpTools MCP 工具定义数组。 - * @returns AsyncGenerator 流式事件。 - */ - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!this.isInitialized) throw new Error('Provider not initialized') - if (!modelId) throw new Error('Model ID is required') - - if (isOpenAIImageGenerationModel(modelId)) { - yield* this.handleImgGeneration(messages, modelId) - } else { - yield* this.handleChatCompletion( - messages, - modelId, - modelConfig, - temperature, - maxTokens, - mcpTools - ) - } - } - - /////////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * 处理图片生成模型请求的内部方法。 - * @param messages 聊天消息数组。 - * @param modelId 模型ID。 - * @returns AsyncGenerator 流式事件。 - */ - private async *handleImgGeneration( - messages: ChatMessage[], - modelId: string - ): AsyncGenerator { - // 获取最后几条消息,检查是否有图片 - let prompt = '' - const imageUrls: string[] = [] - // 获取最后的用户消息内容作为提示词 - const lastUserMessage = messages.findLast((m) => m.role === 'user') - if (lastUserMessage?.content) { - if (typeof lastUserMessage.content === 'string') { - prompt = lastUserMessage.content - } else if (Array.isArray(lastUserMessage.content)) { - // 处理多模态内容,提取文本 - const textParts: string[] = [] - for (const part of lastUserMessage.content) { - if (part.type === 'text' && part.text) { - textParts.push(part.text) - } - } - prompt = textParts.join('\n') - } - } - - // 检查最后几条消息中是否有图片 - // 通常我们只需要检查最后两条消息:最近的用户消息和最近的助手消息 - const lastMessages = messages.slice(-2) - for (const message of lastMessages) { - if (message.content) { - if (Array.isArray(message.content)) { - for (const part of message.content) { - if (part.type === 'image_url' && part.image_url?.url) { - imageUrls.push(part.image_url.url) - } - } - } - } - } - - if (!prompt) { - console.error('[handleImgGeneration] Could not extract prompt for image generation.') - yield createStreamEvent.error('Could not extract prompt for image generation.') - yield createStreamEvent.stop('error') - return - } - - try { - let result - - if (imageUrls.length > 0) { - // 使用 images.edit 接口处理带有图片的请求 - let imageBuffer: Buffer - - if (imageUrls[0].startsWith('imgcache://')) { - const filePath = imageUrls[0].slice('imgcache://'.length) - const fullPath = path.join(app.getPath('userData'), 'images', filePath) - imageBuffer = fs.readFileSync(fullPath) - } else { - const imageResponse = await fetch(imageUrls[0]) - const imageBlob = await imageResponse.blob() - imageBuffer = Buffer.from(await imageBlob.arrayBuffer()) - } - - // 创建临时文件 - const imagePath = `/tmp/openai_image_${Date.now()}.png` - await new Promise((resolve, reject) => { - fs.writeFile(imagePath, imageBuffer, (err: Error | null) => { - if (err) { - reject(err) - } else { - resolve() - } - }) - }) - - // 使用文件路径创建 Readable 流 - const imageFile = fs.createReadStream(imagePath) - const params: OpenAI.Images.ImageEditParams = { - model: modelId, - image: imageFile, - prompt: prompt, - n: 1 - } - - // 如果是支持尺寸配置的模型,检测图片尺寸并设置合适的参数 - if (SIZE_CONFIGURABLE_MODELS.includes(modelId)) { - try { - const metadata = await sharp(imageBuffer).metadata() - if (metadata.width && metadata.height) { - const aspectRatio = metadata.width / metadata.height - - // 根据宽高比选择最接近的尺寸 - if (Math.abs(aspectRatio - 1) < 0.1) { - // 接近正方形 - params.size = SUPPORTED_IMAGE_SIZES.SQUARE - } else if (aspectRatio > 1) { - // 横向图片 - params.size = SUPPORTED_IMAGE_SIZES.LANDSCAPE - } else { - // 纵向图片 - params.size = SUPPORTED_IMAGE_SIZES.PORTRAIT - } - } else { - // 如果无法获取宽高,使用默认参数 - params.size = '1024x1536' - } - params.quality = 'high' - } catch (error) { - console.warn( - '[handleImgGeneration] Failed to detect image dimensions, using default size:', - error - ) - // 检测失败时使用默认参数 - params.size = '1024x1536' - params.quality = 'high' - } - } - - result = await this.openai.images.edit(params) - - // 清理临时文件 - try { - fs.unlinkSync(imagePath) - } catch (e) { - console.error('[handleImgGeneration] Failed to delete temporary file:', e) - } - } else { - // 使用原来的 images.generate 接口处理没有图片的请求 - console.log( - `[handleImgGeneration] Generating image with model ${modelId} and prompt: "${prompt}"` - ) - const params: OpenAI.Images.ImageGenerateParams = { - model: modelId, - prompt: prompt, - n: 1, - output_format: 'png' - } - if (modelId === 'gpt-image-1' || modelId === 'gpt-4o-image' || modelId === 'gpt-4o-all') { - params.size = '1024x1536' - params.quality = 'high' - } - result = await this.openai.images.generate(params, { - timeout: 300_000 - }) - } - if (result.data && (result.data[0]?.url || result.data[0]?.b64_json)) { - // 使用devicePresenter缓存图片URL - try { - let imageUrl: string - if (result.data[0]?.b64_json) { - // 处理 base64 数据 - const base64Data = result.data[0].b64_json - // 直接使用 devicePresenter 缓存 base64 数据 - imageUrl = await presenter.devicePresenter.cacheImage( - base64Data.startsWith('data:image/png;base64,') - ? base64Data - : 'data:image/png;base64,' + base64Data - ) - } else { - // 原有的 URL 处理逻辑 - imageUrl = result.data[0]?.url || '' - } - - const cachedUrl = await presenter.devicePresenter.cacheImage(imageUrl) - - // 返回缓存后的URL - yield createStreamEvent.imageData({ - data: cachedUrl, - mimeType: 'deepchat/image-url' - }) - - // 处理 usage 信息 - if (result.usage) { - yield createStreamEvent.usage({ - prompt_tokens: result.usage.input_tokens || 0, - completion_tokens: result.usage.output_tokens || 0, - total_tokens: result.usage.total_tokens || 0, - cached_tokens: getOpenAIResponseCachedTokens(result.usage), - cache_write_tokens: getOpenAIResponseCacheWriteTokens(result.usage) - }) - } - - yield createStreamEvent.stop('complete') - } catch (cacheError) { - // 缓存失败时降级为使用原始URL - console.warn( - '[handleImgGeneration] Failed to cache image, using original URL:', - cacheError - ) - yield createStreamEvent.imageData({ - data: result.data[0]?.url || result.data[0]?.b64_json || '', - mimeType: 'deepchat/image-url' - }) - yield createStreamEvent.stop('complete') - } - } else { - console.error('[handleImgGeneration] No image data received from API.', result) - yield createStreamEvent.error('No image data received from API.') - yield createStreamEvent.stop('error') - } - } catch (error: unknown) { - const errorMessage = error instanceof Error ? error.message : String(error) - console.error('[handleImgGeneration] Error during image generation:', errorMessage) - yield createStreamEvent.error(`Image generation failed: ${errorMessage}`) - yield createStreamEvent.stop('error') - } - } - - /////////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * 处理 OpenAI Responses 聊天补全模型请求的内部方法。 - * @param messages 聊天消息数组。 - * @param modelId 模型ID。 - * @param modelConfig 模型配置。 - * @param temperature 温度参数。 - * @param maxTokens 最大 token 数。 - * @param mcpTools MCP 工具定义数组。 - * @returns AsyncGenerator 流式事件。 - */ - private async *handleChatCompletion( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - const tools = mcpTools || [] - const supportsFunctionCall = modelConfig?.functionCall || false - let processedMessages = this.formatMessages(messages) - if (tools.length > 0 && !supportsFunctionCall) { - processedMessages = this.prepareFunctionCallPrompt(processedMessages, tools) - } - const apiTools = - tools.length > 0 && supportsFunctionCall - ? await this.mcpRuntime?.mcpToolsToOpenAIResponsesTools(tools, this.provider.id) - : undefined - - const requestParams: OpenAI.Responses.ResponseCreateParamsStreaming = { - model: modelId, - input: processedMessages, - temperature, - max_output_tokens: maxTokens, - stream: true - } - const promptCachePlan = resolvePromptCachePlan({ - providerId: this.provider.id, - apiType: 'openai_responses', - modelId, - messages: processedMessages as unknown[], - tools, - conversationId: modelConfig?.conversationId - }) - - // 如果模型支持函数调用且有工具,添加 tools 参数 - if (tools.length > 0 && supportsFunctionCall && apiTools) { - requestParams.tools = apiTools - } - if (modelConfig.reasoningEffort && this.supportsEffortParameter(modelId)) { - ;(requestParams as any).reasoning = { - effort: modelConfig.reasoningEffort - } - } - - // 仅当模型能力集声明支持时,才添加 verbosity - if (modelConfig.verbosity && this.supportsVerbosityParameter(modelId)) { - ;(requestParams as any).text = { - verbosity: modelConfig.verbosity - } - } - - OPENAI_REASONING_MODELS.forEach((noTempId) => { - if (modelId.startsWith(noTempId)) delete requestParams.temperature - }) - - const cachedRequestParams = applyOpenAIPromptCacheKey( - requestParams as unknown as Record, - promptCachePlan - ) as unknown as OpenAI.Responses.ResponseCreateParamsStreaming - - await this.emitRequestTrace(modelConfig, { - endpoint: this.buildResponsesEndpoint(), - headers: this.buildResponsesTraceHeaders(), - body: cachedRequestParams - }) - - const stream = await this.openai.responses.create(cachedRequestParams) - - // --- State Variables --- - type TagState = 'none' | 'start' | 'inside' | 'end' - let thinkState: TagState = 'none' - let funcState: TagState = 'none' // Only relevant if !supportsFunctionCall - - let pendingBuffer = '' // Buffer for tag matching and potential text output - let thinkBuffer = '' // Buffer for reasoning content - let funcCallBuffer = '' // Buffer for non-native function call content - - const thinkStartMarker = '' - const thinkEndMarker = '' - const funcStartMarker = '' - const funcEndMarker = '' - - const nativeToolCalls: Record< - string, - { name: string; arguments: string; completed?: boolean; itemId?: string } - > = {} - const nativeToolCallIdByItemId: Record = {} - const nativeToolCallIdByOutputIndex: Record = {} - const stopReason: LLMCoreStreamEvent['stop_reason'] = 'complete' - let toolUseDetected = false - let usage: - | { - prompt_tokens: number - completion_tokens: number - total_tokens: number - cached_tokens?: number - cache_write_tokens?: number - } - | undefined = undefined - - // --- Stream Processing Loop --- - for await (const chunk of stream) { - // 处理函数调用相关事件 - if (supportsFunctionCall && tools.length > 0) { - if (chunk.type === 'response.output_item.added') { - const item = chunk.item - if (item.type === 'function_call') { - toolUseDetected = true - const callId = item.call_id - if (callId) { - nativeToolCalls[callId] = { - name: item.name, - arguments: item.arguments || '', - completed: false, - itemId: item.id - } - nativeToolCallIdByOutputIndex[chunk.output_index] = callId - if (item.id) { - nativeToolCallIdByItemId[item.id] = callId - } - yield { - type: 'tool_call_start', - tool_call_id: callId, - tool_call_name: item.name - } - } - } - } else if (chunk.type === 'response.function_call_arguments.delta') { - const itemId = chunk.item_id - const delta = chunk.delta - const callId = - nativeToolCallIdByItemId[itemId] || nativeToolCallIdByOutputIndex[chunk.output_index] - if (callId && !nativeToolCallIdByItemId[itemId]) { - nativeToolCallIdByItemId[itemId] = callId - } - const toolCall = callId ? nativeToolCalls[callId] : undefined - if (toolCall) { - toolCall.arguments += delta - yield { - type: 'tool_call_chunk', - tool_call_id: callId, - tool_call_arguments_chunk: delta - } - } - } else if (chunk.type === 'response.function_call_arguments.done') { - const itemId = chunk.item_id - const argsData = chunk.arguments - const callId = - nativeToolCallIdByItemId[itemId] || nativeToolCallIdByOutputIndex[chunk.output_index] - if (callId && !nativeToolCallIdByItemId[itemId]) { - nativeToolCallIdByItemId[itemId] = callId - } - const toolCall = callId ? nativeToolCalls[callId] : undefined - if (toolCall) { - toolCall.arguments = argsData - toolCall.completed = true - yield { - type: 'tool_call_end', - tool_call_id: callId, - tool_call_arguments_complete: argsData - } - } - } else if (chunk.type === 'response.output_item.done') { - const item = chunk.item - if (item.type === 'function_call') { - nativeToolCallIdByOutputIndex[chunk.output_index] = item.call_id - if (item.id) { - nativeToolCallIdByItemId[item.id] = item.call_id - } - const toolCall = nativeToolCalls[item.call_id] - if (toolCall && !toolCall.completed) { - toolCall.completed = true - yield { - type: 'tool_call_end', - tool_call_id: item.call_id, - tool_call_arguments_complete: item.arguments - } - } - } - } - } - - // 处理文本增量 - if (chunk.type === 'response.output_text.delta') { - const content = chunk.delta - for (const char of content) { - pendingBuffer += char - let processedChar = false - - // --- Thinking Tag Processing (Inside or End states) --- - if (thinkState === 'inside') { - if (pendingBuffer.endsWith(thinkEndMarker)) { - thinkState = 'none' - if (thinkBuffer) { - yield createStreamEvent.reasoning(thinkBuffer) - thinkBuffer = '' - } - pendingBuffer = '' - processedChar = true - } else if (thinkEndMarker.startsWith(pendingBuffer)) { - thinkState = 'end' - processedChar = true - } else if (pendingBuffer.length >= thinkEndMarker.length) { - const charsToYield = pendingBuffer.slice(0, -thinkEndMarker.length + 1) - if (charsToYield) { - thinkBuffer += charsToYield - yield createStreamEvent.reasoning(charsToYield) - } - pendingBuffer = pendingBuffer.slice(-thinkEndMarker.length + 1) - if (thinkEndMarker.startsWith(pendingBuffer)) { - thinkState = 'end' - } else { - thinkBuffer += pendingBuffer - yield createStreamEvent.reasoning(pendingBuffer) - pendingBuffer = '' - thinkState = 'inside' - } - processedChar = true - } else { - thinkBuffer += char - yield createStreamEvent.reasoning(char) - pendingBuffer = '' - processedChar = true - } - } else if (thinkState === 'end') { - if (pendingBuffer.endsWith(thinkEndMarker)) { - thinkState = 'none' - if (thinkBuffer) { - yield createStreamEvent.reasoning(thinkBuffer) - thinkBuffer = '' - } - pendingBuffer = '' - processedChar = true - } else if (!thinkEndMarker.startsWith(pendingBuffer)) { - const failedTagChars = pendingBuffer - thinkBuffer += failedTagChars - yield createStreamEvent.reasoning(failedTagChars) - pendingBuffer = '' - thinkState = 'inside' - processedChar = true - } else { - processedChar = true - } - } - - // --- Function Call Tag Processing (Inside or End states, if applicable) --- - else if ( - !supportsFunctionCall && - tools.length > 0 && - (funcState === 'inside' || funcState === 'end') - ) { - processedChar = true // Assume processed unless logic below changes state back - if (funcState === 'inside') { - if (pendingBuffer.endsWith(funcEndMarker)) { - funcState = 'none' - funcCallBuffer += pendingBuffer.slice(0, -funcEndMarker.length) - pendingBuffer = '' - toolUseDetected = true - console.log( - `[handleChatCompletion] Non-native end tag detected. Buffer to parse:`, - funcCallBuffer - ) - const parsedCalls = this.parseFunctionCalls( - `${funcStartMarker}${funcCallBuffer}${funcEndMarker}` - ) - for (const parsedCall of parsedCalls) { - yield { - type: 'tool_call_start', - tool_call_id: parsedCall.id, - tool_call_name: parsedCall.function.name - } - yield { - type: 'tool_call_chunk', - tool_call_id: parsedCall.id, - tool_call_arguments_chunk: parsedCall.function.arguments - } - yield { - type: 'tool_call_end', - tool_call_id: parsedCall.id, - tool_call_arguments_complete: parsedCall.function.arguments - } - } - funcCallBuffer = '' - } else if (funcEndMarker.startsWith(pendingBuffer)) { - funcState = 'end' - } else if (pendingBuffer.length >= funcEndMarker.length) { - const charsToAdd = pendingBuffer.slice(0, -funcEndMarker.length + 1) - funcCallBuffer += charsToAdd - pendingBuffer = pendingBuffer.slice(-funcEndMarker.length + 1) - if (funcEndMarker.startsWith(pendingBuffer)) { - funcState = 'end' - } else { - funcCallBuffer += pendingBuffer - pendingBuffer = '' - funcState = 'inside' - } - } else { - funcCallBuffer += char - pendingBuffer = '' - } - } else { - // funcState === 'end' - if (pendingBuffer.endsWith(funcEndMarker)) { - funcState = 'none' - pendingBuffer = '' - toolUseDetected = true - console.log( - `[handleChatCompletion] Non-native end tag detected (from end state). Buffer to parse:`, - funcCallBuffer - ) - const parsedCalls = this.parseFunctionCalls( - `${funcStartMarker}${funcCallBuffer}${funcEndMarker}` - ) - for (const parsedCall of parsedCalls) { - yield { - type: 'tool_call_start', - tool_call_id: parsedCall.id, - tool_call_name: parsedCall.function.name - } - yield { - type: 'tool_call_chunk', - tool_call_id: parsedCall.id, - tool_call_arguments_chunk: parsedCall.function.arguments - } - yield { - type: 'tool_call_end', - tool_call_id: parsedCall.id, - tool_call_arguments_complete: parsedCall.function.arguments - } - } - funcCallBuffer = '' - } else if (!funcEndMarker.startsWith(pendingBuffer)) { - funcCallBuffer += pendingBuffer - pendingBuffer = '' - funcState = 'inside' - } - } - } - - // --- General Text / Start Tag Detection (When not inside any tag) --- - if (!processedChar) { - let potentialThink = thinkStartMarker.startsWith(pendingBuffer) - let potentialFunc = - !supportsFunctionCall && tools.length > 0 && funcStartMarker.startsWith(pendingBuffer) - const matchedThink = pendingBuffer.endsWith(thinkStartMarker) - const matchedFunc = - !supportsFunctionCall && tools.length > 0 && pendingBuffer.endsWith(funcStartMarker) - - // --- Handle Full Matches First --- - if (matchedThink) { - const textBefore = pendingBuffer.slice(0, -thinkStartMarker.length) - if (textBefore) { - yield createStreamEvent.text(textBefore) - } - console.log( - '[handleChatCompletion] start tag matched. Entering inside state.' - ) - thinkState = 'inside' - funcState = 'none' // Reset other state - pendingBuffer = '' - } else if (matchedFunc) { - const textBefore = pendingBuffer.slice(0, -funcStartMarker.length) - if (textBefore) { - yield createStreamEvent.text(textBefore) - } - console.log( - '[handleChatCompletion] Non-native start tag detected. Entering inside state.' - ) - funcState = 'inside' - thinkState = 'none' // Reset other state - pendingBuffer = '' - } - // --- Handle Partial Matches (Keep Accumulating) --- - else if (potentialThink || potentialFunc) { - // If potentially matching either, just keep the buffer and wait for more chars - // Update state but don't yield anything - thinkState = potentialThink ? 'start' : 'none' - funcState = potentialFunc ? 'start' : 'none' - } - // --- Handle No Match / Failure --- - else if (pendingBuffer.length > 0) { - // Buffer doesn't start with '<', or starts with '<' but doesn't match start of either tag anymore - const charToYield = pendingBuffer[0] - yield createStreamEvent.text(charToYield) - pendingBuffer = pendingBuffer.slice(1) - // Re-evaluate potential matches with the shortened buffer immediately - potentialThink = - pendingBuffer.length > 0 && thinkStartMarker.startsWith(pendingBuffer) - potentialFunc = - pendingBuffer.length > 0 && - !supportsFunctionCall && - tools.length > 0 && - funcStartMarker.startsWith(pendingBuffer) - thinkState = potentialThink ? 'start' : 'none' - funcState = potentialFunc ? 'start' : 'none' - } - } - } - } - - if (chunk.type === 'response.completed') { - const response = chunk.response - if (response.usage) { - usage = { - prompt_tokens: response.usage.input_tokens || 0, - completion_tokens: response.usage.output_tokens || 0, - total_tokens: response.usage.total_tokens || 0, - cached_tokens: getOpenAIResponseCachedTokens(response.usage), - cache_write_tokens: getOpenAIResponseCacheWriteTokens(response.usage) - } - yield createStreamEvent.usage(usage) - } - - if (response.reasoning?.summary) { - yield createStreamEvent.reasoning(response.reasoning.summary) - } - - yield createStreamEvent.stop(toolUseDetected ? 'tool_use' : stopReason) - return - } - - if ('error' in chunk) { - const errorChunk = chunk as { error: { message?: string } } - yield createStreamEvent.error(errorChunk.error?.message || 'Unknown error occurred') - yield createStreamEvent.stop('error') - return - } - } - - // --- Finalization --- - // Yield any remaining text in the buffer - if (pendingBuffer) { - console.warn('[handleChatCompletion] Finalizing with non-empty pendingBuffer:', pendingBuffer) - // Decide how to yield based on final state - if (thinkState === 'inside' || thinkState === 'end') { - yield createStreamEvent.reasoning(pendingBuffer) - thinkBuffer += pendingBuffer - } else if (funcState === 'inside' || funcState === 'end') { - // Add remaining to func buffer - it will be handled below - funcCallBuffer += pendingBuffer - } else { - yield createStreamEvent.text(pendingBuffer) - } - pendingBuffer = '' - } - - // Yield remaining reasoning content - if (thinkBuffer) { - console.warn( - '[handleChatCompletion] Finalizing with non-empty thinkBuffer (should have been yielded):', - thinkBuffer - ) - } - - // Handle incomplete non-native function call - if (funcCallBuffer) { - console.warn( - '[handleChatCompletion] Finalizing with non-empty function call buffer (likely incomplete tag):', - funcCallBuffer - ) - // Attempt to parse what we have, might fail - const potentialContent = `${funcStartMarker}${funcCallBuffer}` - try { - const parsedCalls = this.parseFunctionCalls(potentialContent) - if (parsedCalls.length > 0) { - toolUseDetected = true - for (const parsedCall of parsedCalls) { - yield { - type: 'tool_call_start', - tool_call_id: parsedCall.id + '-incomplete', - tool_call_name: parsedCall.function.name - } - yield { - type: 'tool_call_chunk', - tool_call_id: parsedCall.id + '-incomplete', - tool_call_arguments_chunk: parsedCall.function.arguments - } - yield { - type: 'tool_call_end', - tool_call_id: parsedCall.id + '-incomplete', - tool_call_arguments_complete: parsedCall.function.arguments - } - } - } else { - console.log( - '[handleChatCompletion] Incomplete function call buffer parsing yielded no calls. Emitting as text.' - ) - yield createStreamEvent.text(potentialContent) - } - } catch (e) { - console.error('[handleChatCompletion] Error parsing incomplete function call buffer:', e) - yield createStreamEvent.text(potentialContent) - } - funcCallBuffer = '' - } - } - - private prepareFunctionCallPrompt( - messages: OpenAI.Responses.ResponseInput, - mcpTools: MCPToolDefinition[] - ): OpenAI.Responses.ResponseInput { - console.log('prepareFunc') - // 创建消息副本而不是直接修改原始消息 - const result = [...messages] - - const functionCallPrompt = this.getFunctionCallWrapPrompt(mcpTools) - - // 找到最后一条用户消息 - const lastUserMessageIndex = result.findLastIndex( - (message) => 'role' in message && message.role === 'user' - ) - - if (lastUserMessageIndex !== -1) { - const userMessage = result[lastUserMessageIndex] - if ('content' in userMessage) { - if (Array.isArray(userMessage.content)) { - // 创建新的 content 数组 - const newContent: OpenAI.Responses.ResponseInputMessageContentList = [] - let hasAddedPrompt = false - - // 遍历现有的 content 数组 - for (const content of userMessage.content) { - if (content.type === 'input_text' && !hasAddedPrompt) { - // 为第一个文本内容添加提示词 - newContent.push({ - type: 'input_text', - text: `${functionCallPrompt}\n\n${content.text}` - } as OpenAI.Responses.ResponseInputText) - hasAddedPrompt = true - } else if (content.type === 'input_text' || content.type === 'input_image') { - // 其他内容直接复制 - newContent.push(content as OpenAI.Responses.ResponseInputContent) - } - } - - // 如果没有找到文本内容,在开头添加提示词 - if (!hasAddedPrompt) { - newContent.unshift({ - type: 'input_text', - text: functionCallPrompt - } as OpenAI.Responses.ResponseInputText) - } - - // 更新消息的 content - result[lastUserMessageIndex] = { - ...userMessage, - content: newContent - } as OpenAI.Responses.ResponseInput[number] - } else if (typeof userMessage.content === 'string') { - // 如果 content 是字符串,直接添加提示词 - result[lastUserMessageIndex] = { - ...userMessage, - content: [ - { - type: 'input_text', - text: `${functionCallPrompt}\n\n${userMessage.content}` - } as OpenAI.Responses.ResponseInputText - ] - } as OpenAI.Responses.ResponseInput[number] - } - } - } - - return result - } - - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - if (!this.isNoModelsApi) { - // Use unified timeout configuration from base class - const models = await this.fetchOpenAIModels({ timeout: this.getModelFetchTimeout() }) - this.models = models // Store fetched models - } - // Potentially add a simple API call test here if needed, e.g., list models even for no-API list to check key/endpoint - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - // Use unknown for type safety - let errorMessage = 'An unknown error occurred during provider check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - // Optionally log the full error object for debugging - console.error('OpenAIResponsesProvider check failed:', error) - - eventBus.sendToRenderer(NOTIFICATION_EVENTS.SHOW_ERROR, SendTarget.ALL_WINDOWS, { - title: 'API Check Failed', // More specific title - message: errorMessage, - id: `openai-check-error-${Date.now()}`, - type: 'error' - }) - return { isOk: false, errorMsg: errorMessage } - } - } - - public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { - const summaryText = `${SUMMARY_TITLES_PROMPT}\n\n${messages.map((m) => `${m.role}: ${m.content}`).join('\n')}` - const fullMessage: ChatMessage[] = [{ role: 'user', content: summaryText }] - const response = await this.openAICompletion(fullMessage, modelId, 0.5) - return response.content.replace(/["']/g, '').trim() - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - // Simple completion, no specific system prompt needed unless required by base class or future design - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - const systemPrompt = `Summarize the following text concisely:` - // Create messages based on the input text - const requestMessages: ChatMessage[] = [ - { role: 'system', content: systemPrompt }, - { role: 'user', content: text } // Use the input text directly - ] - return this.openAICompletion(requestMessages, modelId, temperature, maxTokens) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - // Use the prompt directly as the user message content - const requestMessages: ChatMessage[] = [{ role: 'user', content: prompt }] - // Note: formatMessages might not be needed here if it's just a single prompt string, - // but keeping it for consistency in case formatMessages adds system prompts or other logic. - return this.openAICompletion(requestMessages, modelId, temperature, maxTokens) - } - - async suggestions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - const systemPrompt = `Based on the last user message in the conversation history, provide 3 brief, relevant follow-up suggestions or questions. Output ONLY the suggestions, each on a new line. Do not include numbering, bullet points, or introductory text like "Here are some suggestions:".` - const lastUserMessage = messages.filter((m) => m.role === 'user').pop() // Get the most recent user message - - if (!lastUserMessage) { - console.warn('suggestions called without user messages.') - return [] // Return empty array if no user message found - } - - // Provide some context if possible, e.g., last few messages - const contextMessages = messages.slice(-5) // Last 5 messages as context - - const requestMessages: ChatMessage[] = [ - { role: 'system', content: systemPrompt }, - // Include context leading up to the last user message - ...contextMessages - ] - - try { - const response = await this.openAICompletion( - requestMessages, - modelId, - temperature ?? 0.7, - maxTokens ?? 60 - ) // Adjusted temp/tokens - // Split, trim, and filter results robustly - return response.content - .split('\n') - .map((s) => s.trim()) - .filter((s) => s.length > 0 && !s.match(/^[0-9.\-*\s]*/)) // Fixed regex range - } catch (error) { - console.error('Failed to get suggestions:', error) - return [] // Return empty on error - } - } - - protected parseFunctionCalls( - response: string, - fallbackIdPrefix: string = 'tool-call' - ): Array<{ id: string; type: string; function: { name: string; arguments: string } }> { - try { - // 使用非贪婪模式匹配function_call标签对,能够处理多行内容 - const functionCallMatches = response.match(/([\s\S]*?)<\/function_call>/gs) - if (!functionCallMatches) { - return [] - } - - const toolCalls = functionCallMatches - .map((match, index) => { - const content = match.replace(/<\/?function_call>/g, '').trim() - if (!content) { - return null // Skip empty content between tags - } - - try { - let parsedCall - let repairedJson: string | undefined - try { - // 首先尝试标准 JSON 解析 - parsedCall = JSON.parse(content) - } catch { - try { - // 如果标准解析失败,使用 jsonrepair 进行修复 - repairedJson = jsonrepair(content) - parsedCall = JSON.parse(repairedJson) - } catch (repairError) { - console.error( - `[parseFunctionCalls] Failed to parse content for match ${index} even with jsonrepair:`, - repairError, - 'Original content:', - content, - 'Repaired content attempt:', - repairedJson ?? 'N/A' - ) - return null // Skip this malformed call - } - } - - // 提取名称和参数,处理各种可能的结构 - let functionName, functionArgs - if (parsedCall.function_call && typeof parsedCall.function_call === 'object') { - functionName = parsedCall.function_call.name - functionArgs = parsedCall.function_call.arguments - } else if (parsedCall.name && parsedCall.arguments !== undefined) { - functionName = parsedCall.name - functionArgs = parsedCall.arguments - } else if ( - parsedCall.function && - typeof parsedCall.function === 'object' && - parsedCall.function.name - ) { - functionName = parsedCall.function.name - functionArgs = parsedCall.function.arguments - } else { - // 尝试在单个键下查找函数调用结构 - const keys = Object.keys(parsedCall) - if (keys.length === 1) { - const potentialToolCall = parsedCall[keys[0]] - if (potentialToolCall && typeof potentialToolCall === 'object') { - if (potentialToolCall.name && potentialToolCall.arguments !== undefined) { - functionName = potentialToolCall.name - functionArgs = potentialToolCall.arguments - } else if ( - potentialToolCall.function && - typeof potentialToolCall.function === 'object' && - potentialToolCall.function.name - ) { - functionName = potentialToolCall.function.name - functionArgs = potentialToolCall.function.arguments - } - } - } - - if (!functionName) { - console.error( - '[parseFunctionCalls] Could not determine function name from parsed call:', - parsedCall - ) - return null - } - } - - // 确保参数是字符串格式 - if (typeof functionArgs !== 'string') { - try { - functionArgs = JSON.stringify(functionArgs) - } catch (stringifyError) { - console.error( - '[parseFunctionCalls] Failed to stringify function arguments:', - stringifyError, - functionArgs - ) - functionArgs = '{"error": "failed to stringify arguments"}' - } - } - - // 生成唯一ID - const id = parsedCall.id || functionName || `${fallbackIdPrefix}-${index}-${Date.now()}` - - return { - id: String(id), - type: 'function', - function: { - name: String(functionName), - arguments: functionArgs - } - } - } catch (processingError) { - console.error( - '[parseFunctionCalls] Error processing parsed function call JSON:', - processingError, - 'Content:', - content - ) - return null - } - }) - .filter( - ( - call - ): call is { id: string; type: string; function: { name: string; arguments: string } } => - call !== null && - typeof call.id === 'string' && - typeof call.function === 'object' && - call.function !== null && - typeof call.function.name === 'string' && - typeof call.function.arguments === 'string' - ) - - return toolCalls - } catch (error) { - console.error( - '[parseFunctionCalls] Unexpected error during execution:', - error, - 'Input:', - response - ) - return [] - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/openRouterProvider.ts b/src/main/presenter/llmProviderPresenter/providers/openRouterProvider.ts deleted file mode 100644 index 04b274ef5..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/openRouterProvider.ts +++ /dev/null @@ -1,298 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - ChatMessage, - KeyStatus, - MODEL_META, - IConfigPresenter -} from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Define interface for OpenRouter API key response -interface OpenRouterKeyResponse { - data: { - label: string - usage: number - is_free_tier: boolean - is_provisioning_key: boolean - limit: number | null - limit_remaining: number | null - rate_limit: { - requests: number - interval: string - } - } -} - -// Define interface for OpenRouter model response based on their API documentation -interface OpenRouterModelResponse { - id: string - name: string - description: string - created: number - context_length: number - architecture?: { - input_modalities: string[] // ["file", "image", "text"] - output_modalities: string[] // ["text"] - tokenizer: string - instruct_type: string | null - } - pricing: { - prompt: string - completion: string - request: string - image: string - web_search: string - internal_reasoning: string - input_cache_read: string - input_cache_write: string - } - top_provider?: { - context_length: number - max_completion_tokens: number - is_moderated: boolean - } - per_request_limits: any - supported_parameters?: string[] -} - -export class OpenRouterProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `You need to summarize the user's conversation into a title of no more than 10 words, with the title language matching the user's primary language, without using punctuation or other special symbols:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - - /** - * Get current API key status from OpenRouter - * @returns Promise API key status information - */ - public async getKeyStatus(): Promise { - if (!this.provider.apiKey) { - throw new Error('API key is required') - } - - const response = await fetch('https://openrouter.ai/api/v1/key', { - method: 'GET', - headers: { - Authorization: `Bearer ${this.provider.apiKey}`, - 'Content-Type': 'application/json' - } - }) - - if (response.status !== 200) { - const errorText = await response.text() - throw new Error( - `OpenRouter API key check failed: ${response.status} ${response.statusText} - ${errorText}` - ) - } - - const responseText = await response.text() - if (!responseText || responseText.trim().length === 0) { - throw new Error('OpenRouter API returned empty response') - } - - const keyResponse: OpenRouterKeyResponse = JSON.parse(responseText) - if (!keyResponse.data) { - throw new Error(`OpenRouter API response missing 'data' field`) - } - - // Build KeyStatus based on available data - const keyStatus: KeyStatus = { - usage: '$' + keyResponse.data.usage - } - - // Only include limit_remaining if it's not null (has actual limit) - if (keyResponse.data.limit_remaining !== null) { - keyStatus.limit_remaining = '$' + keyResponse.data.limit_remaining - keyStatus.remainNum = keyResponse.data.limit_remaining - } - - return keyStatus - } - - /** - * Override check method to use OpenRouter's API key status endpoint - * @returns Promise<{ isOk: boolean; errorMsg: string | null }> - */ - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - const keyStatus = await this.getKeyStatus() - - // Check if there's remaining quota (only if limit_remaining exists) - if (keyStatus.remainNum !== undefined && keyStatus.remainNum <= 0) { - return { - isOk: false, - errorMsg: `API key quota exhausted. Remaining: ${keyStatus.limit_remaining}` - } - } - - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - let errorMessage = 'An unknown error occurred during OpenRouter API key check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - - console.error('OpenRouter API key check failed:', error) - return { isOk: false, errorMsg: errorMessage } - } - } - - /** - * Override fetchOpenAIModels to parse OpenRouter specific model data and update model configurations - * @param options - Request options - * @returns Promise - Array of model metadata - */ - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - try { - const response = await this.openai.models.list(options) - // console.log('OpenRouter models response:', JSON.stringify(response, null, 2)) - - const models: MODEL_META[] = [] - - for (const model of response.data) { - // Type the model as OpenRouter specific response - const openRouterModel = model as unknown as OpenRouterModelResponse - - // Extract model information - const modelId = openRouterModel.id - const supportedParameters = openRouterModel.supported_parameters || [] - const inputModalities = openRouterModel.architecture?.input_modalities || [] - - // Check capabilities based on supported parameters and architecture - const hasFunctionCalling = supportedParameters.includes('tools') - const hasVision = inputModalities.includes('image') - const hasReasoning = - supportedParameters.includes('reasoning') || - supportedParameters.includes('include_reasoning') - - // Get existing model configuration first - const existingConfig = - this.configPresenter.getModelConfig(modelId, this.provider.id) ?? ({} as const) - - // Extract configuration values with proper fallback priority: API -> existing config -> default - const contextLength = - openRouterModel.context_length || - openRouterModel.top_provider?.context_length || - existingConfig.contextLength || - 4096 - const maxTokens = - openRouterModel.top_provider?.max_completion_tokens || existingConfig.maxTokens || 2048 - - // Build new configuration based on API response - const newConfig = { - contextLength: contextLength, - maxTokens: maxTokens, - functionCall: hasFunctionCalling, - vision: hasVision, - reasoning: hasReasoning || existingConfig.reasoning, // Use API info or keep existing - temperature: existingConfig.temperature, // Keep existing temperature - type: existingConfig.type // Keep existing type - } - - // Check if configuration has changed - const configChanged = - existingConfig.contextLength !== newConfig.contextLength || - existingConfig.maxTokens !== newConfig.maxTokens || - existingConfig.functionCall !== newConfig.functionCall || - existingConfig.vision !== newConfig.vision || - existingConfig.reasoning !== newConfig.reasoning - - // Update configuration if changed - if (configChanged) { - // console.log(`Updating OpenRouter configuration for model ${modelId}:`, { - // old: { - // contextLength: existingConfig.contextLength, - // maxTokens: existingConfig.maxTokens, - // functionCall: existingConfig.functionCall, - // vision: existingConfig.vision, - // reasoning: existingConfig.reasoning - // }, - // new: newConfig - // }) - - this.configPresenter.setModelConfig(modelId, this.provider.id, newConfig, { - source: 'provider' - }) - } - - // Create MODEL_META object - const modelMeta: MODEL_META = { - id: modelId, - name: openRouterModel.name || modelId, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: contextLength, - maxTokens: maxTokens, - description: openRouterModel.description, - vision: hasVision, - functionCall: hasFunctionCalling, - reasoning: hasReasoning || existingConfig.reasoning || false - } - - models.push(modelMeta) - } - - console.log(`Processed ${models.length} OpenRouter models with dynamic configuration updates`) - return models - } catch (error) { - console.error('Error fetching OpenRouter models:', error) - // Fallback to parent implementation - return super.fetchOpenAIModels(options) - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/poeProvider.ts b/src/main/presenter/llmProviderPresenter/providers/poeProvider.ts deleted file mode 100644 index a9f73b8c2..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/poeProvider.ts +++ /dev/null @@ -1,30 +0,0 @@ -import { LLM_PROVIDER, MODEL_META, IConfigPresenter } from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -/** - * PoeProvider integrates Poe's OpenAI-compatible API surface with the shared - * BaseLLMProvider contract so the rest of the app can treat it just like - * any other OpenAI-style backend. - * - * Poe exposes hundreds of community and frontier models through a single - * endpoint. We reuse the OpenAICompatibleProvider implementation and only - * tweak metadata so the renderer can present a clearer group name. - */ -export class PoeProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - const models = await super.fetchOpenAIModels(options) - return models.map((model) => ({ - ...model, - group: 'Poe' - })) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/ppioProvider.ts b/src/main/presenter/llmProviderPresenter/providers/ppioProvider.ts deleted file mode 100644 index eb9476466..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/ppioProvider.ts +++ /dev/null @@ -1,243 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - ChatMessage, - KeyStatus, - MODEL_META, - IConfigPresenter -} from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Define interface for PPIO API key response -interface PPIOKeyResponse { - credit_balance: number -} - -// Define interface for PPIO model response -interface PPIOModelResponse { - id: string - object: string - owned_by: string - created: number - display_name: string - description: string - context_size: number - max_output_tokens: number - features?: string[] - status: number - model_type: string -} - -export class PPIOProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `You need to summarize the user's conversation into a title of no more than 10 words, with the title language matching the user's primary language, without using punctuation or other special symbols:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - - /** - * Get current API key status from PPIO - * @returns Promise API key status information - */ - public async getKeyStatus(): Promise { - if (!this.provider.apiKey) { - throw new Error('API key is required') - } - - const response = await fetch('https://api.ppinfra.com/v3/user', { - method: 'GET', - headers: { - Authorization: this.provider.apiKey, - 'Content-Type': 'application/json' - } - }) - - if (!response.ok) { - const errorText = await response.text() - throw new Error( - `PPIO API key check failed: ${response.status} ${response.statusText} - ${errorText}` - ) - } - - const keyResponse: PPIOKeyResponse = await response.json() - const remaining = '¥' + keyResponse.credit_balance / 10000 - return { - limit_remaining: remaining, - remainNum: keyResponse.credit_balance - } - } - - /** - * Override check method to use PPIO's API key status endpoint - * @returns Promise<{ isOk: boolean; errorMsg: string | null }> - */ - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - const keyStatus = await this.getKeyStatus() - - // Check if there's remaining quota - if (keyStatus.remainNum !== undefined && keyStatus.remainNum <= 0) { - return { - isOk: false, - errorMsg: `API key quota exhausted. Remaining: ${keyStatus.limit_remaining}` - } - } - - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - let errorMessage = 'An unknown error occurred during PPIO API key check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - - console.error('PPIO API key check failed:', error) - return { isOk: false, errorMsg: errorMessage } - } - } - - /** - * Override fetchOpenAIModels to parse PPIO specific model data and update model configurations - * @param options - Request options - * @returns Promise - Array of model metadata - */ - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - try { - const response = await this.openai.models.list(options) - // console.log('PPIO models response:', JSON.stringify(response, null, 2)) - - const models: MODEL_META[] = [] - - for (const model of response.data) { - // Type the model as PPIO specific response - const ppioModel = model as unknown as PPIOModelResponse - - // Extract model information - const modelId = ppioModel.id - const features = ppioModel.features || [] - - // Check features for capabilities - const hasFunctionCalling = features.includes('function-calling') - const hasVision = features.includes('vision') - // const hasStructuredOutputs = features.includes('structured-outputs') - - // Get existing model configuration first - const existingConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) - - // Extract configuration values with proper fallback priority: API -> existing config -> default - const contextLength = ppioModel.context_size || existingConfig.contextLength || 4096 - const maxTokens = ppioModel.max_output_tokens || existingConfig.maxTokens || 2048 - - // Build new configuration based on API response - const newConfig = { - contextLength: contextLength, - maxTokens: maxTokens, - functionCall: hasFunctionCalling, - vision: hasVision, - reasoning: existingConfig.reasoning, // Keep existing reasoning setting - temperature: existingConfig.temperature, // Keep existing temperature - type: existingConfig.type // Keep existing type - } - - // Check if configuration has changed - const configChanged = - existingConfig.contextLength !== newConfig.contextLength || - existingConfig.maxTokens !== newConfig.maxTokens || - existingConfig.functionCall !== newConfig.functionCall || - existingConfig.vision !== newConfig.vision - - // Update configuration if changed - if (configChanged) { - // console.log(`Updating configuration for model ${modelId}:`, { - // old: { - // contextLength: existingConfig.contextLength, - // maxTokens: existingConfig.maxTokens, - // functionCall: existingConfig.functionCall, - // vision: existingConfig.vision - // }, - // new: newConfig - // }) - - this.configPresenter.setModelConfig(modelId, this.provider.id, newConfig, { - source: 'provider' - }) - } - - // Create MODEL_META object - const modelMeta: MODEL_META = { - id: modelId, - name: ppioModel.display_name || modelId, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: contextLength, - maxTokens: maxTokens, - description: ppioModel.description, - vision: hasVision, - functionCall: hasFunctionCalling, - reasoning: existingConfig.reasoning || false - } - - models.push(modelMeta) - } - - console.log(`Processed ${models.length} PPIO models with dynamic configuration updates`) - return models - } catch (error) { - console.error('Error fetching PPIO models:', error) - // Fallback to parent implementation - return super.fetchOpenAIModels(options) - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/siliconcloudProvider.ts b/src/main/presenter/llmProviderPresenter/providers/siliconcloudProvider.ts deleted file mode 100644 index d94e6c148..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/siliconcloudProvider.ts +++ /dev/null @@ -1,246 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - MODEL_META, - ChatMessage, - KeyStatus, - IConfigPresenter, - LLMCoreStreamEvent, - ModelConfig, - MCPToolDefinition -} from '@shared/presenter' -import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS } from '@shared/modelConfigDefaults' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Define interface for SiliconCloud API key response -interface SiliconCloudKeyResponse { - code: number - message: string - status: boolean - data: { - id: string - name: string - image: string - email: string - isAdmin: boolean - balance: string - status: string - introduction: string - role: string - chargeBalance: string - totalBalance: string - } -} - -export class SiliconcloudProvider extends OpenAICompatibleProvider { - // 支持 enable_thinking 参数的模型列表 - private static readonly ENABLE_THINKING_MODELS: string[] = [ - 'qwen/qwen3-8b', - 'qwen/qwen3-14b', - 'qwen/qwen3-32b', - 'qwen/qwen3-30b-a3b', - 'qwen/qwen3-235b-a22b', - 'tencent/hunyuan-a13b-instruct', - 'zai-org/glm-4.5v', - 'deepseek-ai/deepseek-v3.1', - 'pro/deepseek-ai/deepseek-v3.1' - ] - - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - /** - * 检查模型是否支持 enable_thinking 参数 - * @param modelId 模型ID - * @returns boolean 是否支持 enable_thinking - */ - private supportsEnableThinking(modelId: string): boolean { - const normalizedModelId = modelId.toLowerCase() - return SiliconcloudProvider.ENABLE_THINKING_MODELS.some((supportedModel) => - normalizedModelId.includes(supportedModel) - ) - } - - /** - * 重写 coreStream 方法以支持 SiliconCloud 的 enable_thinking 参数 - */ - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!this.isInitialized) throw new Error('Provider not initialized') - if (!modelId) throw new Error('Model ID is required') - - const shouldAddEnableThinking = this.supportsEnableThinking(modelId) && modelConfig?.reasoning - - if (shouldAddEnableThinking) { - // 原始的 create 方法 - const originalCreate = this.openai.chat.completions.create.bind(this.openai.chat.completions) - // 替换 create 方法以添加 enable_thinking 参数 - this.openai.chat.completions.create = ((params: any, options?: any) => { - const modifiedParams = { - ...params, - enable_thinking: true - } - return originalCreate(modifiedParams, options) - }) as any - - try { - const effectiveModelConfig = { ...modelConfig, reasoning: false } - yield* super.coreStream( - messages, - modelId, - effectiveModelConfig, - temperature, - maxTokens, - mcpTools - ) - } finally { - this.openai.chat.completions.create = originalCreate - } - } else { - yield* super.coreStream(messages, modelId, modelConfig, temperature, maxTokens, mcpTools) - } - } - - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - const response = await this.openai.models.list({ - ...options - }) - return response.data.map((model) => ({ - id: model.id, - name: model.id, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, - maxTokens: DEFAULT_MODEL_MAX_TOKENS - })) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `请总结以下内容,使用简洁的语言,突出重点:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - - /** - * Get current API key status from SiliconCloud - * @returns Promise API key status information - */ - public async getKeyStatus(): Promise { - if (!this.provider.apiKey) { - throw new Error('API key is required') - } - - const response = await fetch('https://api.siliconflow.cn/v1/user/info', { - method: 'GET', - headers: { - Authorization: `Bearer ${this.provider.apiKey}`, - 'Content-Type': 'application/json' - } - }) - - if (!response.ok) { - const errorText = await response.text() - throw new Error( - `SiliconCloud API key check failed: ${response.status} ${response.statusText} - ${errorText}` - ) - } - - const keyResponse: SiliconCloudKeyResponse = await response.json() - - if (keyResponse.code !== 20000 || !keyResponse.status) { - throw new Error(`SiliconCloud API error: ${keyResponse.message}`) - } - - const totalBalance = parseFloat(keyResponse.data.totalBalance) - - // Map to unified KeyStatus format - return { - limit_remaining: `¥${totalBalance}`, - remainNum: totalBalance - } - } - - /** - * Override check method to use SiliconCloud's API key status endpoint - * @returns Promise<{ isOk: boolean; errorMsg: string | null }> - */ - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - const keyStatus = await this.getKeyStatus() - - // Check if there's remaining quota - if (keyStatus.remainNum !== undefined && keyStatus.remainNum <= 0) { - return { - isOk: false, - errorMsg: `API key quota exhausted. Remaining: ${keyStatus.limit_remaining}` - } - } - - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - let errorMessage = 'An unknown error occurred during SiliconCloud API key check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - - console.error('SiliconCloud API key check failed:', error) - return { isOk: false, errorMsg: errorMessage } - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/togetherProvider.ts b/src/main/presenter/llmProviderPresenter/providers/togetherProvider.ts deleted file mode 100644 index 4ce142258..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/togetherProvider.ts +++ /dev/null @@ -1,87 +0,0 @@ -import { IConfigPresenter, LLM_PROVIDER, LLMResponse, MODEL_META } from '@shared/presenter' -import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS } from '@shared/modelConfigDefaults' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import Together from 'together-ai' -import type { ProviderMcpRuntimePort } from '../runtimePorts' -export class TogetherProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: { role: 'system' | 'user' | 'assistant'; content: string }[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `请总结以下内容,使用简洁的语言,突出重点:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - protected async fetchProviderModels(options?: { timeout: number }): Promise { - // 检查供应商是否在黑名单中 - if (this.isNoModelsApi) { - // console.log(`Provider ${this.provider.name} does not support OpenAI models API`) - return this.models - } - return this.fetchTogetherAIModels(options) - } - - protected async fetchTogetherAIModels(options?: { timeout: number }): Promise { - const togetherai = new Together({ - apiKey: this.provider.apiKey - }) - const response = await togetherai.models.list(options) - return response - .filter((model) => model.type === 'chat' || model.type === 'language') - .map((model) => ({ - id: model.id, - name: model.id, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, - maxTokens: DEFAULT_MODEL_MAX_TOKENS - })) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/tokenfluxProvider.ts b/src/main/presenter/llmProviderPresenter/providers/tokenfluxProvider.ts deleted file mode 100644 index 3947acfa3..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/tokenfluxProvider.ts +++ /dev/null @@ -1,226 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - ChatMessage, - KeyStatus, - MODEL_META, - IConfigPresenter -} from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Define interface for TokenFlux API model response -interface TokenFluxModelResponse { - id: string - name: string - description: string - provider: string - pricing: { - input: number - output: number - } - context_length: number - supports_streaming: boolean - supports_vision: boolean -} - -// Define interface for TokenFlux models list response -interface TokenFluxModelsResponse { - object: string - data: TokenFluxModelResponse[] -} - -export class TokenFluxProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `You need to summarize the user's conversation into a title of no more than 10 words, with the title language matching the user's primary language, without using punctuation or other special symbols:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } - - /** - * Get current API key status from TokenFlux - * @returns Promise API key status information - */ - public async getKeyStatus(): Promise { - if (!this.provider.apiKey) { - throw new Error('API key is required') - } - - // TokenFlux uses OpenAI-compatible API, so we can use the models endpoint for key validation - const response = await fetch(`${this.provider.baseUrl}/models`, { - method: 'GET', - headers: { - Authorization: `Bearer ${this.provider.apiKey}`, - 'Content-Type': 'application/json' - } - }) - - if (!response.ok) { - const errorText = await response.text() - throw new Error( - `TokenFlux API key check failed: ${response.status} ${response.statusText} - ${errorText}` - ) - } - - // TokenFlux doesn't provide quota information in the models endpoint response - // So we return a simple success status - return { - limit_remaining: 'Available', - remainNum: undefined - } - } - - /** - * Override check method to use TokenFlux's API key status endpoint - * @returns Promise<{ isOk: boolean; errorMsg: string | null }> - */ - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - await this.getKeyStatus() - return { isOk: true, errorMsg: null } - } catch (error: unknown) { - let errorMessage = 'An unknown error occurred during TokenFlux API key check.' - if (error instanceof Error) { - errorMessage = error.message - } else if (typeof error === 'string') { - errorMessage = error - } - - console.error('TokenFlux API key check failed:', error) - return { isOk: false, errorMsg: errorMessage } - } - } - - /** - * Override fetchOpenAIModels to parse TokenFlux specific model data and update model configurations - * @param options - Request options - * @returns Promise - Array of model metadata - */ - protected async fetchOpenAIModels(options?: { timeout: number }): Promise { - try { - const response = await this.openai.models.list(options) - // console.log('TokenFlux models response:', JSON.stringify(response, null, 2)) - - const models: MODEL_META[] = [] - - // Cast response to TokenFlux format - const tokenfluxResponse = response as unknown as TokenFluxModelsResponse - - for (const model of tokenfluxResponse.data) { - // Extract model information - const modelId = model.id - const modelName = model.name || modelId - const description = model.description || '' - - // Determine capabilities based on TokenFlux model data - const hasVision = model.supports_vision || false - const hasFunctionCalling = true // Most TokenFlux models should support function calling - - // Get existing model configuration first - const existingConfig = this.configPresenter.getModelConfig(modelId, this.provider.id) - - // Extract configuration values with proper fallback priority: API -> existing config -> default - const contextLength = model.context_length || existingConfig.contextLength || 4096 - const maxTokens = existingConfig.maxTokens || Math.min(contextLength / 2, 4096) - - // Build new configuration based on API response - const newConfig = { - contextLength: contextLength, - maxTokens: maxTokens, - functionCall: hasFunctionCalling, - vision: hasVision, - reasoning: existingConfig.reasoning, // Keep existing reasoning setting - temperature: existingConfig.temperature, // Keep existing temperature - type: existingConfig.type // Keep existing type - } - - // Check if configuration has changed - const configChanged = - existingConfig.contextLength !== newConfig.contextLength || - existingConfig.maxTokens !== newConfig.maxTokens || - existingConfig.functionCall !== newConfig.functionCall || - existingConfig.vision !== newConfig.vision - - // Update configuration if changed - if (configChanged) { - this.configPresenter.setModelConfig(modelId, this.provider.id, newConfig, { - source: 'provider' - }) - } - - // Create MODEL_META object - const modelMeta: MODEL_META = { - id: modelId, - name: modelName, - group: 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: contextLength, - maxTokens: maxTokens, - description: description, - vision: hasVision, - functionCall: hasFunctionCalling, - reasoning: existingConfig.reasoning || false - } - - models.push(modelMeta) - } - - console.log(`Processed ${models.length} TokenFlux models with dynamic configuration updates`) - return models - } catch (error) { - console.error('Error fetching TokenFlux models:', error) - // Fallback to parent implementation - return super.fetchOpenAIModels(options) - } - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/vercelAIGatewayProvider.ts b/src/main/presenter/llmProviderPresenter/providers/vercelAIGatewayProvider.ts deleted file mode 100644 index d7a42b8de..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/vercelAIGatewayProvider.ts +++ /dev/null @@ -1,60 +0,0 @@ -import { LLM_PROVIDER, LLMResponse, ChatMessage, IConfigPresenter } from '@shared/presenter' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -export class VercelAIGatewayProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `请总结以下内容,使用简洁的语言,突出重点:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/vertexProvider.ts b/src/main/presenter/llmProviderPresenter/providers/vertexProvider.ts deleted file mode 100644 index 7c02e094c..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/vertexProvider.ts +++ /dev/null @@ -1,1217 +0,0 @@ -import { - Content, - FunctionCallingConfigMode, - GenerateContentParameters, - GenerateContentResponseUsageMetadata, - GoogleGenAI, - HarmBlockThreshold, - HarmCategory, - Modality, - Part, - SafetySetting, - Tool, - GenerateContentConfig -} from '@google/genai' -import { ModelType } from '@shared/model' -import { - ChatMessage, - IConfigPresenter, - LLM_PROVIDER, - LLMCoreStreamEvent, - LLMResponse, - MCPToolDefinition, - MODEL_META, - ModelConfig -} from '@shared/presenter' -import { VERTEX_PROVIDER } from '@shared/types/presenters/llmprovider.presenter' -import { createStreamEvent } from '@shared/types/core/llm-events' -import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider' -import { modelCapabilities } from '../../configPresenter/modelCapabilities' -import { eventBus, SendTarget } from '@/eventbus' -import { CONFIG_EVENTS } from '@/events' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -// Mapping from simple keys to API HarmCategory constants -const keyToHarmCategoryMap: Record = { - harassment: HarmCategory.HARM_CATEGORY_HARASSMENT, - hateSpeech: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - sexuallyExplicit: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - dangerousContent: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT -} - -// Value mapping from config storage to API HarmBlockThreshold constants -// Assuming config stores 'BLOCK_NONE', 'BLOCK_LOW_AND_ABOVE', etc. directly -const valueToHarmBlockThresholdMap: Record = { - BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE, - BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED -} -const safetySettingKeys = Object.keys(keyToHarmCategoryMap) - -export class VertexProvider extends BaseLLMProvider { - private genAI: GoogleGenAI - - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - this.genAI = this.createGenAIClient() - this.init() - } - - public onProxyResolved(): void { - this.genAI = this.createGenAIClient() - this.init() - } - - private get vertexProvider(): VERTEX_PROVIDER { - return this.provider as VERTEX_PROVIDER - } - - private resolveProjectId(): string | undefined { - return this.vertexProvider.projectId || process.env.GOOGLE_CLOUD_PROJECT - } - - private resolveLocation(): string | undefined { - return this.vertexProvider.location || process.env.GOOGLE_CLOUD_LOCATION - } - - private buildGoogleAuthOptions() { - const privateKey = this.vertexProvider.accountPrivateKey - const clientEmail = this.vertexProvider.accountClientEmail - if (privateKey && clientEmail) { - return { - projectId: this.resolveProjectId(), - credentials: { - client_email: clientEmail, - private_key: privateKey.replace(/\\n/g, '\n') - }, - scopes: ['https://www.googleapis.com/auth/cloud-platform'] - } - } - return undefined - } - - private buildBaseUrl(): string { - const customBaseUrl = this.vertexProvider.baseUrl?.trim() - if (customBaseUrl) return customBaseUrl - - const location = this.resolveLocation() || 'us-central1' - if (this.vertexProvider.endpointMode === 'express') { - return 'https://aiplatform.googleapis.com/' - } - return `https://${location}-aiplatform.googleapis.com/` - } - - private getApiVersion(): 'v1' | 'v1beta1' { - return (this.vertexProvider.apiVersion as 'v1' | 'v1beta1') || 'v1' - } - - private createGenAIClient(): GoogleGenAI { - const project = this.resolveProjectId() - const location = this.resolveLocation() - const apiVersion = this.getApiVersion() - - return new GoogleGenAI({ - vertexai: true, - project, - location, - apiVersion, - googleAuthOptions: this.buildGoogleAuthOptions(), - httpOptions: { - baseUrl: this.buildBaseUrl(), - apiVersion - } - }) - } - - // 确保带有 Vertex 格式的模型路径 - private ensureVertexModelName(modelId: string): string { - if (!modelId) return modelId - if (modelId.startsWith('projects/')) return modelId - const normalized = modelId - .replace(/^models\//i, '') - .replace(/^publishers\/google\/models\//i, '') - return `publishers/google/models/${normalized}` - } - - private buildVertexStreamEndpoint(modelId: string): string { - const baseUrl = this.buildBaseUrl().replace(/\/+$/, '') - const apiVersion = this.getApiVersion() - const modelPath = this.ensureVertexModelName(modelId).replace(/^\/+/, '') - return `${baseUrl}/${apiVersion}/${modelPath}:streamGenerateContent` - } - - private buildVertexTraceHeaders(): Record { - const headers: Record = { - 'Content-Type': 'application/json' - } - if (this.provider.apiKey) { - headers['x-goog-api-key'] = this.provider.apiKey - } - return headers - } - - // Implement abstract method fetchProviderModels from BaseLLMProvider - protected async fetchProviderModels(): Promise { - try { - const modelsResponse = await this.genAI.models.list() - // console.log('gemini models response:', modelsResponse) - - // 将 pager 转换为数组 - const models: any[] = [] - for await (const model of modelsResponse) { - models.push(model) - } - - if (models.length === 0) { - console.warn('No models found in Vertex AI response, using Provider DB models') - const dbModels = this.configPresenter.getDbProviderModels(this.provider.id).map((m) => ({ - id: m.id, - name: m.name, - group: m.group || 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: m.contextLength, - maxTokens: m.maxTokens, - vision: m.vision || false, - functionCall: m.functionCall || false, - reasoning: m.reasoning || false, - ...(m.type ? { type: m.type } : {}) - })) - return dbModels - } - - // 映射 API 返回的模型数据(能力统一读 Provider DB) - const normalizeModelId = (mid: string): string => - String(mid || '') - .replace(/^models\//i, '') - .replace(/^publishers\/google\/models\//i, '') - const apiModels: MODEL_META[] = models - .filter((model: any) => { - const name = String(model.name || '').toLowerCase() - return ( - !name.includes('embedding') && - !name.includes('aqa') && - !name.includes('text-embedding') && - !name.includes('gemma-3n-e4b-it') - ) - }) - .map((model: any) => { - const apiModelId: string = model.name - const displayName: string = model.displayName || apiModelId - - const normalizedId = normalizeModelId(apiModelId) - - const vision = modelCapabilities.supportsVision(this.provider.id, normalizedId) - const functionCall = modelCapabilities.supportsToolCall(this.provider.id, normalizedId) - const reasoning = modelCapabilities.supportsReasoning(this.provider.id, normalizedId) - const isImageOutput = modelCapabilities.supportsImageOutput( - this.provider.id, - normalizedId - ) - const modelType = isImageOutput ? ModelType.ImageGeneration : ModelType.Chat - - let group = 'default' - if (/\b(exp|preview)\b/i.test(apiModelId)) group = 'experimental' - else if (/\bgemma\b/i.test(apiModelId)) group = 'gemma' - - return { - id: apiModelId, - name: displayName, - group, - providerId: this.provider.id, - isCustom: false, - contextLength: model.inputTokenLimit, - maxTokens: model.outputTokenLimit, - vision, - functionCall, - reasoning, - ...(modelType !== ModelType.Chat && { type: modelType }) - } as MODEL_META - }) - - // console.log('Mapped Vertex models:', apiModels) - return apiModels - } catch (error) { - console.warn('Failed to fetch models from Vertex AI:', error) - // If API call fails, fallback to Provider DB mapping - const dbModels = this.configPresenter.getDbProviderModels(this.provider.id).map((m) => ({ - id: m.id, - name: m.name, - group: m.group || 'default', - providerId: this.provider.id, - isCustom: false, - contextLength: m.contextLength, - maxTokens: m.maxTokens, - vision: m.vision || false, - functionCall: m.functionCall || false, - reasoning: m.reasoning || false, - ...(m.type ? { type: m.type } : {}) - })) - return dbModels - } - } - - // Implement summaryTitles abstract method from BaseLLMProvider - public async summaryTitles( - messages: { role: 'system' | 'user' | 'assistant'; content: string }[], - modelId: string - ): Promise { - console.log('vertex summary check, ignore modelId', modelId) - // Use Vertex AI to generate conversation titles - try { - const conversationText = messages.map((m) => `${m.role}: ${m.content}`).join('\n') - const prompt = `${SUMMARY_TITLES_PROMPT}\n\n${conversationText}` - - const result = await this.genAI.models.generateContent({ - model: this.ensureVertexModelName(modelId), - contents: [{ role: 'user', parts: [{ text: prompt }] }], - config: this.getGenerateContentConfig(0.4, undefined, modelId, false) - }) - - return result.text?.trim() || 'New Conversation' - } catch (error) { - console.error('Failed to generate conversation title:', error) - return 'New Conversation' - } - } - - // Override fetchModels method since Vertex AI will reuse the cached model list - async fetchModels(): Promise { - // Vertex AI will reuse the cached model list - return this.models - } - - // Override check method to use the first default model for testing - async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - try { - const projectId = this.resolveProjectId() - const location = this.resolveLocation() - if (!projectId || !location) { - return { isOk: false, errorMsg: 'projectId and location are required for Vertex AI' } - } - if (this.vertexProvider.accountPrivateKey && !this.vertexProvider.accountClientEmail) { - return { - isOk: false, - errorMsg: 'accountClientEmail is required when accountPrivateKey is provided' - } - } - this.genAI = this.createGenAIClient() - - // Use the first model for simple testing - const testModelId = - this.models.find((m) => m.type !== ModelType.ImageGeneration)?.id || - this.models[0]?.id || - 'gemini-1.5-flash-001' - - const result = await this.genAI.models.generateContent({ - model: this.ensureVertexModelName(testModelId), - contents: [{ role: 'user', parts: [{ text: 'Hello from Vertex AI' }] }] - }) - return { - isOk: Boolean(result && (result.text || result.candidates?.length)), - errorMsg: null - } - } catch (error) { - console.error('Vertex provider check failed:', this.provider.name, error) - return { isOk: false, errorMsg: error instanceof Error ? error.message : String(error) } - } - } - - protected async init() { - if (this.provider.enable) { - try { - const projectId = this.resolveProjectId() - const location = this.resolveLocation() - if (!projectId || !location) { - console.warn('Vertex provider missing projectId or location, skip initialization') - return - } - this.genAI = this.createGenAIClient() - this.isInitialized = true - // Use API to get model list, fallback to static list if failed - this.models = await this.fetchProviderModels() - await this.autoEnableModelsIfNeeded() - // Vertex AI is relatively slow, special compensation - eventBus.sendToRenderer( - CONFIG_EVENTS.MODEL_LIST_CHANGED, - SendTarget.ALL_WINDOWS, - this.provider.id - ) - console.info('Provider initialized successfully:', this.provider.name) - } catch (error) { - console.warn('Provider initialization failed:', this.provider.name, error) - } - } - } - - /** - * 重写 autoEnableModelsIfNeeded 方法 - * 不自动启用模型,交由用户手动选择。 - */ - protected async autoEnableModelsIfNeeded() { - if (!this.models || this.models.length === 0) return - const providerId = this.provider.id - - // 检查是否有自定义模型 - const customModels = this.configPresenter.getCustomModels(providerId) - if (customModels && customModels.length > 0) return - - // 检查是否有任何模型的状态被手动修改过 - const hasManuallyModifiedModels = this.models.some((model) => - this.configPresenter.getModelStatus(providerId, model.id) - ) - if (hasManuallyModifiedModels) return - - // 检查是否有任何已启用的模型 - const hasEnabledModels = this.models.some((model) => - this.configPresenter.getModelStatus(providerId, model.id) - ) - - // 不再自动启用模型,让用户手动选择启用需要的模型 - if (!hasEnabledModels) { - console.info( - `Provider ${this.provider.name} models loaded, please manually enable the models you need` - ) - } - } - - // Helper function to get and format safety settings - private async getFormattedSafetySettings(): Promise { - const safetySettings: SafetySetting[] = [] - - for (const key of safetySettingKeys) { - try { - // Use configPresenter to get the setting value for the 'gemini' provider - // Assuming getSetting returns the string value like 'BLOCK_MEDIUM_AND_ABOVE' - const settingValue = - (await this.configPresenter.getSetting( - `geminiSafety_${key}` // Match the key used in settings store - )) || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' // Default if not set - - const threshold = valueToHarmBlockThresholdMap[settingValue] - const category = keyToHarmCategoryMap[key] - - // Only add if threshold is defined, category is defined, and threshold is not BLOCK_NONE - if ( - threshold && - category && - threshold !== 'BLOCK_NONE' && - threshold !== 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' - ) { - safetySettings.push({ category, threshold }) - } - } catch (error) { - console.warn(`Failed to retrieve or map safety setting for ${key}:`, error) - } - } - - return safetySettings.length > 0 ? safetySettings : undefined - } - - // 判断模型是否支持 thinkingBudget - private supportsThinkingBudget(modelId: string): boolean { - const normalized = modelId.replace(/^models\//i, '') - const range = modelCapabilities.getThinkingBudgetRange(this.provider.id, normalized) - return ( - typeof range.default === 'number' || - typeof range.min === 'number' || - typeof range.max === 'number' - ) - } - - // 获取生成配置,不再创建模型实例 - private getGenerateContentConfig( - temperature?: number, - maxTokens?: number, - modelId?: string, - reasoning?: boolean, - thinkingBudget?: number - ): GenerateContentConfig { - const config: GenerateContentConfig = { - temperature, - maxOutputTokens: maxTokens, - topP: 1 // topP默认为1.0 - } - - // 从当前模型列表中查找指定的模型 - if (modelId && this.models) { - const model = this.models.find((m) => m.id === modelId) - if (model && model.type === ModelType.ImageGeneration) { - config.responseModalities = [Modality.TEXT, Modality.IMAGE] - } - } - - // 正确配置思考功能 - if (reasoning) { - config.thinkingConfig = { - includeThoughts: true - } - - // 仅对支持 thinkingBudget 的 Gemini 2.5 系列模型添加 thinkingBudget 参数 - if (modelId && this.supportsThinkingBudget(modelId) && thinkingBudget !== undefined) { - config.thinkingConfig.thinkingBudget = thinkingBudget - } - } - - return config - } - - // 将 ChatMessage 转换为 Gemini 格式的消息 - private formatVertexMessages(messages: ChatMessage[]): { - systemInstruction: string - contents: Content[] - } { - // 提取系统消息 - const systemMessages = messages.filter((msg) => msg.role === 'system') - let systemContent = '' - if (systemMessages.length > 0) { - systemContent = systemMessages.map((msg) => msg.content).join('\n') - } - - // 创建Gemini内容数组 - const formattedContents: Content[] = [] - - // 处理非系统消息 - const nonSystemMessages = messages.filter((msg) => msg.role !== 'system') - for (let i = 0; i < nonSystemMessages.length; i++) { - const message = nonSystemMessages[i] - - // 检查是否是带有tool_calls的assistant消息 - if (message.role === 'assistant' && 'tool_calls' in message) { - // 处理tool_calls消息 - for (const toolCall of message.tool_calls || []) { - // 添加模型发出的函数调用 - formattedContents.push({ - role: 'model', - parts: [ - { - functionCall: { - name: toolCall.function.name, - args: JSON.parse(toolCall.function.arguments || '{}') - } - } - ] - }) - - // 查找对应的工具响应消息 - const nextMessage = i + 1 < nonSystemMessages.length ? nonSystemMessages[i + 1] : null - if ( - nextMessage && - nextMessage.role === 'tool' && - 'tool_call_id' in nextMessage && - nextMessage.tool_call_id === toolCall.id - ) { - // 添加用户角色的函数响应 - formattedContents.push({ - role: 'user', - parts: [ - { - functionResponse: { - name: toolCall.function.name, - response: { - result: - typeof nextMessage.content === 'string' - ? nextMessage.content - : JSON.stringify(nextMessage.content) - } - } - } - ] - }) - - // 跳过下一条消息,因为已经处理过了 - i++ - } - } - continue - } - - // 为每条消息创建parts数组 - const parts: Part[] = [] - - // 检查消息是否包含工具调用或工具响应 - if (message.role === 'tool' && Array.isArray(message.content)) { - // 处理工具消息 - for (const part of message.content) { - // @ts-ignore - 处理类型兼容性 - if (part.type === 'function_call' && part.function_call) { - // 处理函数调用 - parts.push({ - // @ts-ignore - 处理类型兼容性 - functionCall: { - // @ts-ignore - 处理类型兼容性 - name: part.function_call.name || '', - // @ts-ignore - 处理类型兼容性 - args: part.function_call.arguments ? JSON.parse(part.function_call.arguments) : {} - } - }) - // @ts-ignore - 处理类型兼容性 - } else if (part.type === 'function_response') { - // 处理函数响应 - // @ts-ignore - 处理类型兼容性 - parts.push({ text: part.function_response || '' }) - } - } - } else if (typeof message.content === 'string') { - // 处理消息内容 - 可能是字符串或包含图片的数组 - // 处理纯文本消息 - // 只添加非空文本 - if (message.content.trim() !== '') { - parts.push({ text: message.content }) - } - } else if (Array.isArray(message.content)) { - // 处理多模态消息(带图片等) - for (const part of message.content) { - if (part.type === 'text') { - // 只添加非空文本 - if (part.text && part.text.trim() !== '') { - parts.push({ text: part.text }) - } - } else if (part.type === 'image_url' && part.image_url) { - // 处理图片(假设是 base64 格式) - const matches = part.image_url.url.match(/^data:([^;]+);base64,(.+)$/) - if (matches && matches.length === 3) { - const mimeType = matches[1] - const base64Data = matches[2] - parts.push({ - inlineData: { - data: base64Data, - mimeType: mimeType - } - }) - } - } - } - } - - // 只有当parts不为空时,才添加到formattedContents中 - if (parts.length > 0) { - // 将消息角色转换为Gemini支持的角色 - let role: 'user' | 'model' = 'user' - if (message.role === 'assistant') { - role = 'model' - } else if (message.role === 'tool') { - // 工具消息作为用户消息处理 - role = 'user' - } - - formattedContents.push({ - role: role, - parts: parts - }) - } - } - - return { systemInstruction: systemContent, contents: formattedContents } - } - - // 处理 Vertex API 响应,支持新旧格式的思考内容 - private processVertexResponse(result: any): LLMResponse { - const resultResp: LLMResponse = { - content: '' - } - - let textContent = '' - let thoughtContent = '' - - // 检查是否有候选响应和 parts - if (result.candidates && result.candidates[0]?.content?.parts) { - for (const part of result.candidates[0].content.parts) { - // 检查是否是思考内容 (新格式) - if ((part as any).thought === true && part.text) { - thoughtContent += part.text - } else if (part.text) { - textContent += part.text - } - } - } else { - // 回退到使用 result.text - textContent = result.text || '' - } - - // 如果没有检测到新格式的思考内容,检查旧格式的 标签 - if (!thoughtContent && textContent.includes('')) { - const thinkStart = textContent.indexOf('') - const thinkEnd = textContent.indexOf('') - - if (thinkEnd > thinkStart) { - // 提取reasoning_content - thoughtContent = textContent.substring(thinkStart + 7, thinkEnd).trim() - - // 合并前后的普通内容 - const beforeThink = textContent.substring(0, thinkStart).trim() - const afterThink = textContent.substring(thinkEnd + 8).trim() - textContent = [beforeThink, afterThink].filter(Boolean).join('\n') - } - } - - resultResp.content = textContent - if (thoughtContent) { - resultResp.reasoning_content = thoughtContent - } - - return resultResp - } - - // 实现抽象方法 - async completions( - messages: { role: 'system' | 'user' | 'assistant'; content: string }[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - try { - if (!this.genAI) { - throw new Error('Google Generative AI client is not initialized') - } - - const { systemInstruction, contents } = this.formatVertexMessages(messages) - - // 创建 GenerateContentConfig - const generateContentConfig: GenerateContentConfig = this.getGenerateContentConfig( - temperature ?? 0.7, - maxTokens, - modelId, - false // completions 方法中不处理 reasoning - ) - - if (systemInstruction) { - generateContentConfig.systemInstruction = systemInstruction - } - - // 一次性创建 requestParams - const requestParams: GenerateContentParameters = { - model: modelId, - contents, - config: generateContentConfig - } - - const result = await this.genAI.models.generateContent({ - ...requestParams, - model: this.ensureVertexModelName(requestParams.model as string) - }) - - const resultResp: LLMResponse = { - content: '' - } - - // 尝试获取tokens信息 - 使用新SDK的usageMetadata结构 - try { - if (result.usageMetadata) { - const usage = result.usageMetadata - resultResp.totalUsage = { - prompt_tokens: usage.promptTokenCount || 0, - completion_tokens: usage.candidatesTokenCount || 0, - total_tokens: usage.totalTokenCount || 0 - } - } else { - // 估算token数量 - 简单方法,可以根据实际需要调整 - const promptText = messages.map((m) => m.content).join(' ') - const responseText = result.text || '' - - // 简单估算: 英文约1个token/4个字符,中文约1个token/1.5个字符 - const estimateTokens = (text: string): number => { - const chineseCharCount = (text.match(/[\u4e00-\u9fa5]/g) || []).length - const otherCharCount = text.length - chineseCharCount - return Math.ceil(chineseCharCount / 1.5 + otherCharCount / 4) - } - - const promptTokens = estimateTokens(promptText) - const completionTokens = estimateTokens(responseText) - - resultResp.totalUsage = { - prompt_tokens: promptTokens, - completion_tokens: completionTokens, - total_tokens: promptTokens + completionTokens - } - } - } catch (e) { - console.warn('Failed to estimate token count for Vertex response', e) - } - - // 处理响应内容,支持新格式的思考内容 - let textContent = '' - let thoughtContent = '' - - // 检查是否有候选响应和 parts - if (result.candidates && result.candidates[0]?.content?.parts) { - for (const part of result.candidates[0].content.parts) { - // 检查是否是思考内容 (新格式) - if ((part as any).thought === true && part.text) { - thoughtContent += part.text - } else if (part.text) { - textContent += part.text - } - } - } else { - // 回退到使用 result.text - textContent = result.text || '' - } - - // 如果没有检测到新格式的思考内容,检查旧格式的 标签 - if (!thoughtContent && textContent.includes('')) { - const thinkStart = textContent.indexOf('') - const thinkEnd = textContent.indexOf('') - - if (thinkEnd > thinkStart) { - // 提取reasoning_content - thoughtContent = textContent.substring(thinkStart + 7, thinkEnd).trim() - - // 合并前后的普通内容 - const beforeThink = textContent.substring(0, thinkStart).trim() - const afterThink = textContent.substring(thinkEnd + 8).trim() - textContent = [beforeThink, afterThink].filter(Boolean).join('\n') - } - } - - resultResp.content = textContent - if (thoughtContent) { - resultResp.reasoning_content = thoughtContent - } - - return resultResp - } catch (error) { - console.error('Vertex completions error:', error) - throw error - } - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (!this.isInitialized) { - throw new Error('Provider not initialized') - } - - if (!modelId) { - throw new Error('Model ID is required') - } - - try { - const prompt = `Please generate a concise summary for the following content:\n\n${text}` - - const result = await this.genAI.models.generateContent({ - model: this.ensureVertexModelName(modelId), - contents: [{ role: 'user', parts: [{ text: prompt }] }], - config: this.getGenerateContentConfig(temperature, maxTokens, modelId, false) - }) - - return this.processVertexResponse(result) - } catch (error) { - console.error('Vertex summaries error:', error) - throw error - } - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (!this.isInitialized) { - throw new Error('Provider not initialized') - } - - if (!modelId) { - throw new Error('Model ID is required') - } - - try { - const result = await this.genAI.models.generateContent({ - model: this.ensureVertexModelName(modelId), - contents: [{ role: 'user', parts: [{ text: prompt }] }], - config: this.getGenerateContentConfig(temperature, maxTokens, modelId, false) - }) - - return this.processVertexResponse(result) - } catch (error) { - console.error('Vertex generateText error:', error) - throw error - } - } - - async suggestions( - context: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (!this.isInitialized) { - throw new Error('Provider not initialized') - } - - if (!modelId) { - throw new Error('Model ID is required') - } - - try { - const prompt = `Based on the following context, please provide up to 5 reasonable suggestion options, each not exceeding 100 characters. Please return in JSON array format without other explanations:\n\n${context}` - - const result = await this.genAI.models.generateContent({ - model: this.ensureVertexModelName(modelId), - contents: [{ role: 'user', parts: [{ text: prompt }] }], - config: this.getGenerateContentConfig(temperature, maxTokens, modelId, false) - }) - - const responseText = result.text || '' - - // 尝试从响应中解析出JSON数组 - try { - const cleanedText = responseText.replace(/```json|```/g, '').trim() - const suggestions = JSON.parse(cleanedText) - if (Array.isArray(suggestions)) { - return suggestions.map((item) => item.toString()) - } - } catch (parseError) { - console.error('Vertex suggestions parseError:', parseError) - // 如果解析失败,尝试分行处理 - const lines = responseText - .split('\n') - .map((line) => line.trim()) - .filter((line) => line && !line.startsWith('```') && !line.includes(':')) - .map((line) => line.replace(/^[0-9]+\.\s*/, '').replace(/^-\s*/, '')) - - if (lines.length > 0) { - return lines.slice(0, 5) - } - } - - // If all fail, return a default prompt - return ['Unable to generate suggestions'] - } catch (error) { - console.error('Vertex suggestions error:', error) - return ['Error occurred, unable to get suggestions'] - } - } - /** - * 核心流式处理方法 - * 实现BaseLLMProvider中的抽象方法 - */ - async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - mcpTools: MCPToolDefinition[] - ): AsyncGenerator { - if (!this.isInitialized) throw new Error('Provider not initialized') - if (!modelId) throw new Error('Model ID is required') - console.log('modelConfig', modelConfig, modelId) - - // 检查是否是图片生成模型 - const isImageGenerationModel = modelConfig?.type === ModelType.ImageGeneration - - // 如果是图片生成模型,使用特殊处理 - if (isImageGenerationModel) { - yield* this.handleImageGenerationStream(messages, modelId, temperature, maxTokens) - return - } - - const safetySettings = await this.getFormattedSafetySettings() - console.log('safetySettings', safetySettings) - - // 添加Gemini工具调用 - let geminiTools: Tool[] = [] - - // Load MCP tools if available - if (mcpTools.length > 0) - geminiTools = (await this.mcpRuntime?.mcpToolsToGeminiTools(mcpTools, this.provider.id)) ?? [] - - // 格式化消息为Gemini格式 - const formattedParts = this.formatVertexMessages(messages) - - // 1. 获取基础 config - const generateContentConfig: GenerateContentConfig = this.getGenerateContentConfig( - temperature, - maxTokens, - modelId, - modelConfig.reasoning, - modelConfig.thinkingBudget - ) - - // 2. 在本地变量上添加其他属性 - if (formattedParts.systemInstruction) { - generateContentConfig.systemInstruction = formattedParts.systemInstruction - } - - if (geminiTools.length > 0) { - generateContentConfig.tools = geminiTools - // 仅当存在 functionDeclarations 时才配置 functionCallingConfig - const hasFunctionDeclarations = geminiTools.some((t: any) => { - const fns = t?.functionDeclarations - return Array.isArray(fns) && fns.length > 0 - }) - if (hasFunctionDeclarations) { - generateContentConfig.toolConfig = { - functionCallingConfig: { - mode: FunctionCallingConfigMode.AUTO // 允许模型自动决定是否调用工具 - } - } - } - } - - if (safetySettings) { - generateContentConfig.safetySettings = safetySettings - } - - // 3. 一次性创建完整的 requestParams - const requestParams: GenerateContentParameters = { - model: modelId, - contents: formattedParts.contents, - config: generateContentConfig - } - - const streamRequestParams = { - ...requestParams, - model: this.ensureVertexModelName(requestParams.model as string) - } - - await this.emitRequestTrace(modelConfig, { - endpoint: this.buildVertexStreamEndpoint(modelId), - headers: this.buildVertexTraceHeaders(), - body: streamRequestParams - }) - - // 发送流式请求 - const result = await this.genAI.models.generateContentStream(streamRequestParams) - - // 状态变量 - let buffer = '' - let isInThinkTag = false - let toolUseDetected = false - let usageMetadata: GenerateContentResponseUsageMetadata | undefined - let isNewThoughtFormatDetected = modelConfig.reasoning === true - - // 流处理循环 - for await (const chunk of result) { - // 处理用量统计 - if (chunk.usageMetadata) { - usageMetadata = chunk.usageMetadata - } - - // 检查是否包含函数调用 - if (chunk.candidates && chunk.candidates[0]?.content?.parts?.[0]?.functionCall) { - const functionCall = chunk.candidates[0].content.parts[0].functionCall - const functionName = functionCall.name - const functionArgs = functionCall.args || {} - const toolCallId = `gemini-tool-${Date.now()}` - - toolUseDetected = true - - // 发送工具调用开始事件 - yield createStreamEvent.toolCallStart(toolCallId, functionName || '') - - // 发送工具调用参数 - const argsString = JSON.stringify(functionArgs) - yield createStreamEvent.toolCallChunk(toolCallId, argsString) - - // 发送工具调用结束事件 - yield createStreamEvent.toolCallEnd(toolCallId, argsString) - - // 设置停止原因为工具使用 - break - } - - // 处理内容块 - let content = '' - let thoughtContent = '' - - // 处理文本和图像内容 - if (chunk.candidates && chunk.candidates[0]?.content?.parts) { - for (const part of chunk.candidates[0].content.parts) { - // 检查是否是思考内容 (新格式) - if ((part as any).thought === true && part.text) { - isNewThoughtFormatDetected = true - thoughtContent += part.text - } else if (part.text) { - content += part.text - } else if (part.inlineData && part.inlineData.data && part.inlineData.mimeType) { - // 处理图像数据 - yield createStreamEvent.imageData({ - data: part.inlineData.data, - mimeType: part.inlineData.mimeType - }) - } - } - } else { - // 兼容处理 - content = chunk.text || '' - } - - // 如果检测到思考内容,直接发送 - if (thoughtContent) { - yield createStreamEvent.reasoning(thoughtContent) - } - - if (!content) continue - - if (isNewThoughtFormatDetected) { - yield createStreamEvent.text(content) - } else { - buffer += content - - if (buffer.includes('') && !isInThinkTag) { - const thinkStart = buffer.indexOf('') - if (thinkStart > 0) { - yield createStreamEvent.text(buffer.substring(0, thinkStart)) - } - buffer = buffer.substring(thinkStart + 7) - isInThinkTag = true - } - - if (isInThinkTag && buffer.includes('')) { - const thinkEnd = buffer.indexOf('') - const reasoningContent = buffer.substring(0, thinkEnd) - if (reasoningContent) { - yield createStreamEvent.reasoning(reasoningContent) - } - buffer = buffer.substring(thinkEnd + 8) - isInThinkTag = false - } - - if (!isInThinkTag && buffer) { - yield createStreamEvent.text(buffer) - buffer = '' - } - } - } - - if (usageMetadata) { - yield createStreamEvent.usage({ - prompt_tokens: usageMetadata.promptTokenCount || 0, - completion_tokens: usageMetadata.candidatesTokenCount || 0, - total_tokens: usageMetadata.totalTokenCount || 0 - }) - } - - // 处理剩余缓冲区内容 - if (!isNewThoughtFormatDetected && buffer) { - if (isInThinkTag) { - yield createStreamEvent.reasoning(buffer) - } else { - yield createStreamEvent.text(buffer) - } - } - - // 发送停止事件 - yield createStreamEvent.stop(toolUseDetected ? 'tool_use' : 'complete') - } - - /** - * 处理图片生成模型的流式输出 - */ - private async *handleImageGenerationStream( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): AsyncGenerator { - try { - // 提取用户消息并构建parts数组 - const userMessage = messages.findLast((msg) => msg.role === 'user') - if (!userMessage) { - throw new Error('No user message found for image generation') - } - - // 构建包含文本和图片的parts数组,参考formatVertexMessages的逻辑 - const parts: Part[] = [] - - if (typeof userMessage.content === 'string') { - // 处理纯文本消息 - if (userMessage.content.trim() !== '') { - parts.push({ text: userMessage.content }) - } - } else if (Array.isArray(userMessage.content)) { - // 处理多模态消息(带图片等) - for (const part of userMessage.content) { - if (part.type === 'text') { - // 只添加非空文本 - if (part.text && part.text.trim() !== '') { - parts.push({ text: part.text }) - } - } else if (part.type === 'image_url' && part.image_url) { - // 处理图片(假设是 base64 格式) - const matches = part.image_url.url.match(/^data:([^;]+);base64,(.+)$/) - if (matches && matches.length === 3) { - const mimeType = matches[1] - const base64Data = matches[2] - parts.push({ - inlineData: { - data: base64Data, - mimeType: mimeType - } - }) - } - } - } - } - - // 如果没有有效的parts,抛出错误 - if (parts.length === 0) { - throw new Error('No valid content found for image generation') - } - - // 发送生成请求 - const result = await this.genAI.models.generateContentStream({ - model: this.ensureVertexModelName(modelId), - contents: [{ role: 'user', parts }], - config: this.getGenerateContentConfig(temperature, maxTokens, modelId, false) // 图像生成不需要reasoning - }) - - // 处理流式响应 - for await (const chunk of result) { - if (chunk.candidates && chunk.candidates[0]?.content?.parts) { - for (const part of chunk.candidates[0].content.parts) { - if (part.text) { - // 输出文本内容 - yield createStreamEvent.text(part.text) - } else if (part.inlineData) { - // 输出图像数据 - yield createStreamEvent.imageData({ - data: part.inlineData.data || '', - mimeType: part.inlineData.mimeType || '' - }) - } - } - } - } - - // 发送停止事件 - yield createStreamEvent.stop('complete') - } catch (error) { - console.error('Image generation stream error:', error) - yield createStreamEvent.error( - error instanceof Error ? error.message : 'Image generation failed' - ) - yield createStreamEvent.stop('error') - } - } - - async getEmbeddings(modelId: string, texts: string[]): Promise { - if (!this.genAI) throw new Error('Google Generative AI client is not initialized') - // Vertex embedContent 支持批量输入 - const resp = await this.genAI.models.embedContent({ - model: this.ensureVertexModelName(modelId), - contents: texts.map((text) => ({ - parts: [{ text }] - })) - }) - // resp.embeddings?: ContentEmbedding[] - if (resp && Array.isArray(resp.embeddings)) { - return resp.embeddings.map((e) => (Array.isArray(e.values) ? e.values : [])) - } - // 若无返回,抛出异常 - throw new Error('Vertex AI embedding API did not return embeddings') - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts b/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts deleted file mode 100644 index 37b4d7efc..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/zenmuxProvider.ts +++ /dev/null @@ -1,228 +0,0 @@ -import Anthropic from '@anthropic-ai/sdk' -import { - ChatMessage, - IConfigPresenter, - KeyStatus, - LLM_EMBEDDING_ATTRS, - LLM_PROVIDER, - LLMResponse, - LLMCoreStreamEvent, - MCPToolDefinition, - MODEL_META, - ModelConfig -} from '@shared/presenter' -import { ProxyAgent } from 'undici' -import { BaseLLMProvider } from '../baseProvider' -import { proxyConfig } from '../../proxyConfig' -import { AnthropicProvider } from './anthropicProvider' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -const ZENMUX_ANTHROPIC_BASE_URL = 'https://zenmux.ai/api/anthropic' - -class ZenmuxOpenAIDelegate extends OpenAICompatibleProvider { - protected override async init() { - this.isInitialized = true - } - - public async fetchZenmuxModels(options?: { timeout: number }): Promise { - return super.fetchOpenAIModels(options) - } -} - -class ZenmuxAnthropicDelegate extends AnthropicProvider { - private clientInitialized = false - - protected override async init() {} - - public async ensureClientInitialized(): Promise { - const apiKey = this.provider.apiKey || process.env.ANTHROPIC_API_KEY || null - if (!apiKey) { - this.clientInitialized = false - this.isInitialized = false - return - } - - const proxyUrl = proxyConfig.getProxyUrl() - const fetchOptions: { dispatcher?: ProxyAgent } = {} - - if (proxyUrl) { - const proxyAgent = new ProxyAgent(proxyUrl) - fetchOptions.dispatcher = proxyAgent - } - - const self = this as unknown as { anthropic?: Anthropic } - self.anthropic = new Anthropic({ - apiKey, - baseURL: this.provider.baseUrl || ZENMUX_ANTHROPIC_BASE_URL, - defaultHeaders: this.defaultHeaders, - fetchOptions - }) - - this.clientInitialized = true - this.isInitialized = true - } - - public isClientInitialized(): boolean { - return this.clientInitialized - } - - public override onProxyResolved(): void { - void this.ensureClientInitialized() - } -} - -export class ZenmuxProvider extends BaseLLMProvider { - private readonly openaiDelegate: ZenmuxOpenAIDelegate - private readonly anthropicDelegate: ZenmuxAnthropicDelegate - - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - super(provider, configPresenter, mcpRuntime) - - this.openaiDelegate = new ZenmuxOpenAIDelegate(provider, configPresenter, mcpRuntime) - this.anthropicDelegate = new ZenmuxAnthropicDelegate( - { - ...provider, - apiType: 'anthropic', - baseUrl: ZENMUX_ANTHROPIC_BASE_URL - }, - configPresenter, - mcpRuntime - ) - - this.init() - } - - private isAnthropicModel(modelId: string): boolean { - return modelId.trim().toLowerCase().startsWith('anthropic/') - } - - private async ensureAnthropicDelegateReady(): Promise { - await this.anthropicDelegate.ensureClientInitialized() - - if (!this.anthropicDelegate.isClientInitialized()) { - throw new Error('Anthropic SDK not initialized') - } - - return this.anthropicDelegate - } - - protected async fetchProviderModels(): Promise { - const models = await this.openaiDelegate.fetchZenmuxModels() - return models.map((model) => ({ - ...model, - group: 'ZenMux' - })) - } - - public onProxyResolved(): void { - this.openaiDelegate.onProxyResolved() - - if (this.anthropicDelegate.isClientInitialized()) { - this.anthropicDelegate.onProxyResolved() - } - } - - public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { - return this.openaiDelegate.check() - } - - public async summaryTitles(messages: ChatMessage[], modelId: string): Promise { - if (this.isAnthropicModel(modelId)) { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.summaryTitles(messages, modelId) - } - - return this.openaiDelegate.summaryTitles(messages, modelId) - } - - public async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (this.isAnthropicModel(modelId)) { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.completions(messages, modelId, temperature, maxTokens) - } - - return this.openaiDelegate.completions(messages, modelId, temperature, maxTokens) - } - - public async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (this.isAnthropicModel(modelId)) { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.summaries(text, modelId, temperature, maxTokens) - } - - return this.openaiDelegate.summaries(text, modelId, temperature, maxTokens) - } - - public async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - if (this.isAnthropicModel(modelId)) { - const delegate = await this.ensureAnthropicDelegateReady() - return delegate.generateText(prompt, modelId, temperature, maxTokens) - } - - return this.openaiDelegate.generateText(prompt, modelId, temperature, maxTokens) - } - - public async *coreStream( - messages: ChatMessage[], - modelId: string, - modelConfig: ModelConfig, - temperature: number, - maxTokens: number, - tools: MCPToolDefinition[] - ): AsyncGenerator { - if (this.isAnthropicModel(modelId)) { - const delegate = await this.ensureAnthropicDelegateReady() - yield* delegate.coreStream(messages, modelId, modelConfig, temperature, maxTokens, tools) - return - } - - yield* this.openaiDelegate.coreStream( - messages, - modelId, - modelConfig, - temperature, - maxTokens, - tools - ) - } - - public async getEmbeddings(modelId: string, texts: string[]): Promise { - if (this.isAnthropicModel(modelId)) { - throw new Error(`Embeddings not supported for Anthropic models: ${modelId}`) - } - - return this.openaiDelegate.getEmbeddings(modelId, texts) - } - - public async getDimensions(modelId: string): Promise { - if (this.isAnthropicModel(modelId)) { - throw new Error(`Embeddings not supported for Anthropic models: ${modelId}`) - } - - return this.openaiDelegate.getDimensions(modelId) - } - - public async getKeyStatus(): Promise { - return this.openaiDelegate.getKeyStatus() - } -} diff --git a/src/main/presenter/llmProviderPresenter/providers/zhipuProvider.ts b/src/main/presenter/llmProviderPresenter/providers/zhipuProvider.ts deleted file mode 100644 index a460184fd..000000000 --- a/src/main/presenter/llmProviderPresenter/providers/zhipuProvider.ts +++ /dev/null @@ -1,106 +0,0 @@ -import { - LLM_PROVIDER, - LLMResponse, - MODEL_META, - ChatMessage, - IConfigPresenter -} from '@shared/presenter' -import { ModelType } from '@shared/model' -import { - resolveModelContextLength, - resolveModelFunctionCall, - resolveModelMaxTokens -} from '@shared/modelConfigDefaults' -import { OpenAICompatibleProvider } from './openAICompatibleProvider' -import { providerDbLoader } from '../../configPresenter/providerDbLoader' -import { modelCapabilities } from '../../configPresenter/modelCapabilities' -import type { ProviderMcpRuntimePort } from '../runtimePorts' - -export class ZhipuProvider extends OpenAICompatibleProvider { - constructor( - provider: LLM_PROVIDER, - configPresenter: IConfigPresenter, - mcpRuntime?: ProviderMcpRuntimePort - ) { - // Initialize Zhipu AI model configuration - super(provider, configPresenter, mcpRuntime) - } - - protected async fetchOpenAIModels(): Promise { - const resolvedId = modelCapabilities.resolveProviderId(this.provider.id) || this.provider.id - const provider = providerDbLoader.getProvider(resolvedId) - if (!provider || !Array.isArray(provider.models)) { - return [] - } - - return provider.models.map((model) => { - const inputs = model.modalities?.input - const outputs = model.modalities?.output - const hasImageInput = Array.isArray(inputs) && inputs.includes('image') - const hasImageOutput = Array.isArray(outputs) && outputs.includes('image') - const modelType = hasImageOutput ? ModelType.ImageGeneration : ModelType.Chat - - return { - id: model.id, - name: model.display_name || model.name || model.id, - group: 'zhipu', - providerId: this.provider.id, - isCustom: false, - contextLength: resolveModelContextLength(model.limit?.context), - maxTokens: resolveModelMaxTokens(model.limit?.output), - vision: hasImageInput, - functionCall: resolveModelFunctionCall(model.tool_call), - reasoning: Boolean(model.reasoning?.supported), - enableSearch: Boolean(model.search?.supported), - type: modelType - } - }) - } - - async completions( - messages: ChatMessage[], - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion(messages, modelId, temperature, maxTokens) - } - - async summaries( - text: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: `You need to summarize the user's conversation into a title of no more than 10 words, with the title language matching the user's primary language, without using punctuation or other special symbols:\n${text}` - } - ], - modelId, - temperature, - maxTokens - ) - } - - async generateText( - prompt: string, - modelId: string, - temperature?: number, - maxTokens?: number - ): Promise { - return this.openAICompletion( - [ - { - role: 'user', - content: prompt - } - ], - modelId, - temperature, - maxTokens - ) - } -} diff --git a/src/main/presenter/llmProviderPresenter/runtimePorts.ts b/src/main/presenter/llmProviderPresenter/runtimePorts.ts index 7d780af79..21fb6a0e2 100644 --- a/src/main/presenter/llmProviderPresenter/runtimePorts.ts +++ b/src/main/presenter/llmProviderPresenter/runtimePorts.ts @@ -1,10 +1,6 @@ import type { IMCPPresenter } from '@shared/presenter' export interface ProviderMcpRuntimePort { - mcpToolsToAnthropicTools: IMCPPresenter['mcpToolsToAnthropicTools'] - mcpToolsToGeminiTools: IMCPPresenter['mcpToolsToGeminiTools'] - mcpToolsToOpenAITools: IMCPPresenter['mcpToolsToOpenAITools'] - mcpToolsToOpenAIResponsesTools: IMCPPresenter['mcpToolsToOpenAIResponsesTools'] getNpmRegistry?: IMCPPresenter['getNpmRegistry'] getUvRegistry?: IMCPPresenter['getUvRegistry'] } diff --git a/src/main/presenter/mcpPresenter/index.ts b/src/main/presenter/mcpPresenter/index.ts index f3eac05b7..9e0465d05 100644 --- a/src/main/presenter/mcpPresenter/index.ts +++ b/src/main/presenter/mcpPresenter/index.ts @@ -1,5 +1,6 @@ import { IMCPPresenter, + IConfigPresenter, MCPServerConfig, MCPToolDefinition, MCPToolCall, @@ -17,68 +18,9 @@ import { ToolManager } from './toolManager' import { McpRouterManager } from './mcprouterManager' import { eventBus, SendTarget } from '@/eventbus' import { MCP_EVENTS, NOTIFICATION_EVENTS } from '@/events' -import { IConfigPresenter } from '@shared/presenter' import { getErrorMessageLabels } from '@shared/i18n' -import { OpenAI } from 'openai' -import { ToolListUnion, Type, FunctionDeclaration } from '@google/genai' import { presenter } from '@/presenter' -// Define MCP tool interface -interface MCPTool { - id: string - name: string - type: string - description: string - serverName: string - inputSchema: { - properties: Record> - required: string[] - [key: string]: unknown - } -} - -// Define tool type interfaces for various LLM providers -interface OpenAIToolCall { - function: { - name: string - arguments: string - } -} - -interface AnthropicToolUse { - name: string - input: Record -} - -interface GeminiFunctionCall { - name: string - args: Record -} - -// Define tool conversion interfaces -interface OpenAITool { - type: 'function' - function: { - name: string - description: string - parameters: { - type: string - properties: Record> - required: string[] - } - } -} - -interface AnthropicTool { - name: string - description: string - input_schema: { - type: string - properties: Record> - required: string[] - } -} - // Complete McpPresenter implementation export class McpPresenter implements IMCPPresenter { private serverManager: ServerManager @@ -661,495 +603,6 @@ export class McpPresenter implements IMCPPresenter { }) } - // Convert MCPToolDefinition to MCPTool - private mcpToolDefinitionToMcpTool( - toolDefinition: MCPToolDefinition, - serverName: string - ): MCPTool { - const toolParameters = toolDefinition.function.parameters - const mcpTool = { - id: toolDefinition.function.name, - name: toolDefinition.function.name, - type: toolDefinition.type, - description: toolDefinition.function.description, - serverName, - inputSchema: { - properties: (toolParameters?.properties ?? {}) as Record>, - type: toolParameters?.type ?? 'object', - required: toolParameters?.required ?? [] - } - } as MCPTool - return mcpTool - } - - // Tool properties filter function - private filterPropertieAttributes(tool: MCPTool): Record> { - const supportedAttributes = [ - 'type', - 'nullable', - 'description', - 'properties', - 'items', - 'enum', - 'anyOf', - '$def' - ] - - const properties = tool.inputSchema.properties ?? {} - - // Recursive cleanup function to ensure all values are serializable - const cleanValue = (value: unknown): unknown => { - if (value === null || value === undefined) { - return value - } - - if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') { - return value - } - - if (Array.isArray(value)) { - return value.map(cleanValue) - } - - if (typeof value === 'object') { - const cleaned: Record = {} - for (const [k, v] of Object.entries(value as Record)) { - cleaned[k] = cleanValue(v) - } - return cleaned - } - - // For functions, Symbols and other non-serializable values, return string representation - return String(value) - } - - const getSubMap = (obj: Record, keys: string[]): Record => { - const filtered = Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key))) - const cleaned: Record = {} - for (const [key, value] of Object.entries(filtered)) { - if (key === 'type') { - // Handle type property specifically to ensure it has a valid value - const typeValue = cleanValue(value) - if ( - !typeValue || - typeof typeValue !== 'string' || - typeValue.trim() === '' || - typeValue.toLowerCase() === 'any' || - typeValue.toLowerCase() === 'unknown' - ) { - // Set to 'string' if type is missing, empty, 'any', or 'unknown' - cleaned[key] = 'string' - } else { - // Validate that it's a supported JSON Schema type - const supportedTypes = [ - 'string', - 'number', - 'integer', - 'boolean', - 'array', - 'object', - 'null' - ] - if (supportedTypes.includes(typeValue.toLowerCase())) { - cleaned[key] = typeValue.toLowerCase() - } else { - // If it's not a supported type, default to 'string' - cleaned[key] = 'string' - } - } - } else { - cleaned[key] = cleanValue(value) - } - } - - // Ensure type property exists if it was supposed to be included - if (keys.includes('type') && !cleaned.hasOwnProperty('type')) { - cleaned.type = 'string' - } - - return cleaned - } - - const result: Record> = {} - for (const [key, val] of Object.entries(properties)) { - if (typeof val === 'object' && val !== null) { - result[key] = getSubMap(val as Record, supportedAttributes) - } - } - - return result - } - - // New tool conversion methods - /** - * Convert MCP tool definitions to OpenAI tool format - * @param mcpTools Array of MCP tool definitions - * @param serverName Server name - * @returns Tool definitions in OpenAI tool format - */ - async mcpToolsToOpenAITools( - mcpTools: MCPToolDefinition[], - serverName: string - ): Promise { - const openaiTools: OpenAITool[] = mcpTools.map((toolDef) => { - const tool = this.mcpToolDefinitionToMcpTool(toolDef, serverName) - return { - type: 'function', - function: { - name: tool.name, - description: tool.description, - parameters: { - type: 'object', - properties: this.filterPropertieAttributes(tool), - required: tool.inputSchema.required || [] - } - } - } - }) - // console.log('openaiTools', JSON.stringify(openaiTools)) - return openaiTools - } - - /** - * Convert OpenAI tool call back to MCP tool call - * @param mcpTools Array of MCP tool definitions - * @param llmTool OpenAI tool call - * @param serverName Server name - * @returns Matching MCP tool call - */ - async openAIToolsToMcpTool( - llmTool: OpenAIToolCall, - providerId: string - ): Promise { - const mcpTools = await this.getAllToolDefinitions() - const tool = mcpTools.find((tool) => tool.function.name === llmTool.function.name) - if (!tool) { - return undefined - } - - // Create MCP tool call - const mcpToolCall: MCPToolCall = { - id: `${providerId}:${tool.function.name}-${Date.now()}`, // Generate unique ID including server name - type: tool.type, - function: { - name: tool.function.name, - arguments: llmTool.function.arguments - }, - server: { - name: tool.server.name, - icons: tool.server.icons, - description: tool.server.description - } - } - // console.log('mcpToolCall', mcpToolCall, tool) - - return mcpToolCall - } - - /** - * Convert MCP tool definitions to Anthropic tool format - * @param mcpTools Array of MCP tool definitions - * @param serverName Server name - * @returns Tool definitions in Anthropic tool format - */ - async mcpToolsToAnthropicTools( - mcpTools: MCPToolDefinition[], - serverName: string - ): Promise { - return mcpTools.map((toolDef) => { - const tool = this.mcpToolDefinitionToMcpTool(toolDef, serverName) - return { - name: tool.name, - description: tool.description, - input_schema: { - type: 'object', - properties: this.filterPropertieAttributes(tool), - required: tool.inputSchema.required as string[] - } - } - }) - } - - /** - * Convert Anthropic tool use back to MCP tool call - * @param mcpTools Array of MCP tool definitions - * @param toolUse Anthropic tool use - * @param serverName Server name - * @returns Matching MCP tool call - */ - async anthropicToolUseToMcpTool( - toolUse: AnthropicToolUse, - providerId: string - ): Promise { - const mcpTools = await this.getAllToolDefinitions() - - const tool = mcpTools.find((tool) => tool.function.name === toolUse.name) - // console.log('tool', tool, toolUse) - if (!tool) { - return undefined - } - - // Create MCP tool call - const mcpToolCall: MCPToolCall = { - id: `${providerId}:${tool.function.name}-${Date.now()}`, // Generate unique ID including server name - type: tool.type, - function: { - name: tool.function.name, - arguments: JSON.stringify(toolUse.input) - }, - server: { - name: tool.server.name, - icons: tool.server.icons, - description: tool.server.description - } - } - - return mcpToolCall - } - - /** - * Convert MCP tool definitions to Gemini tool format - * @param mcpTools Array of MCP tool definitions - * @param serverName Server name - * @returns Tool definitions in Gemini tool format - */ - async mcpToolsToGeminiTools( - mcpTools: MCPToolDefinition[] | undefined, - serverName: string - ): Promise { - if (!mcpTools || mcpTools.length === 0) { - return [] - } - - // Recursively clean Schema objects to ensure compliance with Gemini API requirements - const cleanSchema = (schema: Record): Record => { - const cleanedSchema: Record = {} - - // Handle type field - ensure always has valid value - if ('type' in schema) { - const type = schema.type - if (typeof type === 'string' && type.trim() !== '') { - cleanedSchema.type = type - } else if (Array.isArray(type) && type.length > 0) { - // If it's a type array, take the first non-empty type - const validType = type.find((t) => typeof t === 'string' && t.trim() !== '') - if (validType) { - cleanedSchema.type = validType - } else { - cleanedSchema.type = 'string' // Default type - } - } else { - // If no valid type, infer from other attributes - if ('enum' in schema) { - cleanedSchema.type = 'string' - } else if ('properties' in schema) { - cleanedSchema.type = 'object' - } else if ('items' in schema) { - cleanedSchema.type = 'array' - } else { - cleanedSchema.type = 'string' // Default type - } - } - } else { - // If there's no type field at all, infer from other attributes - if ('enum' in schema) { - cleanedSchema.type = 'string' - } else if ('properties' in schema) { - cleanedSchema.type = 'object' - } else if ('items' in schema) { - cleanedSchema.type = 'array' - } else if ('anyOf' in schema || 'oneOf' in schema) { - // For union types, try to infer the most appropriate type - cleanedSchema.type = 'string' // Default to string - } else { - cleanedSchema.type = 'string' // Final default type - } - } - - // Handle description - if ('description' in schema && typeof schema.description === 'string') { - cleanedSchema.description = schema.description - } - - // Handle enum - if ('enum' in schema && Array.isArray(schema.enum)) { - cleanedSchema.enum = schema.enum - // Ensure enum type is string - if (!cleanedSchema.type || cleanedSchema.type === '') { - cleanedSchema.type = 'string' - } - } - - // Handle properties - if ( - 'properties' in schema && - typeof schema.properties === 'object' && - schema.properties !== null - ) { - const properties = schema.properties as Record - const cleanedProperties: Record = {} - - for (const [propName, propValue] of Object.entries(properties)) { - if (typeof propValue === 'object' && propValue !== null) { - cleanedProperties[propName] = cleanSchema(propValue as Record) - } - } - - if (Object.keys(cleanedProperties).length > 0) { - cleanedSchema.properties = cleanedProperties - cleanedSchema.type = 'object' - } - } - - // Handle items (array type) - if ('items' in schema && typeof schema.items === 'object' && schema.items !== null) { - cleanedSchema.items = cleanSchema(schema.items as Record) - cleanedSchema.type = 'array' - } - - // Handle nullable - if ('nullable' in schema && typeof schema.nullable === 'boolean') { - cleanedSchema.nullable = schema.nullable - } - - // Handle anyOf/oneOf (union types) - simplify to single type - if ('anyOf' in schema && Array.isArray(schema.anyOf)) { - const anyOfOptions = schema.anyOf as Array> - - // Try to find the most suitable type - let bestOption = anyOfOptions[0] - - // Prefer options with enum - for (const option of anyOfOptions) { - if ('enum' in option && Array.isArray(option.enum)) { - bestOption = option - break - } - } - - // If no enum, prefer string type - if (!('enum' in bestOption)) { - for (const option of anyOfOptions) { - if (option.type === 'string') { - bestOption = option - break - } - } - } - - // Recursively clean the selected option - const cleanedOption = cleanSchema(bestOption) - Object.assign(cleanedSchema, cleanedOption) - } - - // Handle oneOf similar to anyOf - if ('oneOf' in schema && Array.isArray(schema.oneOf)) { - const oneOfOptions = schema.oneOf as Array> - const bestOption = oneOfOptions[0] || {} - const cleanedOption = cleanSchema(bestOption) - Object.assign(cleanedSchema, cleanedOption) - } - - // Final check: ensure type field is mandatory - if (!cleanedSchema.type || cleanedSchema.type === '') { - cleanedSchema.type = 'string' - } - - return cleanedSchema - } - - // Process each tool definition to build function declarations that comply with Gemini API - const functionDeclarations = mcpTools.map((toolDef) => { - // Convert to internal tool representation - const tool = this.mcpToolDefinitionToMcpTool(toolDef, serverName) - - // Get parameter properties - const properties = tool.inputSchema.properties - const processedProperties: Record> = {} - - // Process each property and apply cleanup function - for (const [propName, propValue] of Object.entries(properties)) { - if (typeof propValue === 'object' && propValue !== null) { - const cleaned = cleanSchema(propValue as Record) - // Ensure cleaned property has valid type - if (cleaned.type && cleaned.type !== '') { - processedProperties[propName] = cleaned - } else { - console.warn(`[MCP] Skipping property ${propName} due to invalid type`) - } - } - } - - // Prepare function declaration structure - const functionDeclaration: FunctionDeclaration = { - name: tool.id, - description: tool.description - } - - if (Object.keys(processedProperties).length > 0) { - functionDeclaration.parameters = { - type: Type.OBJECT, - properties: processedProperties, - required: tool.inputSchema.required || [] - } - } - - // Log functions without parameters - if (Object.keys(processedProperties).length === 0) { - console.log( - `[MCP] Function ${tool.id} has no parameters, providing minimal parameter structure` - ) - } - - return functionDeclaration - }) - - // Return result in Gemini tool format - return [ - { - functionDeclarations - } - ] - } - - /** - * Convert Gemini function call back to MCP tool call - * @param mcpTools Array of MCP tool definitions - * @param fcall Gemini function call - * @param serverName Server name - * @returns Matching MCP tool call - */ - async geminiFunctionCallToMcpTool( - fcall: GeminiFunctionCall | undefined, - providerId: string - ): Promise { - const mcpTools = await this.getAllToolDefinitions() - if (!fcall) return undefined - if (!mcpTools) return undefined - - const tool = mcpTools.find((tool) => tool.function.name === fcall.name) - if (!tool) { - return undefined - } - - // Create MCP tool call - const mcpToolCall: MCPToolCall = { - id: `${providerId}:${tool.function.name}-${Date.now()}`, // Generate unique ID including server name - type: tool.type, - function: { - name: tool.function.name, - arguments: JSON.stringify(fcall.args) - }, - server: { - name: tool.server.name, - icons: tool.server.icons, - description: tool.server.description - } - } - - return mcpToolCall - } - // Get MCP enabled status async getMcpEnabled(): Promise { return this.configPresenter.getMcpEnabled() @@ -1238,33 +691,6 @@ export class McpPresenter implements IMCPPresenter { return this.toolManager.readResourceByClient(resource.client.name, resource.uri) } - /** - * Convert MCP tool definitions to OpenAI Responses API tool format - * @param mcpTools Array of MCP tool definitions - * @param serverName Server name - * @returns Tool definitions in OpenAI Responses API tool format - */ - async mcpToolsToOpenAIResponsesTools( - mcpTools: MCPToolDefinition[], - serverName: string - ): Promise { - const openaiTools: OpenAI.Responses.Tool[] = mcpTools.map((toolDef) => { - const tool = this.mcpToolDefinitionToMcpTool(toolDef, serverName) - return { - type: 'function', - name: tool.name, - description: tool.description, - parameters: { - type: 'object', - properties: this.filterPropertieAttributes(tool), - required: tool.inputSchema.required || [] - }, - strict: false - } - }) - return openaiTools - } - async grantPermission( serverName: string, permissionType: 'read' | 'write' | 'all', diff --git a/src/main/presenter/sessionPresenter/messageFormatter.ts b/src/main/presenter/sessionPresenter/messageFormatter.ts index ec4f5e4ff..a5ceed0f9 100644 --- a/src/main/presenter/sessionPresenter/messageFormatter.ts +++ b/src/main/presenter/sessionPresenter/messageFormatter.ts @@ -15,6 +15,29 @@ const FILE_CONTENT_TRUNCATION_SUFFIX = '…(truncated)' type VisionUserMessageContent = UserMessageContent & { images?: string[] } +function parseProviderOptionsJson( + value: unknown +): Record> | undefined { + if (typeof value !== 'string' || !value) { + return undefined + } + + try { + const parsed = JSON.parse(value) + if (isRecord(parsed) && !Array.isArray(parsed)) { + return parsed as Record> + } + } catch {} + + return undefined +} + +function getBlockProviderOptions( + block: AssistantMessageBlock +): Record> | undefined { + return parseProviderOptionsJson(block.extra?.providerOptionsJson) +} + function isRecord(value: unknown): value is Record { return typeof value === 'object' && value !== null } @@ -253,13 +276,15 @@ export function addContextMessages( if (block.type === 'tool_call' && block.tool_call) { if (block.tool_call.response) { const toolCallId = block.tool_call.id || nanoid(8) + const providerOptions = getBlockProviderOptions(block) toolCalls.push({ id: toolCallId, type: 'function', function: { name: block.tool_call.name, arguments: block.tool_call.params || '' - } + }, + ...(providerOptions ? { provider_options: providerOptions } : {}) }) toolResponses.push({ id: toolCallId, @@ -267,14 +292,24 @@ export function addContextMessages( }) } } else if (block.type === 'content' && block.content) { - messageContent.push({ type: 'text', text: block.content }) + const providerOptions = getBlockProviderOptions(block) + messageContent.push({ + type: 'text', + text: block.content, + ...(providerOptions ? { provider_options: providerOptions } : {}) + }) } }) if (toolCalls.length > 0) { const assistantMessage: ChatMessage = { role: 'assistant', - content: messageContent.length > 0 ? messageContent : undefined, + content: + messageContent.length > 0 + ? messageContent.some((part) => part.type === 'text' && part.provider_options) + ? messageContent + : messageContent.map((part) => ('text' in part ? part.text : '')).join('') + : undefined, tool_calls: toolCalls } resultMessages.push(assistantMessage) @@ -289,7 +324,9 @@ export function addContextMessages( } else if (messageContent.length > 0) { const assistantMessage: ChatMessage = { role: 'assistant', - content: messageContent + content: messageContent.some((part) => part.type === 'text' && part.provider_options) + ? messageContent + : messageContent.map((part) => ('text' in part ? part.text : '')).join('') } resultMessages.push(assistantMessage) } diff --git a/src/shared/types/core/chat-message.ts b/src/shared/types/core/chat-message.ts index 5347619b4..ecdf50a1d 100644 --- a/src/shared/types/core/chat-message.ts +++ b/src/shared/types/core/chat-message.ts @@ -1,14 +1,21 @@ export type ChatMessageRole = 'system' | 'user' | 'assistant' | 'tool' +export type ChatMessageProviderOptions = Record> + export type ChatMessageToolCall = { id: string type: 'function' function: { name: string; arguments: string } + provider_options?: ChatMessageProviderOptions } export type ChatMessageContent = - | { type: 'text'; text: string } - | { type: 'image_url'; image_url: { url: string; detail?: 'auto' | 'low' | 'high' } } + | { type: 'text'; text: string; provider_options?: ChatMessageProviderOptions } + | { + type: 'image_url' + image_url: { url: string; detail?: 'auto' | 'low' | 'high' } + provider_options?: ChatMessageProviderOptions + } export type ChatMessage = { role: ChatMessageRole @@ -16,4 +23,6 @@ export type ChatMessage = { tool_calls?: ChatMessageToolCall[] tool_call_id?: string reasoning_content?: string + reasoning_provider_options?: ChatMessageProviderOptions + provider_options?: ChatMessageProviderOptions } diff --git a/src/shared/types/core/chat.ts b/src/shared/types/core/chat.ts index 142b56468..b55992fa6 100644 --- a/src/shared/types/core/chat.ts +++ b/src/shared/types/core/chat.ts @@ -106,6 +106,7 @@ export type AssistantMessageBlock = { export type { ChatMessage, ChatMessageContent, + ChatMessageProviderOptions, ChatMessageRole, ChatMessageToolCall } from './chat-message' diff --git a/src/shared/types/core/llm-events.ts b/src/shared/types/core/llm-events.ts index b8645253d..a18e7f034 100644 --- a/src/shared/types/core/llm-events.ts +++ b/src/shared/types/core/llm-events.ts @@ -1,5 +1,7 @@ // Strong-typed LLM core stream events (discriminated union) +import type { ChatMessageProviderOptions } from './chat-message' + export type StreamEventType = | 'text' | 'reasoning' @@ -16,29 +18,34 @@ export type StreamEventType = export interface TextStreamEvent { type: 'text' content: string + provider_options?: ChatMessageProviderOptions } export interface ReasoningStreamEvent { type: 'reasoning' reasoning_content: string + provider_options?: ChatMessageProviderOptions } export interface ToolCallStartEvent { type: 'tool_call_start' tool_call_id: string tool_call_name: string + provider_options?: ChatMessageProviderOptions } export interface ToolCallChunkEvent { type: 'tool_call_chunk' tool_call_id: string tool_call_arguments_chunk: string + provider_options?: ChatMessageProviderOptions } export interface ToolCallEndEvent { type: 'tool_call_end' tool_call_id: string tool_call_arguments_complete?: string + provider_options?: ChatMessageProviderOptions } export interface PermissionRequestEvent { @@ -102,30 +109,54 @@ export type LLMCoreStreamEvent = export type { ChatMessage, ChatMessageContent, + ChatMessageProviderOptions, ChatMessageRole, ChatMessageToolCall } from './chat-message' export const createStreamEvent = { - text: (content: string): TextStreamEvent => ({ type: 'text', content }), - reasoning: (reasoning_content: string): ReasoningStreamEvent => ({ + text: (content: string, provider_options?: ChatMessageProviderOptions): TextStreamEvent => ({ + type: 'text', + content, + ...(provider_options ? { provider_options } : {}) + }), + reasoning: ( + reasoning_content: string, + provider_options?: ChatMessageProviderOptions + ): ReasoningStreamEvent => ({ type: 'reasoning', - reasoning_content + reasoning_content, + ...(provider_options ? { provider_options } : {}) }), - toolCallStart: (tool_call_id: string, tool_call_name: string): ToolCallStartEvent => ({ + toolCallStart: ( + tool_call_id: string, + tool_call_name: string, + provider_options?: ChatMessageProviderOptions + ): ToolCallStartEvent => ({ type: 'tool_call_start', tool_call_id, - tool_call_name + tool_call_name, + ...(provider_options ? { provider_options } : {}) }), - toolCallChunk: (tool_call_id: string, tool_call_arguments_chunk: string): ToolCallChunkEvent => ({ + toolCallChunk: ( + tool_call_id: string, + tool_call_arguments_chunk: string, + provider_options?: ChatMessageProviderOptions + ): ToolCallChunkEvent => ({ type: 'tool_call_chunk', tool_call_id, - tool_call_arguments_chunk + tool_call_arguments_chunk, + ...(provider_options ? { provider_options } : {}) }), - toolCallEnd: (tool_call_id: string, tool_call_arguments_complete?: string): ToolCallEndEvent => ({ + toolCallEnd: ( + tool_call_id: string, + tool_call_arguments_complete?: string, + provider_options?: ChatMessageProviderOptions + ): ToolCallEndEvent => ({ type: 'tool_call_end', tool_call_id, - tool_call_arguments_complete + tool_call_arguments_complete, + ...(provider_options ? { provider_options } : {}) }), permission: (permission: PermissionRequestPayload): PermissionRequestEvent => ({ type: 'permission', diff --git a/src/shared/types/presenters/legacy.presenters.d.ts b/src/shared/types/presenters/legacy.presenters.d.ts index c73198570..13ad22a70 100644 --- a/src/shared/types/presenters/legacy.presenters.d.ts +++ b/src/shared/types/presenters/legacy.presenters.d.ts @@ -1852,20 +1852,6 @@ export interface IMCPPresenter { setMcpRouterApiKey?(key: string): Promise isServerInstalled?(source: string, sourceId: string): Promise updateMcpRouterServersAuth?(apiKey: string): Promise - - mcpToolsToAnthropicTools( - mcpTools: MCPToolDefinition[], - serverName: string - ): Promise - mcpToolsToGeminiTools( - mcpTools: MCPToolDefinition[] | undefined, - serverName: string - ): Promise - mcpToolsToOpenAITools(mcpTools: MCPToolDefinition[], serverName: string): Promise - mcpToolsToOpenAIResponsesTools( - mcpTools: MCPToolDefinition[], - serverName: string - ): Promise } export interface IDeeplinkPresenter { diff --git a/test/main/presenter/agentRuntimePresenter/contextBuilder.test.ts b/test/main/presenter/agentRuntimePresenter/contextBuilder.test.ts index 3184c6340..2f99461ce 100644 --- a/test/main/presenter/agentRuntimePresenter/contextBuilder.test.ts +++ b/test/main/presenter/agentRuntimePresenter/contextBuilder.test.ts @@ -163,6 +163,45 @@ function makeAssistantWithReasoningAndToolRecord( } } +function makeAssistantWithToolProviderOptionsRecord( + orderSeq: number, + text: string, + toolResponse: string +) { + return { + id: `asst-${orderSeq}`, + sessionId: 's1', + orderSeq, + role: 'assistant' as const, + content: JSON.stringify([ + { type: 'content', content: text, status: 'success', timestamp: Date.now() }, + { + type: 'tool_call', + status: 'success', + timestamp: Date.now(), + extra: { + providerOptionsJson: JSON.stringify({ + vertex: { + thoughtSignature: 'tool-thought-signature' + } + }) + }, + tool_call: { + id: `tc-${orderSeq}`, + name: 'example_tool', + params: '{"foo":"bar"}', + response: toolResponse + } + } + ]), + status: 'sent' as const, + isContextEdge: 0, + metadata: '{}', + createdAt: Date.now(), + updatedAt: Date.now() + } +} + describe('truncateContext', () => { it('returns all messages when within budget', () => { const history = [ @@ -513,6 +552,37 @@ describe('buildContext', () => { ]) }) + it('replays settled tool call provider options for follow-up turns', () => { + const messages = [ + makeUserRecord(1, 'check this'), + makeAssistantWithToolProviderOptionsRecord(2, 'Tool finished', 'All good') + ] + const store = createMockMessageStore(messages) + const result = buildContext('s1', 'next', '', 10000, 4096, store) + + expect(result).toEqual([ + { role: 'user', content: 'check this' }, + { + role: 'assistant', + content: 'Tool finished', + tool_calls: [ + { + id: 'tc-2', + type: 'function', + function: { name: 'example_tool', arguments: '{"foo":"bar"}' }, + provider_options: { + vertex: { + thoughtSignature: 'tool-thought-signature' + } + } + } + ] + }, + { role: 'tool', tool_call_id: 'tc-2', content: 'All good' }, + { role: 'user', content: 'next' } + ]) + }) + it('includes non-image file context in user content', () => { const store = createMockMessageStore([]) const result = buildContext( diff --git a/test/main/presenter/agentRuntimePresenter/dispatch.test.ts b/test/main/presenter/agentRuntimePresenter/dispatch.test.ts index 1ee50a6da..e3ed0a2af 100644 --- a/test/main/presenter/agentRuntimePresenter/dispatch.test.ts +++ b/test/main/presenter/agentRuntimePresenter/dispatch.test.ts @@ -528,6 +528,72 @@ describe('dispatch', () => { expect(assistantMsg.reasoning_content).toBe('Let me think...') }) + it('preserves tool call provider options in the follow-up assistant message', async () => { + const tools = [makeTool('exec')] + const toolPresenter = createMockToolPresenter({ exec: 'done' }) + const conversation: any[] = [] + + state.blocks.push({ + type: 'tool_call', + content: '', + status: 'pending', + timestamp: Date.now(), + tool_call: { + id: 'tc1', + name: 'exec', + params: '{"command":"tree"}', + response: '' + }, + extra: { + providerOptionsJson: JSON.stringify({ + vertex: { + thoughtSignature: 'tool-thought-signature' + } + }) + } + }) + state.completedToolCalls = [ + { + id: 'tc1', + name: 'exec', + arguments: '{"command":"tree"}', + providerOptions: { + vertex: { + thoughtSignature: 'tool-thought-signature' + } + } + } + ] + + await executeTools( + state, + conversation, + 0, + tools, + toolPresenter, + 'gemini-3.1-flash-lite-preview', + io, + 'full_access', + new ToolOutputGuard(), + 32000, + 1024 + ) + + const assistantMsg = conversation.find((message: any) => message.role === 'assistant') + expect(assistantMsg.tool_calls).toEqual([ + { + id: 'tc1', + type: 'function', + function: { name: 'exec', arguments: '{"command":"tree"}' }, + provider_options: { + vertex: { + thoughtSignature: 'tool-thought-signature' + } + } + } + ]) + }) + it('does not include reasoning_content when compatibility is disabled', async () => { const tools = [makeTool('search')] const toolPresenter = createMockToolPresenter({ search: 'result' }) diff --git a/test/main/presenter/llmProviderPresenter.test.ts b/test/main/presenter/llmProviderPresenter.test.ts index 8e2322686..f4c3b81ae 100644 --- a/test/main/presenter/llmProviderPresenter.test.ts +++ b/test/main/presenter/llmProviderPresenter.test.ts @@ -2,7 +2,19 @@ import { describe, it, expect, beforeEach, vi, beforeAll, afterEach } from 'vite import { LLMProviderPresenter } from '../../../src/main/presenter/llmProviderPresenter/index' import { ConfigPresenter } from '../../../src/main/presenter/configPresenter/index' import { LLM_PROVIDER, ChatMessage, ISQLitePresenter } from '../../../src/shared/presenter' -import { OpenAICompatibleProvider } from '../../../src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider' +import { AiSdkProvider } from '../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' + +const { + mockRunAiSdkCoreStream, + mockRunAiSdkDimensions, + mockRunAiSdkEmbeddings, + mockRunAiSdkGenerateText +} = vi.hoisted(() => ({ + mockRunAiSdkCoreStream: vi.fn(), + mockRunAiSdkDimensions: vi.fn(), + mockRunAiSdkEmbeddings: vi.fn(), + mockRunAiSdkGenerateText: vi.fn().mockResolvedValue({ content: 'mock completion' }) +})) // Ensure electron is mocked for this suite to avoid CJS named export issues vi.mock('electron', () => { @@ -77,6 +89,13 @@ vi.mock('@/presenter/proxyConfig', () => ({ } })) +vi.mock('../../../src/main/presenter/llmProviderPresenter/aiSdk', () => ({ + runAiSdkCoreStream: mockRunAiSdkCoreStream, + runAiSdkDimensions: mockRunAiSdkDimensions, + runAiSdkEmbeddings: mockRunAiSdkEmbeddings, + runAiSdkGenerateText: mockRunAiSdkGenerateText +})) + describe('LLMProviderPresenter Integration Tests', () => { let llmProviderPresenter: LLMProviderPresenter let mockConfigPresenter: ConfigPresenter @@ -158,6 +177,19 @@ describe('LLMProviderPresenter Integration Tests', () => { beforeEach(() => { // Clear all mocks before each test vi.clearAllMocks() + vi.unstubAllGlobals() + mockRunAiSdkGenerateText.mockResolvedValue({ content: 'mock completion' }) + + vi.stubGlobal( + 'fetch', + vi.fn().mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ id: 'mock-gpt-thinking' }, { id: 'gpt-4-mock' }, { id: 'mock-gpt-markdown' }] + }), + text: vi.fn().mockResolvedValue('') + }) + ) // Reset mock implementations mockConfigPresenter.getProviders = vi.fn().mockReturnValue([mockProvider]) @@ -194,6 +226,7 @@ describe('LLMProviderPresenter Integration Tests', () => { // Wait for any pending async operations to complete await new Promise((resolve) => setTimeout(resolve, 100)) + vi.unstubAllGlobals() }) describe('Basic Provider Management', () => { @@ -237,7 +270,7 @@ describe('LLMProviderPresenter Integration Tests', () => { const providerInstance = llmProviderPresenter.getProviderInstance('novita') - expect(providerInstance).toBeInstanceOf(OpenAICompatibleProvider) + expect(providerInstance).toBeInstanceOf(AiSdkProvider) }) }) @@ -369,6 +402,8 @@ describe('LLMProviderPresenter Integration Tests', () => { }) it('should handle provider check failure for invalid config', async () => { + vi.stubGlobal('fetch', vi.fn().mockRejectedValue(new Error('Network error'))) + // 创建一个无效配置的provider const invalidProvider: LLM_PROVIDER = { id: 'invalid-test', diff --git a/test/main/presenter/llmProviderPresenter/aiSdkMessageMapper.test.ts b/test/main/presenter/llmProviderPresenter/aiSdkMessageMapper.test.ts new file mode 100644 index 000000000..009dcd686 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/aiSdkMessageMapper.test.ts @@ -0,0 +1,38 @@ +import { describe, expect, it } from 'vitest' +import { mapMessagesToModelMessages } from '@/presenter/llmProviderPresenter/aiSdk/messageMapper' + +describe('AI SDK message mapper', () => { + it('skips malformed non-text user content parts instead of throwing', () => { + const result = mapMessagesToModelMessages( + [ + { + role: 'user', + content: [ + { type: 'text', text: 'hello' }, + { type: 'image_url', image_url: { url: 'https://example.com/a.png' } }, + { type: 'image_url' }, + { type: 'unknown', value: 'ignored' } + ] as any + } + ], + { + tools: [], + supportsNativeTools: true + } + ) + + expect(result).toEqual([ + { + role: 'user', + content: [ + { type: 'text', text: 'hello' }, + { + type: 'image', + image: new URL('https://example.com/a.png'), + mediaType: 'image/png' + } + ] + } + ]) + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/aiSdkProviderFactory.test.ts b/test/main/presenter/llmProviderPresenter/aiSdkProviderFactory.test.ts new file mode 100644 index 000000000..12b3d7486 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/aiSdkProviderFactory.test.ts @@ -0,0 +1,210 @@ +import { describe, expect, it } from 'vitest' +import { + createAiSdkProviderContext, + normalizeAzureBaseUrl, + normalizeAnthropicBaseUrl, + normalizeVertexRequestBody, + normalizeVertexBaseUrl +} from '@/presenter/llmProviderPresenter/aiSdk/providerFactory' + +describe('AI SDK provider factory', () => { + it('normalizes anthropic-style base urls to a v1 prefix', () => { + expect(normalizeAnthropicBaseUrl('https://api.anthropic.com')).toBe( + 'https://api.anthropic.com/v1' + ) + expect(normalizeAnthropicBaseUrl('https://api.minimaxi.com/anthropic')).toBe( + 'https://api.minimaxi.com/anthropic/v1' + ) + expect(normalizeAnthropicBaseUrl('https://zenmux.ai/api/anthropic/')).toBe( + 'https://zenmux.ai/api/anthropic/v1' + ) + }) + + it('avoids duplicating the messages suffix', () => { + expect(normalizeAnthropicBaseUrl('https://api.anthropic.com/v1')).toBe( + 'https://api.anthropic.com/v1' + ) + expect(normalizeAnthropicBaseUrl('https://zenmux.ai/api/anthropic/v1/messages')).toBe( + 'https://zenmux.ai/api/anthropic/v1' + ) + expect(normalizeAnthropicBaseUrl('https://proxy.example.com/messages')).toBe( + 'https://proxy.example.com' + ) + }) + + it('normalizes vertex express-mode base urls to the publishers/google prefix', () => { + expect(normalizeVertexBaseUrl('https://zenmux.ai/api/vertex-ai', 'api-key', 'v1')).toBe( + 'https://zenmux.ai/api/vertex-ai/v1/publishers/google' + ) + expect(normalizeVertexBaseUrl('https://zenmux.ai/api/vertex-ai/v1', 'api-key', 'v1')).toBe( + 'https://zenmux.ai/api/vertex-ai/v1/publishers/google' + ) + expect( + normalizeVertexBaseUrl( + 'https://zenmux.ai/api/vertex-ai/v1/publishers/google', + 'api-key', + 'v1' + ) + ).toBe('https://zenmux.ai/api/vertex-ai/v1/publishers/google') + }) + + it('removes default AUTO tool config from vertex request bodies', () => { + expect( + normalizeVertexRequestBody({ + contents: [], + tools: [], + toolConfig: { + functionCallingConfig: { + mode: 'AUTO' + } + } + }) + ).toEqual({ + contents: [], + tools: [] + }) + }) + + it('normalizes vertex system instructions and tool schemas to google genai wire format', () => { + expect( + normalizeVertexRequestBody({ + systemInstruction: { + parts: [{ text: 'sys' }] + }, + tools: [ + { + functionDeclarations: [ + { + name: 'skill_manage', + parameters: { + type: 'object', + properties: { + action: { type: 'string' }, + enabled: { type: 'boolean' } + }, + required: ['action'] + } + } + ] + } + ] + }) + ).toEqual({ + systemInstruction: { + role: 'user', + parts: [{ text: 'sys' }] + }, + tools: [ + { + functionDeclarations: [ + { + name: 'skill_manage', + parameters: { + type: 'OBJECT', + properties: { + action: { type: 'STRING' }, + enabled: { type: 'BOOLEAN' } + }, + required: ['action'] + } + } + ] + } + ] + }) + }) + + it('normalizes azure resource base urls to the openai prefix with v1 semantics', () => { + expect(normalizeAzureBaseUrl('https://example.openai.azure.com', undefined)).toEqual({ + baseURL: 'https://example.openai.azure.com/openai', + apiVersion: 'v1', + useDeploymentBasedUrls: false + }) + + expect(normalizeAzureBaseUrl('https://example.openai.azure.com/openai/v1', undefined)).toEqual({ + baseURL: 'https://example.openai.azure.com/openai', + apiVersion: 'v1', + useDeploymentBasedUrls: false + }) + }) + + it('preserves deployment-based azure urls and legacy api versions', () => { + expect( + normalizeAzureBaseUrl( + 'https://example.openai.azure.com/openai/deployments/deepchat-prod', + undefined + ) + ).toEqual({ + baseURL: 'https://example.openai.azure.com/openai', + apiVersion: '2024-02-01', + useDeploymentBasedUrls: true, + deploymentName: 'deepchat-prod' + }) + + expect( + normalizeAzureBaseUrl( + 'https://example.openai.azure.com/openai/deployments/deepchat-prod', + '2025-04-01-preview' + ) + ).toEqual({ + baseURL: 'https://example.openai.azure.com/openai', + apiVersion: '2025-04-01-preview', + useDeploymentBasedUrls: true, + deploymentName: 'deepchat-prod' + }) + }) + + it('builds azure responses endpoints without duplicating v1 segments', () => { + const context = createAiSdkProviderContext({ + providerKind: 'azure', + provider: { + id: 'azure-openai', + name: 'Azure OpenAI', + apiKey: 'test-key', + baseUrl: 'https://example.openai.azure.com/openai/v1', + enable: false + } as any, + configPresenter: { + getSetting: () => undefined + } as any, + defaultHeaders: {}, + modelId: 'my-gpt-4.1-deployment' + }) + + expect(context.apiType).toBe('azure_responses') + expect(context.providerOptionsKey).toBe('azure') + expect(context.endpoint).toBe( + 'https://example.openai.azure.com/openai/v1/responses?api-version=v1' + ) + expect(context.embeddingEndpoint).toBe( + 'https://example.openai.azure.com/openai/v1/embeddings?api-version=v1' + ) + expect(context.imageEndpoint).toBe( + 'https://example.openai.azure.com/openai/v1/images/generations?api-version=v1' + ) + expect(context.resolvedModelId).toBe('my-gpt-4.1-deployment') + }) + + it('uses deployment ids from azure deployment-scoped urls', () => { + const context = createAiSdkProviderContext({ + providerKind: 'azure', + provider: { + id: 'azure-openai', + name: 'Azure OpenAI', + apiKey: 'test-key', + baseUrl: 'https://example.openai.azure.com/openai/deployments/deepchat-prod', + enable: false + } as any, + configPresenter: { + getSetting: () => undefined + } as any, + defaultHeaders: {}, + modelId: 'ignored-model-id' + }) + + expect(context.endpoint).toBe( + 'https://example.openai.azure.com/openai/deployments/deepchat-prod/responses?api-version=2024-02-01' + ) + expect(context.resolvedModelId).toBe('deepchat-prod') + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/aiSdkProviderOptionsMapper.test.ts b/test/main/presenter/llmProviderPresenter/aiSdkProviderOptionsMapper.test.ts new file mode 100644 index 000000000..d42ee3c83 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/aiSdkProviderOptionsMapper.test.ts @@ -0,0 +1,238 @@ +import { describe, expect, it, vi } from 'vitest' + +const { mockGetThinkingBudgetRange, mockGetModel, mockSupportsReasoning } = vi.hoisted(() => ({ + mockGetThinkingBudgetRange: vi.fn().mockReturnValue({}), + mockGetModel: vi.fn().mockReturnValue(undefined), + mockSupportsReasoning: vi.fn().mockReturnValue(false) +})) + +vi.mock('@/presenter/configPresenter/providerDbLoader', () => ({ + providerDbLoader: { + getModel: mockGetModel + } +})) + +vi.mock('@/presenter/configPresenter/modelCapabilities', () => ({ + modelCapabilities: { + getThinkingBudgetRange: mockGetThinkingBudgetRange, + supportsReasoning: mockSupportsReasoning + } +})) + +import { buildProviderOptions } from '@/presenter/llmProviderPresenter/aiSdk/providerOptionsMapper' + +describe('AI SDK provider options', () => { + const baseModelConfig = { + reasoning: true, + reasoningEffort: 'high' as const, + thinkingBudget: 2048, + conversationId: 'conv-1' + } + + it('keeps official anthropic beta features enabled', () => { + const result = buildProviderOptions({ + providerId: 'anthropic', + providerOptionsKey: 'anthropic', + apiType: 'anthropic', + modelId: 'claude-sonnet-4-5', + modelConfig: baseModelConfig, + tools: [], + messages: [] + }) + + expect(result.providerOptions?.anthropic).toMatchObject({ + toolStreaming: true, + sendReasoning: true, + effort: 'high', + thinking: { + type: 'enabled', + budgetTokens: 2048 + } + }) + }) + + it('disables anthropic beta-only options for compatible providers', () => { + const result = buildProviderOptions({ + providerId: 'zenmux', + providerOptionsKey: 'anthropic', + apiType: 'anthropic', + modelId: 'anthropic/claude-sonnet-4.5', + modelConfig: baseModelConfig, + tools: [], + messages: [] + }) + + expect(result.providerOptions?.anthropic).toMatchObject({ + toolStreaming: false, + thinking: { + type: 'enabled', + budgetTokens: 2048 + } + }) + expect(result.providerOptions?.anthropic).not.toHaveProperty('sendReasoning') + expect(result.providerOptions?.anthropic).not.toHaveProperty('effort') + }) + + it('disables anthropic beta-only options for custom anthropic providers', () => { + const result = buildProviderOptions({ + providerId: 'my-anthropic-proxy', + providerOptionsKey: 'anthropic', + apiType: 'anthropic', + modelId: 'claude-sonnet-4-5', + modelConfig: { + reasoningEffort: 'medium' as const + }, + tools: [], + messages: [] + }) + + expect(result.providerOptions?.anthropic).toMatchObject({ + toolStreaming: false + }) + expect(result.providerOptions?.anthropic).not.toHaveProperty('effort') + }) + + it('adds doubao thinking options through providerOptions instead of monkey-patching the sdk client', () => { + mockGetModel.mockReturnValue({ + extra_capabilities: { + reasoning: { + notes: ['doubao-thinking-parameter'] + } + } + }) + + const result = buildProviderOptions({ + providerId: 'doubao', + providerOptionsKey: 'openai', + apiType: 'openai_chat', + modelId: 'doubao-seed-2.0-pro', + modelConfig: { + reasoning: true + }, + tools: [], + messages: [] + }) + + expect(result.providerOptions).toEqual({ + openai: { + thinking: { + type: 'enabled' + } + } + }) + }) + + it('adds siliconcloud thinking flags through providerOptions for supported models', () => { + const result = buildProviderOptions({ + providerId: 'siliconcloud', + providerOptionsKey: 'openai', + apiType: 'openai_chat', + modelId: 'Qwen/Qwen3-32B', + modelConfig: { + reasoning: true + }, + tools: [], + messages: [] + }) + + expect(result.providerOptions).toEqual({ + openai: { + enable_thinking: true + } + }) + }) + + it('maps grok reasoning effort to the vendor-specific body field', () => { + const result = buildProviderOptions({ + providerId: 'grok', + providerOptionsKey: 'openai', + apiType: 'openai_chat', + modelId: 'grok-3-mini', + modelConfig: { + reasoningEffort: 'medium' as const + }, + tools: [], + messages: [] + }) + + expect(result.providerOptions).toEqual({ + openai: { + reasoning_effort: 'medium' + } + }) + }) + + it('disables vertex function-call argument streaming when no tools are present', () => { + const result = buildProviderOptions({ + providerId: 'vertex', + providerOptionsKey: 'vertex', + apiType: 'vertex', + modelId: 'gemini-2.5-flash', + modelConfig: {}, + tools: [], + messages: [] + }) + + expect(result.providerOptions?.vertex).toMatchObject({ + streamFunctionCallArguments: false + }) + }) + + it('enables vertex function-call argument streaming when tools are present', () => { + const result = buildProviderOptions({ + providerId: 'vertex', + providerOptionsKey: 'vertex', + apiType: 'vertex', + modelId: 'gemini-2.5-flash', + modelConfig: {}, + tools: [ + { + type: 'function', + function: { + name: 'skill_manage', + description: 'Manage a skill', + parameters: { + type: 'object', + properties: { + name: { + type: 'string' + } + } + } + } + } + ] as any, + messages: [] + }) + + expect(result.providerOptions?.vertex).toMatchObject({ + streamFunctionCallArguments: true + }) + }) + + it('keeps azure responses options under the azure namespace without prompt cache keys', () => { + const result = buildProviderOptions({ + providerId: 'azure-openai', + providerOptionsKey: 'azure', + apiType: 'azure_responses', + modelId: 'my-gpt-4.1-deployment', + modelConfig: { + reasoningEffort: 'medium' as const, + verbosity: 'high' as const, + maxCompletionTokens: 2048, + conversationId: 'conv-1' + }, + tools: [], + messages: [] + }) + + expect(result.providerOptions).toEqual({ + azure: { + reasoningEffort: 'medium', + textVerbosity: 'high', + maxCompletionTokens: 2048 + } + }) + expect(result.providerOptions?.azure).not.toHaveProperty('promptCacheKey') + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/aiSdkRuntime.test.ts b/test/main/presenter/llmProviderPresenter/aiSdkRuntime.test.ts new file mode 100644 index 000000000..2197f3058 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/aiSdkRuntime.test.ts @@ -0,0 +1,115 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockGenerateImage, mockCreateAiSdkProviderContext, mockCacheImage } = vi.hoisted(() => ({ + mockGenerateImage: vi.fn(), + mockCreateAiSdkProviderContext: vi.fn(), + mockCacheImage: vi.fn() +})) + +vi.mock('ai', () => ({ + generateId: vi.fn(() => 'generated-id'), + generateImage: mockGenerateImage, + generateText: vi.fn(), + streamText: vi.fn(), + embedMany: vi.fn() +})) + +vi.mock('@/presenter', () => ({ + presenter: { + devicePresenter: { + cacheImage: mockCacheImage + } + } +})) + +vi.mock('@/presenter/llmProviderPresenter/aiSdk/providerFactory', () => ({ + createAiSdkProviderContext: mockCreateAiSdkProviderContext +})) + +import { runAiSdkCoreStream } from '@/presenter/llmProviderPresenter/aiSdk/runtime' + +describe('AI SDK runtime', () => { + beforeEach(() => { + vi.clearAllMocks() + mockCreateAiSdkProviderContext.mockReturnValue({ + providerOptionsKey: 'openai', + apiType: 'openai_chat', + model: {}, + imageModel: {}, + endpoint: 'https://image.example.com' + }) + mockGenerateImage.mockResolvedValue({ + images: [ + { + mediaType: 'image/png', + base64: 'ZmFrZQ==' + } + ] + }) + mockCacheImage.mockResolvedValue('cached://image') + }) + + it('builds image prompts from text-like content instead of object stringification', async () => { + const context = { + providerKind: 'openai-compatible', + provider: { + id: 'openai', + apiType: 'openai-compatible' + }, + configPresenter: {}, + defaultHeaders: {}, + shouldUseImageGeneration: () => true + } as any + + const events = [] + for await (const event of runAiSdkCoreStream( + context, + [ + { + role: 'user', + content: [ + { type: 'text', text: 'draw a cat' }, + { type: 'image_url', image_url: { url: 'data:image/png;base64,AAA=' } }, + 'with neon lights', + { text: 'in the rain' }, + { foo: 'ignored' } + ] as any + }, + { + role: 'user', + content: { + text: 'cinematic' + } as any + } + ], + 'gpt-image-1', + { + apiEndpoint: 'image' + } as any, + 0.7, + 1024, + [] + )) { + events.push(event) + } + + expect(mockGenerateImage).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: 'draw a cat\nwith neon lights\nin the rain\n\ncinematic' + }) + ) + expect(events).toEqual([ + { + type: 'image_data', + image_data: { + data: 'cached://image', + mimeType: 'image/png' + } + }, + { + type: 'stop', + stop_reason: 'complete' + } + ]) + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/aiSdkStreamAdapter.test.ts b/test/main/presenter/llmProviderPresenter/aiSdkStreamAdapter.test.ts new file mode 100644 index 000000000..f78a09840 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/aiSdkStreamAdapter.test.ts @@ -0,0 +1,267 @@ +import { describe, expect, it, vi } from 'vitest' +import { adaptAiSdkStream } from '@/presenter/llmProviderPresenter/aiSdk/streamAdapter' +import type { LLMCoreStreamEvent } from '@shared/types/core/llm-events' + +async function collectEvents(parts: any[], options: Parameters[1]) { + async function* stream() { + for (const part of parts) { + yield part + } + } + + const events: LLMCoreStreamEvent[] = [] + for await (const event of adaptAiSdkStream(stream(), options)) { + events.push(event) + } + return events +} + +describe('AI SDK stream adapter', () => { + it('maps native tool streaming events to DeepChat core events', async () => { + const events = await collectEvents( + [ + { + type: 'text-delta', + id: 'text-1', + text: 'hello ', + providerMetadata: { vertex: { thoughtSignature: 'text-signature' } } + }, + { + type: 'reasoning-delta', + id: 'reason-1', + text: 'thinking', + providerMetadata: { vertex: { thoughtSignature: 'reason-signature' } } + }, + { + type: 'tool-input-start', + id: 'call-1', + toolName: 'getWeather', + providerMetadata: { vertex: { thoughtSignature: 'tool-signature' } } + }, + { + type: 'tool-input-delta', + id: 'call-1', + delta: '{"city":"', + providerMetadata: { vertex: { thoughtSignature: 'tool-signature' } } + }, + { type: 'tool-input-delta', id: 'call-1', delta: 'Beijing"}' }, + { type: 'tool-input-end', id: 'call-1' }, + { + type: 'finish', + finishReason: 'tool-calls', + rawFinishReason: 'tool_calls', + totalUsage: { + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + inputTokenDetails: { + cacheReadTokens: 3 + } + } + } + ], + { supportsNativeTools: true } + ) + + expect(events).toEqual([ + { + type: 'text', + content: 'hello ', + provider_options: { vertex: { thoughtSignature: 'text-signature' } } + }, + { + type: 'reasoning', + reasoning_content: 'thinking', + provider_options: { vertex: { thoughtSignature: 'reason-signature' } } + }, + { + type: 'tool_call_start', + tool_call_id: 'call-1', + tool_call_name: 'getWeather', + provider_options: { vertex: { thoughtSignature: 'tool-signature' } } + }, + { + type: 'tool_call_chunk', + tool_call_id: 'call-1', + tool_call_arguments_chunk: '{"city":"', + provider_options: { vertex: { thoughtSignature: 'tool-signature' } } + }, + { type: 'tool_call_chunk', tool_call_id: 'call-1', tool_call_arguments_chunk: 'Beijing"}' }, + { + type: 'tool_call_end', + tool_call_id: 'call-1', + tool_call_arguments_complete: '{"city":"Beijing"}' + }, + { + type: 'usage', + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + cached_tokens: 3 + } + }, + { type: 'stop', stop_reason: 'tool_use' } + ]) + }) + + it('parses legacy function_call blocks from text deltas', async () => { + const events = await collectEvents( + [ + { + type: 'text-delta', + id: 'text-1', + text: 'before {"function_call":{"name":"search","arguments":{"q":"deepchat"}}} after' + }, + { + type: 'finish', + finishReason: 'stop', + rawFinishReason: 'stop', + totalUsage: { + inputTokens: 2, + outputTokens: 4, + totalTokens: 6 + } + } + ], + { supportsNativeTools: false } + ) + + expect(events[0]).toEqual({ type: 'text', content: 'before ' }) + expect(events[1].type).toBe('tool_call_start') + expect(events[2].type).toBe('tool_call_chunk') + expect(events[3].type).toBe('tool_call_end') + expect(events[4]).toEqual({ type: 'text', content: ' after' }) + expect(events[5]).toEqual({ + type: 'usage', + usage: { + prompt_tokens: 2, + completion_tokens: 4, + total_tokens: 6 + } + }) + expect(events[6]).toEqual({ type: 'stop', stop_reason: 'tool_use' }) + }) + + it('maps image file parts and caches the emitted data url', async () => { + const cacheImage = vi.fn().mockResolvedValue('cached://image') + const events = await collectEvents( + [ + { + type: 'file', + file: { + mediaType: 'image/png', + base64: 'ZmFrZQ==' + } + }, + { + type: 'finish', + finishReason: 'stop', + rawFinishReason: 'stop', + totalUsage: { + inputTokens: 1, + outputTokens: 1, + totalTokens: 2 + } + } + ], + { supportsNativeTools: true, cacheImage } + ) + + expect(cacheImage).toHaveBeenCalledWith('data:image/png;base64,ZmFrZQ==') + expect(events[0]).toEqual({ + type: 'image_data', + image_data: { + data: 'cached://image', + mimeType: 'image/png' + } + }) + expect(events[2]).toEqual({ type: 'stop', stop_reason: 'stop_sequence' }) + }) + + it('falls back to the original image data url when image caching fails', async () => { + const cacheImage = vi.fn().mockRejectedValue(new Error('cache failed')) + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + + const events = await collectEvents( + [ + { + type: 'file', + file: { + mediaType: 'image/jpeg', + base64: 'YWJjZA==' + } + }, + { + type: 'finish', + finishReason: 'stop', + rawFinishReason: 'stop', + totalUsage: { + inputTokens: 1, + outputTokens: 1, + totalTokens: 2 + } + } + ], + { supportsNativeTools: true, cacheImage } + ) + + expect(cacheImage).toHaveBeenCalledWith('data:image/jpeg;base64,YWJjZA==') + expect(warnSpy).toHaveBeenCalled() + expect(events[0]).toEqual({ + type: 'image_data', + image_data: { + data: 'data:image/jpeg;base64,YWJjZA==', + mimeType: 'image/jpeg' + } + }) + + warnSpy.mockRestore() + }) + + it('skips file parts with missing or non-image media types', async () => { + const cacheImage = vi.fn() + const events = await collectEvents( + [ + { + type: 'file', + file: { + mediaType: undefined, + base64: 'ZmFrZQ==' + } + }, + { + type: 'file', + file: { + mediaType: 'application/pdf', + base64: 'ZmFrZQ==' + } + }, + { + type: 'finish', + finishReason: 'stop', + rawFinishReason: 'stop', + totalUsage: { + inputTokens: 1, + outputTokens: 1, + totalTokens: 2 + } + } + ], + { supportsNativeTools: true, cacheImage } + ) + + expect(cacheImage).not.toHaveBeenCalled() + expect(events).toEqual([ + { + type: 'usage', + usage: { + prompt_tokens: 1, + completion_tokens: 1, + total_tokens: 2 + } + }, + { type: 'stop', stop_reason: 'stop_sequence' } + ]) + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/aiSdkToolMapper.test.ts b/test/main/presenter/llmProviderPresenter/aiSdkToolMapper.test.ts new file mode 100644 index 000000000..986437275 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/aiSdkToolMapper.test.ts @@ -0,0 +1,192 @@ +import { describe, expect, it } from 'vitest' +import { + mcpToolsToAISDKTools, + normalizeToolInputSchema +} from '@/presenter/llmProviderPresenter/aiSdk/toolMapper' + +describe('AI SDK tool schema normalization', () => { + it('normalizes discriminated union schemas to a top-level object schema', () => { + const schema = { + anyOf: [ + { + type: 'object', + properties: { + action: { type: 'string', const: 'create' }, + content: { type: 'string' } + }, + required: ['action', 'content'], + additionalProperties: false + }, + { + type: 'object', + properties: { + action: { type: 'string', const: 'edit' }, + draftId: { type: 'string' }, + content: { type: 'string' } + }, + required: ['action', 'draftId', 'content'], + additionalProperties: false + } + ], + $schema: 'http://json-schema.org/draft-07/schema#' + } + + const normalized = normalizeToolInputSchema(schema) + + expect(normalized.type).toBe('object') + expect(normalized.properties).toMatchObject({ + action: { type: 'string', enum: ['create', 'edit'] }, + content: { type: 'string' }, + draftId: { type: 'string' } + }) + expect(normalized.required).toEqual(['action', 'content']) + expect(normalized.additionalProperties).toBe(false) + expect(normalized).not.toHaveProperty('anyOf') + expect(normalized).not.toHaveProperty('oneOf') + expect(normalized).not.toHaveProperty('allOf') + }) + + it('converts invalid root schemas into empty object schemas', () => { + const normalized = normalizeToolInputSchema({ + type: 'None' + }) + + expect(normalized).toEqual({ + type: 'object', + properties: {} + }) + }) + + it('drops non-object root fields when falling back to an object schema', () => { + const normalized = normalizeToolInputSchema({ + type: 'array', + items: { + type: 'string' + }, + properties: { + query: { + type: 'string' + } + }, + required: ['query'], + additionalProperties: false + }) + + expect(normalized).toEqual({ + type: 'object', + properties: { + query: { + type: 'string' + } + }, + required: ['query'], + additionalProperties: false + }) + expect(normalized).not.toHaveProperty('items') + }) + + it('uses the union of required keys for allOf branches', () => { + const normalized = normalizeToolInputSchema({ + allOf: [ + { + type: 'object', + properties: { + query: { type: 'string' } + }, + required: ['query'] + }, + { + type: 'object', + properties: { + limit: { type: 'number' } + }, + required: ['limit'] + } + ] + }) + + expect(normalized).toMatchObject({ + type: 'object', + properties: { + query: { type: 'string' }, + limit: { type: 'number' } + }, + required: ['query', 'limit'] + }) + }) + + it('uses a safe dictionary when merging variant properties', () => { + const normalized = normalizeToolInputSchema({ + anyOf: [ + { + type: 'object', + properties: { + __proto__: { + type: 'string' + }, + safe: { + type: 'string' + } + } + }, + { + type: 'object', + properties: { + constructor: { + type: 'string' + }, + safe: { + type: 'string' + } + } + } + ] + }) + + expect(Object.getPrototypeOf(normalized.properties as object)).toBeNull() + expect(normalized.properties).not.toHaveProperty('__proto__') + expect(normalized.properties).not.toHaveProperty('constructor') + expect(normalized.properties).toHaveProperty('safe') + }) + + it('uses a safe dictionary and skips unsafe tool names', () => { + const tools = mcpToolsToAISDKTools([ + { + type: 'function', + function: { + name: '__proto__', + description: 'unsafe', + parameters: { + type: 'object', + properties: {} + } + }, + server: { + name: 'unsafe-server', + icons: '', + description: 'unsafe' + } + }, + { + type: 'function', + function: { + name: 'safe_tool', + description: 'safe', + parameters: { + type: 'object', + properties: {} + } + }, + server: { + name: 'safe-server', + icons: '', + description: 'safe' + } + } + ]) + + expect(Object.getPrototypeOf(tools)).toBeNull() + expect(tools).not.toHaveProperty('__proto__') + expect(tools).toHaveProperty('safe_tool') + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/aihubmixProvider.test.ts b/test/main/presenter/llmProviderPresenter/aihubmixProvider.test.ts new file mode 100644 index 000000000..71e29fb63 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/aihubmixProvider.test.ts @@ -0,0 +1,121 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import type { IConfigPresenter, LLM_PROVIDER, ModelConfig } from '../../../../src/shared/presenter' +import { AiSdkProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' + +const { mockRunAiSdkCoreStream } = vi.hoisted(() => ({ + mockRunAiSdkCoreStream: vi.fn() +})) + +vi.mock('electron', () => ({ + app: { + getName: vi.fn(() => 'DeepChat'), + getVersion: vi.fn(() => '0.0.0-test'), + getPath: vi.fn(() => '/mock/path'), + isReady: vi.fn(() => true), + on: vi.fn() + } +})) + +vi.mock('@/eventbus', () => ({ + eventBus: { + on: vi.fn(), + sendToRenderer: vi.fn(), + sendToMain: vi.fn(), + emit: vi.fn(), + send: vi.fn() + }, + SendTarget: { + ALL_WINDOWS: 'ALL_WINDOWS' + } +})) + +vi.mock('@/events', () => ({ + CONFIG_EVENTS: { + PROXY_RESOLVED: 'PROXY_RESOLVED', + PROVIDER_ATOMIC_UPDATE: 'PROVIDER_ATOMIC_UPDATE', + PROVIDER_BATCH_UPDATE: 'PROVIDER_BATCH_UPDATE', + MODEL_LIST_CHANGED: 'MODEL_LIST_CHANGED' + }, + PROVIDER_DB_EVENTS: { + LOADED: 'LOADED', + UPDATED: 'UPDATED' + }, + NOTIFICATION_EVENTS: { + SHOW_ERROR: 'SHOW_ERROR' + } +})) + +vi.mock('../../../../src/main/presenter/proxyConfig', () => ({ + proxyConfig: { + getProxyUrl: vi.fn().mockReturnValue(null) + } +})) + +vi.mock('../../../../src/main/presenter/llmProviderPresenter/aiSdk', () => ({ + runAiSdkCoreStream: mockRunAiSdkCoreStream, + runAiSdkDimensions: vi.fn(), + runAiSdkEmbeddings: vi.fn(), + runAiSdkGenerateText: vi.fn() +})) + +const createConfigPresenter = (): IConfigPresenter => + ({ + getProviders: vi.fn().mockReturnValue([]), + getProviderModels: vi.fn().mockReturnValue([]), + getCustomModels: vi.fn().mockReturnValue([]), + getModelConfig: vi.fn().mockReturnValue(undefined), + getSetting: vi.fn().mockReturnValue(undefined), + setProviderModels: vi.fn(), + getModelStatus: vi.fn().mockReturnValue(true) + }) as unknown as IConfigPresenter + +const createProvider = (): LLM_PROVIDER => + ({ + id: 'aihubmix', + name: 'Aihubmix', + apiType: 'openai-compatible', + apiKey: 'test-key', + baseUrl: 'https://aihubmix.com/v1', + enable: false + }) as LLM_PROVIDER + +describe('AihubmixProvider AI SDK runtime headers', () => { + beforeEach(() => { + vi.clearAllMocks() + mockRunAiSdkCoreStream.mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { type: 'stop', stop_reason: 'complete' } + } + }) + }) + + it('preserves the DeepChat APP-Code header in AI SDK mode', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + ;(provider as any).isInitialized = true + + for await (const _event of provider.coreStream( + [{ role: 'user', content: 'hello' }], + 'gpt-4o', + { + maxTokens: 1024, + contextLength: 8192, + vision: false, + functionCall: false, + reasoning: false, + type: 'chat' + } as ModelConfig, + 0.7, + 256, + [] + )) { + break + } + + const context = mockRunAiSdkCoreStream.mock.calls.at(-1)?.[0] + + expect(context.defaultHeaders).toMatchObject({ + 'APP-Code': 'SMUE7630', + 'X-Title': 'DeepChat' + }) + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/anthropicProvider.test.ts b/test/main/presenter/llmProviderPresenter/anthropicProvider.test.ts index 825f12644..5acd1f3e9 100644 --- a/test/main/presenter/llmProviderPresenter/anthropicProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/anthropicProvider.test.ts @@ -1,36 +1,26 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' -import type { IConfigPresenter, LLM_PROVIDER, ModelConfig } from '../../../../src/shared/presenter' -import { AnthropicProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/anthropicProvider' +import type { IConfigPresenter, LLM_PROVIDER } from '../../../../src/shared/presenter' +import { AiSdkProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' -const { mockAnthropicConstructor, mockMessagesCreate, mockModelsList, mockGetProxyUrl } = - vi.hoisted(() => ({ - mockAnthropicConstructor: vi.fn(), - mockMessagesCreate: vi.fn().mockResolvedValue({ content: [], usage: undefined }), - mockModelsList: vi.fn().mockResolvedValue({ data: [] }), - mockGetProxyUrl: vi.fn().mockReturnValue(null) - })) +const { mockRunAiSdkCoreStream, mockRunAiSdkGenerateText } = vi.hoisted(() => ({ + mockRunAiSdkCoreStream: vi.fn(), + mockRunAiSdkGenerateText: vi.fn().mockResolvedValue({ content: 'ok' }) +})) -vi.mock('@anthropic-ai/sdk', () => ({ - default: vi.fn().mockImplementation((options: Record) => { - mockAnthropicConstructor(options) - return { - messages: { - create: mockMessagesCreate - }, - models: { - list: mockModelsList - } - } - }) +vi.mock('electron', () => ({ + app: { + getName: vi.fn(() => 'DeepChat'), + getVersion: vi.fn(() => '0.0.0-test'), + getPath: vi.fn(() => '/mock/path'), + isReady: vi.fn(() => true), + on: vi.fn() + } })) vi.mock('@/eventbus', () => ({ eventBus: { on: vi.fn(), - sendToRenderer: vi.fn(), - sendToMain: vi.fn(), - emit: vi.fn(), - send: vi.fn() + sendToRenderer: vi.fn() }, SendTarget: { ALL_WINDOWS: 'ALL_WINDOWS' @@ -39,63 +29,56 @@ vi.mock('@/eventbus', () => ({ vi.mock('@/events', () => ({ CONFIG_EVENTS: { - PROXY_RESOLVED: 'PROXY_RESOLVED', MODEL_LIST_CHANGED: 'MODEL_LIST_CHANGED' + }, + PROVIDER_DB_EVENTS: { + LOADED: 'LOADED', + UPDATED: 'UPDATED' } })) -vi.mock('../../../../src/main/presenter/proxyConfig', () => ({ - proxyConfig: { - getProxyUrl: mockGetProxyUrl - } +vi.mock('../../../../src/main/presenter/llmProviderPresenter/aiSdk', () => ({ + runAiSdkCoreStream: mockRunAiSdkCoreStream, + runAiSdkGenerateText: mockRunAiSdkGenerateText })) -const createConfigPresenter = () => - ({ - getProviderModels: vi.fn().mockReturnValue([]), - getCustomModels: vi.fn().mockReturnValue([]), - getModelConfig: vi.fn().mockReturnValue(undefined), - getDbProviderModels: vi.fn().mockReturnValue([]), - getSetting: vi.fn().mockReturnValue(undefined), - setProviderModels: vi.fn(), - getModelStatus: vi.fn().mockReturnValue(true) - }) as unknown as IConfigPresenter - -const createAsyncStream = (chunks: Array>) => ({ - async *[Symbol.asyncIterator]() { - for (const chunk of chunks) { - yield chunk - } - } -}) - const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ id: 'anthropic', name: 'Anthropic', apiType: 'anthropic', apiKey: 'test-key', baseUrl: 'https://api.anthropic.com', - enable: true, + enable: false, ...overrides }) -describe('AnthropicProvider API-only behavior', () => { +const createConfigPresenter = (): IConfigPresenter => + ({ + getProviderModels: vi.fn().mockReturnValue([]), + getCustomModels: vi.fn().mockReturnValue([]), + getDbProviderModels: vi.fn().mockReturnValue([ + { + id: 'claude-sonnet-4-5-20250929', + name: 'Claude Sonnet 4.5', + group: 'Claude', + contextLength: 200000, + maxTokens: 64000, + vision: true, + functionCall: true, + reasoning: true + } + ]), + getModelConfig: vi.fn().mockReturnValue(undefined), + setProviderModels: vi.fn(), + getModelStatus: vi.fn().mockReturnValue(true) + }) as unknown as IConfigPresenter + +describe('AiSdkProvider anthropic', () => { const originalEnvKey = process.env.ANTHROPIC_API_KEY - const streamModelConfig: ModelConfig = { - maxTokens: 1024, - contextLength: 8192, - vision: false, - functionCall: false, - reasoning: false, - type: 'chat', - conversationId: 'session-1' - } beforeEach(() => { vi.clearAllMocks() - mockMessagesCreate.mockResolvedValue({ content: [], usage: undefined }) - mockModelsList.mockResolvedValue({ data: [] }) - mockGetProxyUrl.mockReturnValue(null) + mockRunAiSdkGenerateText.mockResolvedValue({ content: 'ok' }) delete process.env.ANTHROPIC_API_KEY }) @@ -104,199 +87,84 @@ describe('AnthropicProvider API-only behavior', () => { delete process.env.ANTHROPIC_API_KEY return } - process.env.ANTHROPIC_API_KEY = originalEnvKey - }) - it('initializes with env API key when provider apiKey is empty', () => { - process.env.ANTHROPIC_API_KEY = 'env-key' - - new AnthropicProvider(createProvider({ apiKey: '' }), createConfigPresenter()) - - expect(mockAnthropicConstructor).toHaveBeenCalledWith( - expect.objectContaining({ - apiKey: 'env-key', - baseURL: 'https://api.anthropic.com' - }) - ) + process.env.ANTHROPIC_API_KEY = originalEnvKey }) - it('does not initialize when no API key is available', async () => { - const provider = new AnthropicProvider( - createProvider({ apiKey: '', enable: true }), + it('fails fast when no Anthropic API key is available', async () => { + const provider = new AiSdkProvider( + createProvider({ + apiKey: '' + }), createConfigPresenter() ) - expect(mockAnthropicConstructor).not.toHaveBeenCalled() await expect(provider.check()).resolves.toEqual({ isOk: false, - errorMsg: 'Anthropic SDK not initialized' + errorMsg: 'Missing API key' }) + expect(mockRunAiSdkGenerateText).not.toHaveBeenCalled() }) - it('does not send Claude Code system prompt or oauth beta header', async () => { - const provider = new AnthropicProvider( - createProvider({ enable: false }), + it('uses the AI SDK runtime for provider health checks', async () => { + process.env.ANTHROPIC_API_KEY = 'env-key' + const provider = new AiSdkProvider( + createProvider({ + apiKey: '' + }), createConfigPresenter() ) - ;(provider as any).anthropic = { - messages: { create: mockMessagesCreate }, - models: { list: mockModelsList } - } + ;(provider as any).isInitialized = true - const headers = (provider as any).buildAnthropicApiKeyHeaders() - expect(headers).toEqual( - expect.objectContaining({ - 'Content-Type': 'application/json', - 'anthropic-version': '2023-06-01', - 'x-api-key': 'test-key' - }) - ) - expect(headers).not.toHaveProperty('anthropic-beta') - - await provider.check() - - const request = mockMessagesCreate.mock.calls.at(-1)?.[0] - expect(request).toEqual( + await expect(provider.check()).resolves.toEqual({ + isOk: true, + errorMsg: null + }) + expect(mockRunAiSdkGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - model: 'claude-sonnet-4-5-20250929', - max_tokens: 10 - }) + providerKind: 'anthropic' + }), + [{ role: 'user', content: 'Hello' }], + 'claude-sonnet-4-5-20250929', + expect.any(Object), + 0.2, + 16 ) - expect(request).not.toHaveProperty('system') }) - it('passes through the caller system prompt for text generation', async () => { - mockMessagesCreate.mockResolvedValue({ - content: [{ type: 'text', text: 'hello' }], - usage: undefined - }) - - const provider = new AnthropicProvider( - createProvider({ enable: false }), - createConfigPresenter() - ) - ;(provider as any).anthropic = { - messages: { create: mockMessagesCreate }, - models: { list: mockModelsList } - } + it('passes system prompts through the AI SDK text path', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + ;(provider as any).isInitialized = true await provider.generateText('hi', 'claude-sonnet-4-5-20250929', 0.2, 32, 'Real system prompt') - expect(mockMessagesCreate).toHaveBeenLastCalledWith( + expect(mockRunAiSdkGenerateText).toHaveBeenLastCalledWith( expect.objectContaining({ - system: 'Real system prompt', - messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }] - }) - ) - }) - - it('adds top-level cache_control for Claude streaming requests', async () => { - mockMessagesCreate.mockResolvedValue( - createAsyncStream([ - { - type: 'message_start', - message: { - usage: { - input_tokens: 10, - output_tokens: 2 - } - } - }, - { - type: 'content_block_delta', - delta: { - type: 'text_delta', - text: 'hello' - } - } - ]) - ) - - const provider = new AnthropicProvider( - createProvider({ enable: false }), - createConfigPresenter() - ) - ;(provider as any).anthropic = { - messages: { create: mockMessagesCreate }, - models: { list: mockModelsList } - } - - const events = [] - for await (const event of provider.coreStream( - [{ role: 'user', content: 'hi' }], + providerKind: 'anthropic' + }), + [ + { role: 'system', content: 'Real system prompt' }, + { role: 'user', content: 'hi' } + ], 'claude-sonnet-4-5-20250929', - streamModelConfig, + expect.any(Object), 0.2, - 64, - [] - )) { - events.push(event) - } - - const request = mockMessagesCreate.mock.calls.at(-1)?.[0] - expect(request).toMatchObject({ - cache_control: { - type: 'ephemeral' - } - }) - expect(events.some((event) => event.type === 'text')).toBe(true) - }) - - it('normalizes cache read and cache write usage metadata for streams', async () => { - mockMessagesCreate.mockResolvedValue( - createAsyncStream([ - { - type: 'message_start', - message: { - usage: { - input_tokens: 10, - output_tokens: 5, - cache_read_input_tokens: 20, - cache_creation_input_tokens: 30 - } - } - }, - { - type: 'content_block_delta', - delta: { - type: 'text_delta', - text: 'hello' - } - } - ]) - ) - - const provider = new AnthropicProvider( - createProvider({ enable: false }), - createConfigPresenter() + 32 ) - ;(provider as any).anthropic = { - messages: { create: mockMessagesCreate }, - models: { list: mockModelsList } - } + }) - const events = [] - for await (const event of provider.coreStream( - [{ role: 'user', content: 'hi' }], - 'claude-sonnet-4-5-20250929', - streamModelConfig, - 0.2, - 64, - [] - )) { - events.push(event) - } + it('reads model metadata from the provider database snapshot', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + const models = await (provider as any).fetchProviderModels() - const usageEvent = events.find((event) => event.type === 'usage') - expect(usageEvent).toMatchObject({ - type: 'usage', - usage: { - prompt_tokens: 60, - completion_tokens: 5, - total_tokens: 65, - cached_tokens: 20, - cache_write_tokens: 30 - } - }) + expect(models).toEqual([ + expect.objectContaining({ + id: 'claude-sonnet-4-5-20250929', + providerId: 'anthropic', + vision: true, + functionCall: true, + reasoning: true + }) + ]) }) }) diff --git a/test/main/presenter/llmProviderPresenter/awsBedrockProvider.test.ts b/test/main/presenter/llmProviderPresenter/awsBedrockProvider.test.ts index bad73be9e..75ff22530 100644 --- a/test/main/presenter/llmProviderPresenter/awsBedrockProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/awsBedrockProvider.test.ts @@ -1,58 +1,81 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' -import type { - AWS_BEDROCK_PROVIDER, - ChatMessage, - IConfigPresenter, - ModelConfig -} from '../../../../src/shared/presenter' -import { AwsBedrockProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/awsBedrockProvider' - -const { mockBedrockRuntimeSend, mockGetProxyUrl } = vi.hoisted(() => ({ - mockBedrockRuntimeSend: vi.fn(), - mockGetProxyUrl: vi.fn().mockReturnValue(null) -})) +import type { AWS_BEDROCK_PROVIDER, IConfigPresenter } from '../../../../src/shared/presenter' +import { AiSdkProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' -vi.mock('@aws-sdk/client-bedrock', () => ({ - BedrockClient: vi.fn(), - ListFoundationModelsCommand: class ListFoundationModelsCommand { - input: unknown +const { mockBedrockSend, mockRunAiSdkCoreStream, mockRunAiSdkGenerateText } = vi.hoisted(() => ({ + mockBedrockSend: vi.fn(), + mockRunAiSdkCoreStream: vi.fn(), + mockRunAiSdkGenerateText: vi.fn().mockResolvedValue({ content: 'ok' }) +})) - constructor(input: unknown) { - this.input = input - } +vi.mock('electron', () => ({ + app: { + getName: vi.fn(() => 'DeepChat'), + getVersion: vi.fn(() => '0.0.0-test'), + getPath: vi.fn(() => '/mock/path'), + isReady: vi.fn(() => true), + on: vi.fn() } })) -vi.mock('@aws-sdk/client-bedrock-runtime', () => ({ - BedrockRuntimeClient: vi.fn(), - InvokeModelCommand: class InvokeModelCommand { - input: Record +vi.mock('@/eventbus', () => ({ + eventBus: { + on: vi.fn(), + sendToRenderer: vi.fn() + }, + SendTarget: { + ALL_WINDOWS: 'ALL_WINDOWS' + } +})) - constructor(input: Record) { - this.input = input - } +vi.mock('@/events', () => ({ + CONFIG_EVENTS: { + MODEL_LIST_CHANGED: 'MODEL_LIST_CHANGED' }, - InvokeModelWithResponseStreamCommand: class InvokeModelWithResponseStreamCommand { - input: Record + PROVIDER_DB_EVENTS: { + LOADED: 'LOADED', + UPDATED: 'UPDATED' + } +})) - constructor(input: Record) { +vi.mock('@aws-sdk/client-bedrock', () => ({ + BedrockClient: vi.fn().mockImplementation(() => ({ + config: { + region: vi.fn().mockResolvedValue('us-east-1') + }, + send: mockBedrockSend + })), + ListFoundationModelsCommand: class ListFoundationModelsCommand { + input: unknown + + constructor(input: unknown) { this.input = input } } })) -vi.mock('../../../../src/main/presenter/proxyConfig', () => ({ - proxyConfig: { - getProxyUrl: mockGetProxyUrl - } +vi.mock('../../../../src/main/presenter/llmProviderPresenter/aiSdk', () => ({ + runAiSdkCoreStream: mockRunAiSdkCoreStream, + runAiSdkGenerateText: mockRunAiSdkGenerateText })) -const createConfigPresenter = () => +const createConfigPresenter = (): IConfigPresenter => ({ getProviderModels: vi.fn().mockReturnValue([]), getCustomModels: vi.fn().mockReturnValue([]), + getDbProviderModels: vi.fn().mockReturnValue([ + { + id: 'anthropic.claude-3-5-sonnet-20240620-v1:0', + name: 'Claude 3.5 Sonnet', + group: 'Bedrock Claude', + contextLength: 200000, + maxTokens: 64000, + vision: false, + functionCall: false, + reasoning: false + } + ]), getModelConfig: vi.fn().mockReturnValue(undefined), - getDbProviderModels: vi.fn().mockReturnValue([]), getSetting: vi.fn().mockReturnValue(undefined), setProviderModels: vi.fn(), getModelStatus: vi.fn().mockReturnValue(true) @@ -71,137 +94,80 @@ const createProvider = (overrides?: Partial): AWS_BEDROCK_ ...overrides }) -const createAsyncStream = (chunks: Array>) => ({ - async *[Symbol.asyncIterator]() { - for (const chunk of chunks) { - yield chunk - } - } -}) - -const createBedrockChunk = (chunk: Record) => ({ - chunk: { - bytes: new TextEncoder().encode(JSON.stringify(chunk)) - } -}) - -describe('AwsBedrockProvider prompt cache behavior', () => { - const modelConfig: ModelConfig = { - maxTokens: 1024, - contextLength: 8192, - vision: false, - functionCall: false, - reasoning: false, - type: 'chat', - conversationId: 'session-1' - } - - const messages: ChatMessage[] = [ - { role: 'system', content: 'system prompt' }, - { role: 'user', content: 'history' }, - { role: 'assistant', content: 'stable reply' }, - { role: 'user', content: 'latest question' } - ] - +describe('AiSdkProvider aws-bedrock', () => { beforeEach(() => { vi.clearAllMocks() - mockGetProxyUrl.mockReturnValue(null) - mockBedrockRuntimeSend.mockResolvedValue({ - body: Promise.resolve( - createAsyncStream([ - createBedrockChunk({ - type: 'message_start', - message: { - usage: { - input_tokens: 10, - output_tokens: 5, - cacheReadInputTokens: 20, - cacheWriteInputTokens: 30 - } - } - }), - createBedrockChunk({ - type: 'content_block_delta', - delta: { - type: 'text_delta', - text: 'hello' - } - }) - ]) - ) - }) + mockRunAiSdkGenerateText.mockResolvedValue({ content: 'ok' }) }) - it('adds an explicit cache_control breakpoint before the latest user turn', async () => { - const provider = new AwsBedrockProvider(createProvider(), createConfigPresenter()) - ;(provider as any).bedrockRuntime = { - send: mockBedrockRuntimeSend - } - - const events = [] - for await (const event of provider.coreStream( - messages, - 'anthropic.claude-3-5-sonnet-20240620-v1:0', - modelConfig, - 0.2, - 64, - [] - )) { - events.push(event) - } - - const command = mockBedrockRuntimeSend.mock.calls[0][0] as { - input: { - body: string - } - } - const payload = JSON.parse(command.input.body) + it('fails fast when credentials are missing', async () => { + const provider = new AiSdkProvider( + createProvider({ + credential: undefined + }), + createConfigPresenter() + ) + + await expect(provider.check()).resolves.toEqual({ + isOk: false, + errorMsg: 'Missing AWS Bedrock credentials' + }) + expect(mockRunAiSdkGenerateText).not.toHaveBeenCalled() + }) - expect(payload).not.toHaveProperty('cache_control') - expect(payload.system).toBe('system prompt\n') - expect(payload.messages[1]).toMatchObject({ - role: 'assistant', - content: [ + it('maps active Claude models from the Bedrock catalog', async () => { + mockBedrockSend.mockResolvedValue({ + modelSummaries: [ { - type: 'text', - text: 'stable reply', - cache_control: { - type: 'ephemeral' - } + modelId: 'anthropic.claude-3-5-sonnet-20240620-v1:0', + modelLifecycle: { status: 'ACTIVE' }, + inferenceTypesSupported: ['ON_DEMAND'] } ] }) - expect(events.some((event) => event.type === 'text')).toBe(true) + + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + const models = await provider.fetchModels() + + expect(models).toEqual([ + expect.objectContaining({ + id: 'anthropic.claude-3-5-sonnet-20240620-v1:0', + providerId: 'aws-bedrock' + }) + ]) }) - it('normalizes cache read and cache write usage fields from Bedrock streams', async () => { - const provider = new AwsBedrockProvider(createProvider(), createConfigPresenter()) - ;(provider as any).bedrockRuntime = { - send: mockBedrockRuntimeSend - } + it('falls back to the provider DB snapshot when the Bedrock catalog lookup fails', async () => { + mockBedrockSend.mockRejectedValue(new Error('catalog unavailable')) - const events = [] - for await (const event of provider.coreStream( - messages, - 'anthropic.claude-3-5-sonnet-20240620-v1:0', - modelConfig, - 0.2, - 64, - [] - )) { - events.push(event) - } + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + const models = await provider.fetchModels() - const usageEvent = events.find((event) => event.type === 'usage') - expect(usageEvent).toMatchObject({ - type: 'usage', - usage: { - prompt_tokens: 60, - completion_tokens: 5, - total_tokens: 65, - cached_tokens: 20, - cache_write_tokens: 30 - } + expect(models).toEqual([ + expect.objectContaining({ + id: 'anthropic.claude-3-5-sonnet-20240620-v1:0', + group: 'Bedrock Claude' + }) + ]) + }) + + it('uses the AI SDK runtime for health checks', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + ;(provider as any).isInitialized = true + + await expect(provider.check()).resolves.toEqual({ + isOk: true, + errorMsg: null }) + expect(mockRunAiSdkGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + providerKind: 'aws-bedrock' + }), + [{ role: 'user', content: 'Hi' }], + 'anthropic.claude-3-5-sonnet-20240620-v1:0', + expect.any(Object), + 0.2, + 16 + ) }) }) diff --git a/test/main/presenter/llmProviderPresenter/backgroundModelSync.test.ts b/test/main/presenter/llmProviderPresenter/backgroundModelSync.test.ts index 334074000..5c599d383 100644 --- a/test/main/presenter/llmProviderPresenter/backgroundModelSync.test.ts +++ b/test/main/presenter/llmProviderPresenter/backgroundModelSync.test.ts @@ -5,7 +5,7 @@ import type { LLM_PROVIDER } from '../../../../src/shared/presenter' import { LLMProviderPresenter } from '../../../../src/main/presenter/llmProviderPresenter' -import { OpenAICompatibleProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider' +import { AiSdkProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' const eventState = vi.hoisted(() => ({ handlers: new Map void>>() @@ -48,24 +48,6 @@ vi.mock('electron', () => ({ } })) -vi.mock('openai', () => { - class MockOpenAI { - chat = { - completions: { - create: vi.fn() - } - } - models = { - list: mockModelsList - } - } - - return { - default: MockOpenAI, - AzureOpenAI: MockOpenAI - } -}) - vi.mock('@/presenter', () => ({ presenter: { devicePresenter: { @@ -190,7 +172,7 @@ describe('LLMProviderPresenter background model sync', () => { it('does not trigger an extra startup refresh for non DB-backed providers', async () => { const refreshSpy = vi - .spyOn(OpenAICompatibleProvider.prototype, 'refreshModels') + .spyOn(AiSdkProvider.prototype, 'refreshModels') .mockResolvedValue(undefined) const presenter = new LLMProviderPresenter(createConfigPresenter(), mockSqlitePresenter) @@ -203,7 +185,7 @@ describe('LLMProviderPresenter background model sync', () => { it('re-syncs enabled DB-backed provider models when provider-db updates', async () => { const refreshSpy = vi - .spyOn(OpenAICompatibleProvider.prototype, 'refreshModels') + .spyOn(AiSdkProvider.prototype, 'refreshModels') .mockResolvedValue(undefined) new LLMProviderPresenter( @@ -231,7 +213,7 @@ describe('LLMProviderPresenter background model sync', () => { it('ignores provider-db updates for providers that do not use the provider DB catalog', async () => { const refreshSpy = vi - .spyOn(OpenAICompatibleProvider.prototype, 'refreshModels') + .spyOn(AiSdkProvider.prototype, 'refreshModels') .mockResolvedValue(undefined) new LLMProviderPresenter(createConfigPresenter(), mockSqlitePresenter) @@ -248,13 +230,11 @@ describe('LLMProviderPresenter background model sync', () => { it('coalesces duplicate background refreshes for the same provider', async () => { let resolveRefresh: (() => void) | null = null - const refreshSpy = vi - .spyOn(OpenAICompatibleProvider.prototype, 'refreshModels') - .mockReturnValue( - new Promise((resolve) => { - resolveRefresh = resolve - }) - ) + const refreshSpy = vi.spyOn(AiSdkProvider.prototype, 'refreshModels').mockReturnValue( + new Promise((resolve) => { + resolveRefresh = resolve + }) + ) new LLMProviderPresenter( createConfigPresenter( @@ -304,7 +284,7 @@ describe('LLMProviderPresenter background model sync', () => { }) const configPresenter = createConfigPresenter(provider) const refreshSpy = vi - .spyOn(OpenAICompatibleProvider.prototype, 'refreshModels') + .spyOn(AiSdkProvider.prototype, 'refreshModels') .mockResolvedValue(undefined) const presenter = new LLMProviderPresenter(configPresenter, mockSqlitePresenter) @@ -332,7 +312,7 @@ describe('LLMProviderPresenter background model sync', () => { message: 'network down' }) const refreshSpy = vi - .spyOn(OpenAICompatibleProvider.prototype, 'refreshModels') + .spyOn(AiSdkProvider.prototype, 'refreshModels') .mockResolvedValue(undefined) const presenter = new LLMProviderPresenter(configPresenter, mockSqlitePresenter) @@ -346,7 +326,7 @@ describe('LLMProviderPresenter background model sync', () => { it('does not refresh provider DB for providers that manage models themselves', async () => { const configPresenter = createConfigPresenter() const refreshSpy = vi - .spyOn(OpenAICompatibleProvider.prototype, 'refreshModels') + .spyOn(AiSdkProvider.prototype, 'refreshModels') .mockResolvedValue(undefined) const presenter = new LLMProviderPresenter(configPresenter, mockSqlitePresenter) @@ -358,7 +338,7 @@ describe('LLMProviderPresenter background model sync', () => { it('logs provider-db refresh failures without blocking presenter initialization', async () => { const refreshSpy = vi - .spyOn(OpenAICompatibleProvider.prototype, 'refreshModels') + .spyOn(AiSdkProvider.prototype, 'refreshModels') .mockRejectedValue(new Error('refresh failed')) const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) diff --git a/test/main/presenter/llmProviderPresenter/baseProvider.test.ts b/test/main/presenter/llmProviderPresenter/baseProvider.test.ts new file mode 100644 index 000000000..5ba53dbac --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/baseProvider.test.ts @@ -0,0 +1,219 @@ +import { describe, expect, it, vi } from 'vitest' +import type { + ChatMessage, + IConfigPresenter, + LLM_PROVIDER, + LLMResponse, + MCPToolDefinition, + ModelConfig +} from '../../../../src/shared/presenter' +import { BaseLLMProvider } from '../../../../src/main/presenter/llmProviderPresenter/baseProvider' + +vi.mock('@/eventbus', () => ({ + eventBus: { + on: vi.fn(), + sendToRenderer: vi.fn(), + sendToMain: vi.fn(), + emit: vi.fn(), + send: vi.fn() + }, + SendTarget: { + ALL_WINDOWS: 'ALL_WINDOWS' + } +})) + +vi.mock('@/events', () => ({ + CONFIG_EVENTS: { + MODEL_LIST_CHANGED: 'MODEL_LIST_CHANGED' + } +})) + +class TestProvider extends BaseLLMProvider { + constructor(configPresenter: IConfigPresenter) { + super( + { + id: 'test-provider', + name: 'Test Provider', + enable: true, + apiKey: 'test-key', + apiHost: '', + apiVersion: '', + models: [] + } as unknown as LLM_PROVIDER, + configPresenter + ) + } + + public renderToolsXml(tools: MCPToolDefinition[]): string { + return this.convertToolsToXml(tools) + } + + public onProxyResolved(): void {} + + public async check(): Promise<{ isOk: boolean; errorMsg: string | null }> { + return { isOk: true, errorMsg: null } + } + + public async summaryTitles(_messages: ChatMessage[], _modelId: string): Promise { + return 'summary' + } + + public async completions( + _messages: ChatMessage[], + _modelId: string, + _temperature?: number, + _maxTokens?: number, + _tools?: MCPToolDefinition[] + ): Promise { + return { content: 'ok' } + } + + public async summaries( + _text: string, + _modelId: string, + _temperature?: number, + _maxTokens?: number + ): Promise { + return { content: 'ok' } + } + + public async generateText( + _prompt: string, + _modelId: string, + _temperature?: number, + _maxTokens?: number + ): Promise { + return { content: 'ok' } + } + + public async *coreStream( + _messages: ChatMessage[], + _modelId: string, + _modelConfig: ModelConfig, + _temperature: number, + _maxTokens: number, + _tools: MCPToolDefinition[] + ) { + return + } + + protected async fetchProviderModels() { + return [] + } +} + +describe('BaseLLMProvider tool XML conversion', () => { + const configPresenter = { + getProviderModels: vi.fn().mockReturnValue([]), + getCustomModels: vi.fn().mockReturnValue([]), + getLanguage: vi.fn().mockReturnValue('zh-CN'), + setProviderModels: vi.fn(), + getModelStatus: vi.fn().mockReturnValue(false), + updateCustomModel: vi.fn() + } as unknown as IConfigPresenter + + it('normalizes discriminated union tool schemas before building XML', () => { + const provider = new TestProvider(configPresenter) + const tools: MCPToolDefinition[] = [ + { + type: 'function', + function: { + name: 'skill_manage', + description: 'Manage draft skills', + parameters: { + anyOf: [ + { + type: 'object', + properties: { + action: { type: 'string', const: 'create' }, + content: { type: 'string', description: 'Draft content' } + }, + required: ['action', 'content'], + additionalProperties: false + }, + { + type: 'object', + properties: { + action: { type: 'string', const: 'edit' }, + draftId: { type: 'string', description: 'Draft ID' }, + content: { type: 'string', description: 'Draft content' } + }, + required: ['action', 'draftId', 'content'], + additionalProperties: false + } + ] + } as unknown as MCPToolDefinition['function']['parameters'] + }, + server: { + name: 'deepchat', + icons: 'tool', + description: 'DeepChat tools' + } + } + ] + + const xml = provider.renderToolsXml(tools) + + expect(xml).toContain('') + expect(xml).toContain('') + expect(xml).toContain( + '' + ) + expect(xml).toContain( + '' + ) + }) + + it('keeps tools without properties renderable', () => { + const provider = new TestProvider(configPresenter) + const xml = provider.renderToolsXml([ + { + type: 'function', + function: { + name: 'noop', + description: 'No arguments tool', + parameters: { + type: 'object', + properties: {} + } + }, + server: { + name: 'deepchat', + icons: 'tool', + description: 'DeepChat tools' + } + } + ]) + + expect(xml).toContain('') + }) + + it('escapes XML-sensitive characters in parameter descriptions', () => { + const provider = new TestProvider(configPresenter) + const xml = provider.renderToolsXml([ + { + type: 'function', + function: { + name: 'escape_test', + description: 'Escape test', + parameters: { + type: 'object', + properties: { + query: { + type: 'string', + description: 'He said "hi" & used > output' + } + } + } + }, + server: { + name: 'deepchat', + icons: 'tool', + description: 'DeepChat tools' + } + } + ]) + + expect(xml).toContain('description="He said "hi" & used <tag> > output"') + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/doubaoProvider.test.ts b/test/main/presenter/llmProviderPresenter/doubaoProvider.test.ts index 2c8f63013..10ba5e2f8 100644 --- a/test/main/presenter/llmProviderPresenter/doubaoProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/doubaoProvider.test.ts @@ -1,39 +1,18 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' -import type { IConfigPresenter, LLM_PROVIDER, ModelConfig } from '../../../../src/shared/presenter' -import { DoubaoProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/doubaoProvider' +import type { IConfigPresenter, LLM_PROVIDER } from '../../../../src/shared/presenter' +import { AiSdkProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' -const { mockChatCompletionsCreate, mockGetProvider, mockGetModel, mockGetProxyUrl } = vi.hoisted( - () => ({ - mockChatCompletionsCreate: vi.fn(), - mockGetProvider: vi.fn(), - mockGetModel: vi.fn(), - mockGetProxyUrl: vi.fn().mockReturnValue(null) - }) -) - -vi.mock('openai', () => { - class MockOpenAI { - chat = { - completions: { - create: mockChatCompletionsCreate - } - } - models = { - list: vi.fn().mockResolvedValue({ data: [] }) - } - } - - return { - default: MockOpenAI, - AzureOpenAI: MockOpenAI - } -}) +const { mockGetProvider } = vi.hoisted(() => ({ + mockGetProvider: vi.fn() +})) -vi.mock('@/presenter', () => ({ - presenter: { - devicePresenter: { - cacheImage: vi.fn() - } +vi.mock('electron', () => ({ + app: { + getName: vi.fn(() => 'DeepChat'), + getVersion: vi.fn(() => '0.0.0-test'), + getPath: vi.fn(() => '/mock/path'), + isReady: vi.fn(() => true), + on: vi.fn() } })) @@ -54,6 +33,10 @@ vi.mock('@/events', () => ({ CONFIG_EVENTS: { MODEL_LIST_CHANGED: 'MODEL_LIST_CHANGED' }, + PROVIDER_DB_EVENTS: { + LOADED: 'LOADED', + UPDATED: 'UPDATED' + }, NOTIFICATION_EVENTS: { SHOW_ERROR: 'SHOW_ERROR' } @@ -61,33 +44,25 @@ vi.mock('@/events', () => ({ vi.mock('../../../../src/main/presenter/proxyConfig', () => ({ proxyConfig: { - getProxyUrl: mockGetProxyUrl + getProxyUrl: vi.fn().mockReturnValue(null) } })) vi.mock('../../../../src/main/presenter/configPresenter/providerDbLoader', () => ({ providerDbLoader: { + getDb: vi.fn().mockReturnValue(null), getProvider: mockGetProvider, - getModel: mockGetModel + getModel: vi.fn() } })) -vi.mock('../../../../src/main/presenter/configPresenter/modelCapabilities', () => ({ - modelCapabilities: { - supportsReasoningEffort: vi.fn().mockReturnValue(false), - supportsVerbosity: vi.fn().mockReturnValue(false), - supportsReasoning: vi.fn().mockReturnValue(false) - } +vi.mock('../../../../src/main/presenter/llmProviderPresenter/aiSdk', () => ({ + runAiSdkCoreStream: vi.fn(), + runAiSdkDimensions: vi.fn(), + runAiSdkEmbeddings: vi.fn(), + runAiSdkGenerateText: vi.fn() })) -const createAsyncStream = (chunks: Array>) => ({ - async *[Symbol.asyncIterator]() { - for (const chunk of chunks) { - yield chunk - } - } -}) - const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ id: 'doubao', name: 'Doubao', @@ -98,7 +73,7 @@ const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ ...overrides }) -const createConfigPresenter = () => +const createConfigPresenter = (): IConfigPresenter => ({ getProviderModels: vi.fn().mockReturnValue([]), getCustomModels: vi.fn().mockReturnValue([]), @@ -108,33 +83,9 @@ const createConfigPresenter = () => getModelStatus: vi.fn().mockReturnValue(true) }) as unknown as IConfigPresenter -describe('DoubaoProvider', () => { - const modelConfig: ModelConfig = { - maxTokens: 1024, - contextLength: 8192, - vision: true, - functionCall: true, - reasoning: true, - type: 'chat' - } - +describe('AiSdkProvider doubao', () => { beforeEach(() => { vi.clearAllMocks() - mockGetProxyUrl.mockReturnValue(null) - mockChatCompletionsCreate.mockResolvedValue( - createAsyncStream([ - { - choices: [ - { - delta: { - content: 'ok' - }, - finish_reason: 'stop' - } - ] - } - ]) - ) }) it('maps doubao catalog entries into provider models', async () => { @@ -161,8 +112,8 @@ describe('DoubaoProvider', () => { ] }) - const provider = new DoubaoProvider(createProvider(), createConfigPresenter()) - const models = await (provider as any).fetchOpenAIModels() + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + const models = await provider.fetchModels() expect(models).toEqual([ expect.objectContaining({ @@ -175,46 +126,4 @@ describe('DoubaoProvider', () => { }) ]) }) - - it('adds Doubao thinking parameter for reasoning models based on metadata notes', async () => { - mockGetProvider.mockReturnValue({ - id: 'doubao', - name: 'Doubao', - models: [] - }) - mockGetModel.mockReturnValue({ - id: 'doubao-seed-2.0-pro', - extra_capabilities: { - reasoning: { - notes: ['doubao-thinking-parameter'] - } - } - }) - - const provider = new DoubaoProvider(createProvider(), createConfigPresenter()) - ;(provider as any).isInitialized = true - - const events = [] - for await (const event of provider.coreStream( - [{ role: 'user', content: 'hello' }], - 'doubao-seed-2.0-pro', - modelConfig, - 0.7, - 1024, - [] - )) { - events.push(event) - } - - expect(events.some((event) => event.type === 'text')).toBe(true) - expect(mockChatCompletionsCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: 'doubao-seed-2.0-pro', - thinking: { - type: 'enabled' - } - }), - undefined - ) - }) }) diff --git a/test/main/presenter/llmProviderPresenter/legacyFunctionCallMiddleware.test.ts b/test/main/presenter/llmProviderPresenter/legacyFunctionCallMiddleware.test.ts new file mode 100644 index 000000000..00e7e5213 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/legacyFunctionCallMiddleware.test.ts @@ -0,0 +1,50 @@ +import { describe, expect, it } from 'vitest' +import { applyLegacyFunctionCallPrompt } from '@/presenter/llmProviderPresenter/aiSdk/middlewares/legacyFunctionCallMiddleware' + +describe('legacyFunctionCallMiddleware', () => { + it('preserves message fields when converting non-array user content', () => { + const messages = [ + { + role: 'user' as const, + content: 'hello', + providerMetadata: { + vertex: { + cachedContent: 'cache-key' + } + } + } as any + ] + + const result = applyLegacyFunctionCallPrompt( + messages, + [ + { + type: 'function', + function: { + name: 'search', + description: 'Search', + parameters: { type: 'object', properties: {} } + } + } as any + ], + () => 'tool prompt' + ) + + expect(result).toEqual([ + { + role: 'user', + providerMetadata: { + vertex: { + cachedContent: 'cache-key' + } + }, + content: [ + { + type: 'text', + text: 'hello\n\ntool prompt' + } + ] + } + ]) + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/newApiProvider.test.ts b/test/main/presenter/llmProviderPresenter/newApiProvider.test.ts index b27612daa..eb6198f86 100644 --- a/test/main/presenter/llmProviderPresenter/newApiProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/newApiProvider.test.ts @@ -1,28 +1,10 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' -import type { - ChatMessage, - IConfigPresenter, - LLMCoreStreamEvent, - LLM_PROVIDER, - ModelConfig -} from '../../../../src/shared/presenter' +import type { IConfigPresenter, LLM_PROVIDER, ModelConfig } from '../../../../src/shared/presenter' import { ApiEndpointType, ModelType } from '../../../../src/shared/model' -import { NewApiProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/newApiProvider' +import { AiSdkProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' -const { - mockOpenAIChatCreate, - mockOpenAIResponsesCreate, - mockOpenAIModelsList, - mockAnthropicModelsList, - mockAnthropicMessagesCreate, - mockGetProxyUrl -} = vi.hoisted(() => ({ - mockOpenAIChatCreate: vi.fn(), - mockOpenAIResponsesCreate: vi.fn(), - mockOpenAIModelsList: vi.fn().mockResolvedValue({ data: [] }), - mockAnthropicModelsList: vi.fn().mockResolvedValue({ data: [] }), - mockAnthropicMessagesCreate: vi.fn().mockResolvedValue({}), - mockGetProxyUrl: vi.fn().mockReturnValue(null) +const { mockRunAiSdkCoreStream } = vi.hoisted(() => ({ + mockRunAiSdkCoreStream: vi.fn() })) vi.mock('electron', () => ({ @@ -32,109 +14,9 @@ vi.mock('electron', () => ({ getPath: vi.fn(() => '/mock/path'), isReady: vi.fn(() => true), on: vi.fn() - }, - session: {}, - ipcMain: { - on: vi.fn(), - handle: vi.fn(), - removeHandler: vi.fn() - }, - BrowserWindow: vi.fn(() => ({ - loadURL: vi.fn(), - loadFile: vi.fn(), - on: vi.fn(), - webContents: { send: vi.fn(), on: vi.fn(), isDestroyed: vi.fn(() => false) }, - isDestroyed: vi.fn(() => false), - close: vi.fn(), - show: vi.fn(), - hide: vi.fn() - })), - dialog: { - showOpenDialog: vi.fn() - }, - shell: { - openExternal: vi.fn() } })) -vi.mock('openai', () => { - class MockOpenAI { - chat = { - completions: { - create: mockOpenAIChatCreate - } - } - responses = { - create: mockOpenAIResponsesCreate - } - models = { - list: mockOpenAIModelsList - } - } - - return { - default: MockOpenAI, - AzureOpenAI: MockOpenAI - } -}) - -vi.mock('@anthropic-ai/sdk', () => { - class MockAnthropic { - models = { - list: mockAnthropicModelsList - } - messages = { - create: mockAnthropicMessagesCreate - } - - constructor(_: Record) {} - } - - return { - default: MockAnthropic - } -}) - -vi.mock('@google/genai', () => ({ - Content: class {}, - GoogleGenAI: class MockGoogleGenAI { - models = { - list: vi.fn().mockResolvedValue([]), - generateContent: vi.fn().mockResolvedValue({ text: 'ok' }) - } - - constructor(_: Record) {} - }, - FunctionCallingConfigMode: { - ANY: 'ANY', - AUTO: 'AUTO', - NONE: 'NONE' - }, - GenerateContentParameters: class {}, - GenerateContentResponseUsageMetadata: class {}, - GenerateContentConfig: class {}, - HarmBlockThreshold: { - BLOCK_NONE: 'BLOCK_NONE', - BLOCK_LOW_AND_ABOVE: 'BLOCK_LOW_AND_ABOVE', - BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE', - BLOCK_ONLY_HIGH: 'BLOCK_ONLY_HIGH', - HARM_BLOCK_THRESHOLD_UNSPECIFIED: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' - }, - HarmCategory: { - HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT', - HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH', - HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT' - }, - Modality: { - TEXT: 'TEXT', - IMAGE: 'IMAGE' - }, - Part: class {}, - SafetySetting: class {}, - Tool: class {} -})) - vi.mock('@/presenter', () => ({ presenter: { devicePresenter: { @@ -163,6 +45,10 @@ vi.mock('@/events', () => ({ PROVIDER_BATCH_UPDATE: 'PROVIDER_BATCH_UPDATE', MODEL_LIST_CHANGED: 'MODEL_LIST_CHANGED' }, + PROVIDER_DB_EVENTS: { + LOADED: 'LOADED', + UPDATED: 'UPDATED' + }, NOTIFICATION_EVENTS: { SHOW_ERROR: 'SHOW_ERROR' } @@ -170,21 +56,15 @@ vi.mock('@/events', () => ({ vi.mock('../../../../src/main/presenter/proxyConfig', () => ({ proxyConfig: { - getProxyUrl: mockGetProxyUrl + getProxyUrl: vi.fn().mockReturnValue(null) } })) -vi.mock('../../../../src/main/presenter/configPresenter/modelCapabilities', () => ({ - modelCapabilities: { - supportsReasoningEffort: vi.fn().mockReturnValue(false), - supportsVerbosity: vi.fn().mockReturnValue(false), - supportsReasoning: vi.fn().mockReturnValue(false), - supportsVision: vi.fn().mockReturnValue(false), - supportsToolCall: vi.fn().mockReturnValue(false), - supportsImageOutput: vi.fn().mockReturnValue(false), - getThinkingBudgetRange: vi.fn().mockReturnValue({}), - resolveProviderId: vi.fn((providerId: string) => providerId) - } +vi.mock('../../../../src/main/presenter/llmProviderPresenter/aiSdk', () => ({ + runAiSdkCoreStream: mockRunAiSdkCoreStream, + runAiSdkDimensions: vi.fn(), + runAiSdkEmbeddings: vi.fn(), + runAiSdkGenerateText: vi.fn() })) const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ @@ -224,33 +104,62 @@ const createConfigPresenter = ( describe('NewApiProvider capability routing', () => { beforeEach(() => { vi.clearAllMocks() + mockRunAiSdkCoreStream.mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { type: 'image_data', image_data: { data: 'generated-image', mimeType: 'image/png' } } + } + }) }) it('maps openai-response delegates to openai capability semantics', () => { - const provider = new NewApiProvider(createProvider(), createConfigPresenter()) - const delegateProvider = (provider as any).openaiResponsesDelegate.provider as LLM_PROVIDER + const provider = new AiSdkProvider( + createProvider(), + createConfigPresenter({ + 'gpt-4o': { + endpointType: 'openai-response' + } + }) + ) + const routeDecision = (provider as any).resolveRouteDecision('gpt-4o') + const runtimeProvider = (provider as any).getRuntimeProvider(routeDecision) as LLM_PROVIDER - expect(delegateProvider.id).toBe('new-api') - expect(delegateProvider.capabilityProviderId).toBe('openai') - expect(delegateProvider.apiType).toBe('openai-responses') + expect(runtimeProvider.id).toBe('new-api') + expect(runtimeProvider.capabilityProviderId).toBe('openai') + expect(runtimeProvider.apiType).toBe('openai-responses') }) it('maps gemini delegates to gemini capability semantics', () => { - const provider = new NewApiProvider(createProvider(), createConfigPresenter()) - const delegateProvider = (provider as any).geminiDelegate.provider as LLM_PROVIDER + const provider = new AiSdkProvider( + createProvider(), + createConfigPresenter({ + 'gemini-model': { + endpointType: 'gemini' + } + }) + ) + const routeDecision = (provider as any).resolveRouteDecision('gemini-model') + const runtimeProvider = (provider as any).getRuntimeProvider(routeDecision) as LLM_PROVIDER - expect(delegateProvider.id).toBe('new-api') - expect(delegateProvider.capabilityProviderId).toBe('gemini') - expect(delegateProvider.apiType).toBe('gemini') + expect(runtimeProvider.id).toBe('new-api') + expect(runtimeProvider.capabilityProviderId).toBe('gemini') + expect(runtimeProvider.apiType).toBe('gemini') }) it('maps anthropic delegates to anthropic capability semantics', () => { - const provider = new NewApiProvider(createProvider(), createConfigPresenter()) - const delegateProvider = (provider as any).anthropicDelegate.provider as LLM_PROVIDER + const provider = new AiSdkProvider( + createProvider(), + createConfigPresenter({ + 'claude-model': { + endpointType: 'anthropic' + } + }) + ) + const routeDecision = (provider as any).resolveRouteDecision('claude-model') + const runtimeProvider = (provider as any).getRuntimeProvider(routeDecision) as LLM_PROVIDER - expect(delegateProvider.id).toBe('new-api') - expect(delegateProvider.capabilityProviderId).toBe('anthropic') - expect(delegateProvider.apiType).toBe('anthropic') + expect(runtimeProvider.id).toBe('new-api') + expect(runtimeProvider.capabilityProviderId).toBe('anthropic') + expect(runtimeProvider.apiType).toBe('anthropic') }) it('keeps image-generation on the image runtime route while using openai capabilities', async () => { @@ -261,28 +170,21 @@ describe('NewApiProvider capability routing', () => { type: ModelType.Chat } }) - const provider = new NewApiProvider(createProvider(), configPresenter) - const openaiChatDelegate = (provider as any).openaiChatDelegate - const coreStreamSpy = vi - .spyOn(openaiChatDelegate, 'coreStream') - .mockImplementation(async function* ( - _messages: ChatMessage[], - _modelId: string, - modelConfig: ModelConfig - ): AsyncIterable { - expect(modelConfig.apiEndpoint).toBe(ApiEndpointType.Image) - expect(modelConfig.type).toBe(ModelType.ImageGeneration) - expect(modelConfig.endpointType).toBe('image-generation') - yield { type: 'text', content: 'generated-image' } as LLMCoreStreamEvent - }) + const provider = new AiSdkProvider(createProvider(), configPresenter) + ;(provider as any).isInitialized = true const result = await provider.completions( [{ role: 'user', content: 'Draw a cat' }], 'gpt-image-1' ) - expect(openaiChatDelegate.provider.capabilityProviderId).toBe('openai') - expect(coreStreamSpy).toHaveBeenCalledOnce() + const modelConfig = mockRunAiSdkCoreStream.mock.calls.at(-1)?.[3] + const context = mockRunAiSdkCoreStream.mock.calls.at(-1)?.[0] + + expect(context.provider.capabilityProviderId).toBe('openai') + expect(modelConfig.apiEndpoint).toBe(ApiEndpointType.Image) + expect(modelConfig.type).toBe(ModelType.ImageGeneration) + expect(modelConfig.endpointType).toBe('image-generation') expect(result.content).toBe('generated-image') }) }) diff --git a/test/main/presenter/llmProviderPresenter/openAICompatibleProvider.test.ts b/test/main/presenter/llmProviderPresenter/openAICompatibleProvider.test.ts index aa98cfb2d..fff08a1d8 100644 --- a/test/main/presenter/llmProviderPresenter/openAICompatibleProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/openAICompatibleProvider.test.ts @@ -1,31 +1,22 @@ -import { beforeEach, describe, expect, it, vi } from 'vitest' -import type { - ChatMessage, - IConfigPresenter, - ISQLitePresenter, - LLM_PROVIDER, - MCPToolDefinition, - ModelConfig -} from '../../../../src/shared/presenter' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import type { IConfigPresenter, LLM_PROVIDER, ModelConfig } from '../../../../src/shared/presenter' import { - OpenAICompatibleProvider, + AiSdkProvider, normalizeExtractedImageText -} from '../../../../src/main/presenter/llmProviderPresenter/providers/openAICompatibleProvider' -import { OpenRouterProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/openRouterProvider' -import { LLMProviderPresenter } from '../../../../src/main/presenter/llmProviderPresenter' +} from '../../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' const { - mockChatCompletionsCreate, - mockModelsList, - mockMcpToolsToOpenAITools, mockGetProxyUrl, - mockCacheImage + mockRunAiSdkCoreStream, + mockRunAiSdkDimensions, + mockRunAiSdkEmbeddings, + mockRunAiSdkGenerateText } = vi.hoisted(() => ({ - mockChatCompletionsCreate: vi.fn(), - mockModelsList: vi.fn().mockResolvedValue({ data: [] }), - mockMcpToolsToOpenAITools: vi.fn().mockResolvedValue([]), mockGetProxyUrl: vi.fn().mockReturnValue(null), - mockCacheImage: vi.fn() + mockRunAiSdkCoreStream: vi.fn(), + mockRunAiSdkDimensions: vi.fn(), + mockRunAiSdkEmbeddings: vi.fn(), + mockRunAiSdkGenerateText: vi.fn() })) vi.mock('electron', () => ({ @@ -35,64 +26,13 @@ vi.mock('electron', () => ({ getPath: vi.fn(() => '/mock/path'), isReady: vi.fn(() => true), on: vi.fn() - }, - session: {}, - ipcMain: { - on: vi.fn(), - handle: vi.fn(), - removeHandler: vi.fn() - }, - BrowserWindow: vi.fn(() => ({ - loadURL: vi.fn(), - loadFile: vi.fn(), - on: vi.fn(), - webContents: { send: vi.fn(), on: vi.fn(), isDestroyed: vi.fn(() => false) }, - isDestroyed: vi.fn(() => false), - close: vi.fn(), - show: vi.fn(), - hide: vi.fn() - })), - dialog: { - showOpenDialog: vi.fn() - }, - shell: { - openExternal: vi.fn() - } -})) - -vi.mock('openai', () => { - class MockOpenAI { - chat = { - completions: { - create: mockChatCompletionsCreate - } - } - models = { - list: mockModelsList - } - } - - return { - default: MockOpenAI, - AzureOpenAI: MockOpenAI - } -}) - -vi.mock('@/presenter', () => ({ - presenter: { - devicePresenter: { - cacheImage: mockCacheImage - } } })) vi.mock('@/eventbus', () => ({ eventBus: { on: vi.fn(), - sendToRenderer: vi.fn(), - sendToMain: vi.fn(), - emit: vi.fn(), - send: vi.fn() + sendToRenderer: vi.fn() }, SendTarget: { ALL_WINDOWS: 'ALL_WINDOWS' @@ -101,11 +41,12 @@ vi.mock('@/eventbus', () => ({ vi.mock('@/events', () => ({ CONFIG_EVENTS: { - PROXY_RESOLVED: 'PROXY_RESOLVED', - PROVIDER_ATOMIC_UPDATE: 'PROVIDER_ATOMIC_UPDATE', - PROVIDER_BATCH_UPDATE: 'PROVIDER_BATCH_UPDATE', MODEL_LIST_CHANGED: 'MODEL_LIST_CHANGED' }, + PROVIDER_DB_EVENTS: { + LOADED: 'LOADED', + UPDATED: 'UPDATED' + }, NOTIFICATION_EVENTS: { SHOW_ERROR: 'SHOW_ERROR' } @@ -117,47 +58,33 @@ vi.mock('../../../../src/main/presenter/proxyConfig', () => ({ } })) -vi.mock('../../../../src/main/presenter/configPresenter/modelCapabilities', () => ({ - modelCapabilities: { - supportsReasoningEffort: vi.fn().mockReturnValue(false), - supportsVerbosity: vi.fn().mockReturnValue(false), - supportsReasoning: vi.fn().mockReturnValue(false), - resolveProviderId: vi.fn((providerId: string) => providerId) - } +vi.mock('../../../../src/main/presenter/llmProviderPresenter/aiSdk', () => ({ + runAiSdkCoreStream: mockRunAiSdkCoreStream, + runAiSdkDimensions: mockRunAiSdkDimensions, + runAiSdkEmbeddings: mockRunAiSdkEmbeddings, + runAiSdkGenerateText: mockRunAiSdkGenerateText })) -const createAsyncStream = (chunks: Array>) => ({ +const createStream = (events: Array>) => ({ async *[Symbol.asyncIterator]() { - for (const chunk of chunks) { - yield chunk + for (const event of events) { + yield event } } }) -const collectEvents = async ( - provider: OpenAICompatibleProvider, - providerModel: string, - modelConfig: ModelConfig, - messages: ChatMessage[], - tools: MCPToolDefinition[] -) => { - const events = [] - for await (const event of provider.coreStream( - messages, - providerModel, - modelConfig, - 0.7, - 512, - tools - )) { - events.push(event) - } - return events -} +const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ + id: 'novita', + name: 'Novita', + apiType: 'openai-completions', + apiKey: 'test-key', + baseUrl: 'https://mock.example.com/v1', + enable: false, + ...overrides +}) -const createConfigPresenter = (providers: LLM_PROVIDER[]) => +const createConfigPresenter = (): IConfigPresenter => ({ - getProviders: vi.fn().mockReturnValue(providers), getProviderModels: vi.fn().mockReturnValue([]), getCustomModels: vi.fn().mockReturnValue([]), getModelConfig: vi.fn().mockReturnValue(undefined), @@ -166,316 +93,148 @@ const createConfigPresenter = (providers: LLM_PROVIDER[]) => getModelStatus: vi.fn().mockReturnValue(true) }) as unknown as IConfigPresenter -const mockSqlitePresenter = { - getAcpSession: vi.fn().mockResolvedValue(null), - upsertAcpSession: vi.fn().mockResolvedValue(undefined), - updateAcpSessionId: vi.fn().mockResolvedValue(undefined), - updateAcpWorkdir: vi.fn().mockResolvedValue(undefined), - updateAcpSessionStatus: vi.fn().mockResolvedValue(undefined), - deleteAcpSession: vi.fn().mockResolvedValue(undefined), - deleteAcpSessions: vi.fn().mockResolvedValue(undefined) -} as unknown as ISQLitePresenter - -describe('OpenAICompatibleProvider MCP runtime injection', () => { - const convertedTools = [ - { - type: 'function', - function: { - name: 'get_weather', - description: 'Get current weather', - parameters: { - type: 'object', - properties: { - city: { - type: 'string' - } - } - } - } - } - ] - - const modelConfig: ModelConfig = { - maxTokens: 1024, - contextLength: 8192, - vision: false, - functionCall: true, - reasoning: false, - type: 'chat' - } - - const messages: ChatMessage[] = [{ role: 'user', content: 'What is the weather today?' }] - - const mcpTools: MCPToolDefinition[] = [ - { - type: 'function', - function: { - name: 'get_weather', - description: 'Get current weather', - parameters: { - type: 'object', - properties: { - city: { - type: 'string' - } - }, - required: ['city'] - } - }, - server: { - name: 'weather-server', - icons: '', - description: 'Weather tools' - } - } - ] - - const mcpRuntime = { - mcpToolsToOpenAITools: mockMcpToolsToOpenAITools - } - - const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ - id: 'mock-openai-compatible', - name: 'Mock OpenAI Compatible', - apiType: 'openai-compatible', - apiKey: 'test-key', - baseUrl: 'https://mock.example.com/v1', - enable: false, - ...overrides - }) - +describe('AiSdkProvider openai-compatible', () => { beforeEach(() => { vi.clearAllMocks() - mockModelsList.mockResolvedValue({ data: [] }) + vi.unstubAllGlobals() mockGetProxyUrl.mockReturnValue(null) - mockMcpToolsToOpenAITools.mockResolvedValue(convertedTools) - mockChatCompletionsCreate.mockResolvedValue( - createAsyncStream([ - { - choices: [ - { - delta: { - content: 'ok' - }, - finish_reason: 'stop' - } - ], - usage: { - prompt_tokens: 12, - completion_tokens: 4, - total_tokens: 16 - } - } + mockRunAiSdkCoreStream.mockReturnValue( + createStream([ + { type: 'text', content: 'ok' }, + { type: 'stop', stop_reason: 'complete' } ]) ) }) - it('injects converted tools for direct OpenAICompatibleProvider instances', async () => { - const provider = new OpenAICompatibleProvider( - createProvider(), - createConfigPresenter([]), - mcpRuntime as any - ) - ;(provider as any).isInitialized = true - - const events = await collectEvents(provider, 'gpt-4o', modelConfig, messages, mcpTools) - const requestParams = mockChatCompletionsCreate.mock.calls[0]?.[0] - - expect(events.some((event) => event.type === 'text')).toBe(true) - expect(events.some((event) => event.type === 'stop')).toBe(true) - expect(mockMcpToolsToOpenAITools).toHaveBeenCalledWith(mcpTools, 'mock-openai-compatible') - expect(requestParams.tools).toEqual(convertedTools) - }) - - it('does not inject tools when mcpRuntime is missing', async () => { - const provider = new OpenAICompatibleProvider(createProvider(), createConfigPresenter([])) - ;(provider as any).isInitialized = true - - await collectEvents(provider, 'gpt-4o', modelConfig, messages, mcpTools) - const requestParams = mockChatCompletionsCreate.mock.calls[0]?.[0] - - expect(mockMcpToolsToOpenAITools).not.toHaveBeenCalled() - expect(requestParams.tools).toBeUndefined() - }) - - it('forwards mcpRuntime through OpenAICompatibleProvider subclasses', async () => { - const provider = new OpenRouterProvider( - createProvider({ - id: 'openrouter', - name: 'OpenRouter' - }), - createConfigPresenter([]), - mcpRuntime as any - ) - ;(provider as any).isInitialized = true - - await collectEvents(provider, 'gpt-4o', modelConfig, messages, mcpTools) - const requestParams = mockChatCompletionsCreate.mock.calls[0]?.[0] - - expect(mockMcpToolsToOpenAITools).toHaveBeenCalledWith(mcpTools, 'openrouter') - expect(requestParams.tools).toEqual(convertedTools) + afterEach(() => { + vi.unstubAllGlobals() }) - it('preserves mcpRuntime on the LLMProviderPresenter instantiation path', async () => { - const providerConfig = createProvider({ - id: 'openrouter', - name: 'OpenRouter' + it('fetches models over the provider HTTP endpoint instead of the legacy SDK client', async () => { + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ id: 'gpt-4o' }] + }) }) - const llmProviderPresenter = new LLMProviderPresenter( - createConfigPresenter([providerConfig]), - mockSqlitePresenter, - mcpRuntime as any + vi.stubGlobal('fetch', fetchMock) + + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + const models = await provider.fetchModels() + + expect(fetchMock).toHaveBeenCalledWith( + 'https://mock.example.com/v1/models', + expect.objectContaining({ + method: 'GET', + headers: expect.objectContaining({ + Authorization: 'Bearer test-key' + }) + }) ) - - const provider = llmProviderPresenter.getProviderInstance('openrouter') as OpenRouterProvider - ;(provider as any).isInitialized = true - - await collectEvents(provider, 'gpt-4o', modelConfig, messages, mcpTools) - const requestParams = mockChatCompletionsCreate.mock.calls[0]?.[0] - - expect(provider).toBeInstanceOf(OpenRouterProvider) - expect(mockMcpToolsToOpenAITools).toHaveBeenCalledWith(mcpTools, 'openrouter') - expect(requestParams.tools).toEqual(convertedTools) + expect(models).toEqual([ + expect.objectContaining({ + id: 'gpt-4o', + providerId: 'novita' + }) + ]) }) -}) -describe('normalizeExtractedImageText', () => { - it('keeps meaningful text after image markdown cleanup', () => { - expect(normalizeExtractedImageText(' Here is the updated image.\n\n')).toBe( - 'Here is the updated image.' - ) - }) - - it('drops markdown residue after image markdown cleanup', () => { - expect(normalizeExtractedImageText('`\n')).toBe('') - expect(normalizeExtractedImageText('[]()')).toBe('') - }) -}) - -describe('OpenAICompatibleProvider prompt cache behavior', () => { - beforeEach(() => { - vi.clearAllMocks() - mockModelsList.mockResolvedValue({ data: [] }) - mockGetProxyUrl.mockReturnValue(null) - mockChatCompletionsCreate.mockResolvedValue( - createAsyncStream([ - { - choices: [ - { - delta: { - content: 'ok' - }, - finish_reason: 'stop' - } - ], - usage: { - prompt_tokens: 80, - completion_tokens: 12, - total_tokens: 92, - prompt_tokens_details: { - cached_tokens: 24, - cache_write_tokens: 16 - } - } - } - ]) - ) - }) - - it('injects prompt_cache_key only for official OpenAI chat completions', async () => { - const provider = new OpenAICompatibleProvider( - { - id: 'openai', - name: 'OpenAI', - apiType: 'openai-compatible', - apiKey: 'test-key', - baseUrl: 'https://api.openai.com/v1', - enable: false - }, - createConfigPresenter([]) - ) + it('forwards streaming requests to the AI SDK runtime', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) ;(provider as any).isInitialized = true const modelConfig: ModelConfig = { maxTokens: 1024, contextLength: 8192, vision: false, - functionCall: false, + functionCall: true, reasoning: false, - type: 'chat', - conversationId: 'session-1' + type: 'chat' } - const events = await collectEvents( - provider, - 'gpt-5', + const events = [] + for await (const event of provider.coreStream( + [{ role: 'user', content: 'hello' }], + 'gpt-4o', modelConfig, - [{ role: 'user', content: 'cache me' }], + 0.7, + 512, [] - ) - const usageEvent = events.find((event) => event.type === 'usage') - const requestParams = mockChatCompletionsCreate.mock.calls[0]?.[0] + )) { + events.push(event) + } - expect(requestParams.prompt_cache_key).toMatch(/^deepchat:openai:gpt-5:/) - expect(usageEvent).toMatchObject({ - type: 'usage', - usage: { - cached_tokens: 24, - cache_write_tokens: 16 - } - }) + expect(events).toEqual([ + { type: 'text', content: 'ok' }, + { type: 'stop', stop_reason: 'complete' } + ]) + expect(mockRunAiSdkCoreStream).toHaveBeenCalledWith( + expect.objectContaining({ + providerKind: 'openai-compatible' + }), + [{ role: 'user', content: 'hello' }], + 'gpt-4o', + modelConfig, + 0.7, + 512, + [] + ) }) - it('adds explicit cache_control breakpoint for OpenRouter Claude without top-level cache_control', async () => { - const provider = new OpenRouterProvider( - { - id: 'openrouter', - name: 'OpenRouter', - apiType: 'openai-compatible', - apiKey: 'test-key', - baseUrl: 'https://openrouter.ai/api/v1', - enable: false - }, - createConfigPresenter([]) + it('builds azure runtime context with azure auth headers and image routing', async () => { + const provider = new AiSdkProvider( + createProvider({ + id: 'azure-openai', + name: 'Azure OpenAI', + apiType: 'openai-completions', + baseUrl: 'https://example.openai.azure.com/openai/deployments/deepchat-prod' + }), + createConfigPresenter() ) ;(provider as any).isInitialized = true - const modelConfig: ModelConfig = { + const modelConfig = { + apiEndpoint: 'image', maxTokens: 1024, contextLength: 8192, vision: false, functionCall: false, reasoning: false, - type: 'chat', - conversationId: 'session-2' - } + type: 'chat' + } as ModelConfig - await collectEvents( - provider, - 'anthropic/claude-sonnet-4', + for await (const _event of provider.coreStream( + [{ role: 'user', content: 'paint' }], + 'gpt-image-1', modelConfig, - [ - { role: 'user', content: 'history' }, - { role: 'assistant', content: 'stable reply' }, - { role: 'user', content: 'latest question' } - ], + 0.7, + 256, [] - ) + )) { + break + } - const requestParams = mockChatCompletionsCreate.mock.calls[0]?.[0] - expect(requestParams).not.toHaveProperty('cache_control') - expect(requestParams).not.toHaveProperty('prompt_cache_key') - expect(requestParams.messages[1]).toMatchObject({ - role: 'assistant', - content: [ - { - type: 'text', - text: 'stable reply', - cache_control: { - type: 'ephemeral' - } - } - ] + const context = mockRunAiSdkCoreStream.mock.calls.at(-1)?.[0] + expect(context.providerKind).toBe('azure') + expect(context.cleanHeaders).toBe(false) + expect(context.buildTraceHeaders()).toMatchObject({ + 'Content-Type': 'application/json', + 'api-key': 'test-key' }) + expect(context.shouldUseImageGeneration('gpt-image-1', modelConfig)).toBe(true) + expect(context.shouldUseImageGeneration('gpt-image-1', {} as ModelConfig)).toBe(false) + }) +}) + +describe('normalizeExtractedImageText', () => { + it('keeps meaningful text after markdown cleanup', () => { + expect(normalizeExtractedImageText(' Here is the updated image.\n\n')).toBe( + 'Here is the updated image.' + ) + }) + + it('drops markdown residue that contains no semantic text', () => { + expect(normalizeExtractedImageText('`\n')).toBe('') + expect(normalizeExtractedImageText('[]()')).toBe('') }) }) diff --git a/test/main/presenter/llmProviderPresenter/openAIResponsesProvider.test.ts b/test/main/presenter/llmProviderPresenter/openAIResponsesProvider.test.ts index b3a1318a0..7ab4d7bc4 100644 --- a/test/main/presenter/llmProviderPresenter/openAIResponsesProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/openAIResponsesProvider.test.ts @@ -1,48 +1,33 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' -import { OpenAIResponsesProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/openAIResponsesProvider' -import type { - ChatMessage, - IConfigPresenter, - LLM_PROVIDER, - MCPToolDefinition, - ModelConfig -} from '../../../../src/shared/presenter' - -const { mockResponsesCreate, mockModelsList, mockMcpToolsToOpenAIResponsesTools, mockGetProxyUrl } = - vi.hoisted(() => ({ - mockResponsesCreate: vi.fn(), - mockModelsList: vi.fn().mockResolvedValue({ data: [] }), - mockMcpToolsToOpenAIResponsesTools: vi.fn().mockResolvedValue([]), - mockGetProxyUrl: vi.fn().mockReturnValue(null) - })) - -vi.mock('openai', () => { - class MockOpenAI { - responses = { - create: mockResponsesCreate - } - models = { - list: mockModelsList - } - } +import type { IConfigPresenter, LLM_PROVIDER, ModelConfig } from '../../../../src/shared/presenter' +import { AiSdkProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' + +const { + mockRunAiSdkCoreStream, + mockRunAiSdkDimensions, + mockRunAiSdkEmbeddings, + mockRunAiSdkGenerateText +} = vi.hoisted(() => ({ + mockRunAiSdkCoreStream: vi.fn(), + mockRunAiSdkDimensions: vi.fn(), + mockRunAiSdkEmbeddings: vi.fn(), + mockRunAiSdkGenerateText: vi.fn() +})) - return { - default: MockOpenAI, - AzureOpenAI: MockOpenAI +vi.mock('electron', () => ({ + app: { + getName: vi.fn(() => 'DeepChat'), + getVersion: vi.fn(() => '0.0.0-test'), + getPath: vi.fn(() => '/mock/path'), + isReady: vi.fn(() => true), + on: vi.fn() } -}) - -vi.mock('@/presenter', () => ({ - presenter: {} })) vi.mock('@/eventbus', () => ({ eventBus: { on: vi.fn(), - sendToRenderer: vi.fn(), - sendToMain: vi.fn(), - emit: vi.fn(), - send: vi.fn() + sendToRenderer: vi.fn() }, SendTarget: { ALL_WINDOWS: 'ALL_WINDOWS' @@ -53,6 +38,10 @@ vi.mock('@/events', () => ({ CONFIG_EVENTS: { MODEL_LIST_CHANGED: 'MODEL_LIST_CHANGED' }, + PROVIDER_DB_EVENTS: { + LOADED: 'LOADED', + UPDATED: 'UPDATED' + }, NOTIFICATION_EVENTS: { SHOW_ERROR: 'SHOW_ERROR' } @@ -60,519 +49,122 @@ vi.mock('@/events', () => ({ vi.mock('../../../../src/main/presenter/proxyConfig', () => ({ proxyConfig: { - getProxyUrl: mockGetProxyUrl + getProxyUrl: vi.fn().mockReturnValue(null) } })) -vi.mock('../../../../src/main/presenter/configPresenter/modelCapabilities', () => ({ - modelCapabilities: { - supportsReasoningEffort: vi.fn().mockReturnValue(false), - supportsVerbosity: vi.fn().mockReturnValue(false) - } +vi.mock('../../../../src/main/presenter/llmProviderPresenter/aiSdk', () => ({ + runAiSdkCoreStream: mockRunAiSdkCoreStream, + runAiSdkDimensions: mockRunAiSdkDimensions, + runAiSdkEmbeddings: mockRunAiSdkEmbeddings, + runAiSdkGenerateText: mockRunAiSdkGenerateText })) -const createAsyncStream = (chunks: Array>) => ({ - async *[Symbol.asyncIterator]() { - for (const chunk of chunks) { - yield chunk - } - } +const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ + id: 'openai', + name: 'OpenAI', + apiType: 'openai-responses', + apiKey: 'test-key', + baseUrl: 'https://api.openai.com/v1', + enable: false, + ...overrides }) -const mcpRuntime = { - mcpToolsToOpenAIResponsesTools: mockMcpToolsToOpenAIResponsesTools -} - -describe('OpenAIResponsesProvider tool call id mapping', () => { - const mockProvider: LLM_PROVIDER = { - id: 'openai', - name: 'OpenAI', - apiType: 'openai-responses', - apiKey: 'test-key', - baseUrl: 'https://api.openai.com/v1', - enable: false - } - - const mockConfigPresenter = { +const createConfigPresenter = (): IConfigPresenter => + ({ getProviderModels: vi.fn().mockReturnValue([]), getCustomModels: vi.fn().mockReturnValue([]), + getModelConfig: vi.fn().mockReturnValue(undefined), getSetting: vi.fn().mockReturnValue(undefined), setProviderModels: vi.fn(), - getModelStatus: vi.fn().mockReturnValue(true), - addCustomModel: vi.fn(), - removeCustomModel: vi.fn(), - updateCustomModel: vi.fn() - } as unknown as IConfigPresenter - - const modelConfig: ModelConfig = { - maxTokens: 1024, - contextLength: 8192, - vision: false, - functionCall: true, - reasoning: false, - type: 'chat' - } - - const messages: ChatMessage[] = [{ role: 'user', content: 'please call tool' }] - const tools: MCPToolDefinition[] = [ - { - type: 'function', - function: { - name: 'test_tool', - description: 'test', - parameters: { - type: 'object', - properties: {} - } - }, - server: { - name: 'test-server', - icons: '', - description: 'test' - } - } - ] + getModelStatus: vi.fn().mockReturnValue(true) + }) as unknown as IConfigPresenter +describe('OpenAIResponsesProvider', () => { beforeEach(() => { vi.clearAllMocks() - mockModelsList.mockResolvedValue({ data: [] }) - mockMcpToolsToOpenAIResponsesTools.mockResolvedValue([]) - mockGetProxyUrl.mockReturnValue(null) - }) - - it('uses call_id for streamed tool events when item_id differs from call_id', async () => { - mockResponsesCreate.mockResolvedValue( - createAsyncStream([ - { - type: 'response.output_item.added', - output_index: 0, - item: { - type: 'function_call', - id: 'fc_123', - call_id: 'call_123', - name: 'test_tool', - arguments: '' - } - }, - { - type: 'response.function_call_arguments.delta', - item_id: 'fc_123', - output_index: 0, - delta: '{"city":"' - }, - { - type: 'response.function_call_arguments.done', - item_id: 'fc_123', - output_index: 0, - arguments: '{"city":"shanghai"}' - }, - { - type: 'response.completed', - response: { - usage: { - input_tokens: 10, - output_tokens: 5, - total_tokens: 15, - input_tokens_details: { - cached_tokens: 4, - cache_write_tokens: 6 - } - } - } - } - ]) - ) - - const provider = new OpenAIResponsesProvider( - mockProvider, - mockConfigPresenter, - mcpRuntime as any - ) - ;(provider as any).isInitialized = true - - const events = [] - for await (const event of provider.coreStream( - messages, - 'gpt-4o', - modelConfig, - 0.7, - 512, - tools - )) { - events.push(event) - } - - const startEvent = events.find((event) => event.type === 'tool_call_start') - const chunkEvent = events.find((event) => event.type === 'tool_call_chunk') - const endEvent = events.find((event) => event.type === 'tool_call_end') - const usageEvent = events.find((event) => event.type === 'usage') - const stopEvent = events.find((event) => event.type === 'stop') - - expect(startEvent).toBeDefined() - expect(startEvent?.tool_call_id).toBe('call_123') - expect(chunkEvent).toBeDefined() - expect(chunkEvent?.tool_call_id).toBe('call_123') - expect(chunkEvent?.tool_call_arguments_chunk).toBe('{"city":"') - expect(endEvent).toBeDefined() - expect(endEvent?.tool_call_id).toBe('call_123') - expect(endEvent?.tool_call_arguments_complete).toBe('{"city":"shanghai"}') - expect(usageEvent).toMatchObject({ - type: 'usage', - usage: expect.objectContaining({ - cached_tokens: 4, - cache_write_tokens: 6 - }) - }) - expect(stopEvent?.stop_reason).toBe('tool_use') - expect(mockMcpToolsToOpenAIResponsesTools).toHaveBeenCalledWith(tools, mockProvider.id) - }) - - it('uses unified fallback defaults when model list lacks capability metadata', async () => { - mockModelsList.mockResolvedValue({ - data: [{ id: 'gpt-4.1' }] + mockRunAiSdkCoreStream.mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { type: 'stop', stop_reason: 'complete' } + } }) - - const provider = new OpenAIResponsesProvider( - mockProvider, - mockConfigPresenter, - mcpRuntime as any - ) - const models = await (provider as any).fetchOpenAIModels() - - expect(models).toEqual([ - expect.objectContaining({ - id: 'gpt-4.1', - contextLength: 16000, - maxTokens: 4096 - }) - ]) }) - it('falls back to output_index mapping when item id is unavailable', async () => { - mockResponsesCreate.mockResolvedValue( - createAsyncStream([ - { - type: 'response.output_item.added', - output_index: 0, - item: { - type: 'function_call', - call_id: 'call_456', - name: 'test_tool', - arguments: '' - } - }, - { - type: 'response.function_call_arguments.delta', - item_id: 'fc_missing', - output_index: 0, - delta: '{"topic":"' - }, - { - type: 'response.function_call_arguments.done', - item_id: 'fc_missing', - output_index: 0, - arguments: '{"topic":"responses"}' - }, - { - type: 'response.completed', - response: { - usage: { - input_tokens: 6, - output_tokens: 3, - total_tokens: 9 - } - } - } - ]) - ) - - const provider = new OpenAIResponsesProvider( - mockProvider, - mockConfigPresenter, - mcpRuntime as any - ) + it('uses the responses runtime for official OpenAI providers', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) ;(provider as any).isInitialized = true - const events = [] - for await (const event of provider.coreStream( - messages, - 'gpt-4o', - modelConfig, - 0.7, - 512, - tools - )) { - events.push(event) - } - - const chunkEvent = events.find((event) => event.type === 'tool_call_chunk') - const endEvent = events.find((event) => event.type === 'tool_call_end') - const stopEvent = events.find((event) => event.type === 'stop') - - expect(chunkEvent).toBeDefined() - expect(chunkEvent?.tool_call_id).toBe('call_456') - expect(endEvent).toBeDefined() - expect(endEvent?.tool_call_id).toBe('call_456') - expect(endEvent?.tool_call_arguments_complete).toBe('{"topic":"responses"}') - expect(stopEvent?.stop_reason).toBe('tool_use') - }) - - it('serializes assistant history as shorthand text instead of input_text parts', async () => { - mockResponsesCreate.mockResolvedValue( - createAsyncStream([ + try { + for await (const _event of provider.coreStream( + [{ role: 'user', content: 'hello' }], + 'gpt-4o', { - type: 'response.completed', - response: { - usage: { - input_tokens: 4, - output_tokens: 2, - total_tokens: 6 - } - } - } - ]) - ) - - const provider = new OpenAIResponsesProvider( - mockProvider, - mockConfigPresenter, - mcpRuntime as any - ) - ;(provider as any).isInitialized = true - - const historyMessages: ChatMessage[] = [ - { role: 'system', content: 'system prompt' }, - { role: 'user', content: 'hi' }, - { role: 'assistant', content: 'Hi! What can I help you with today?' }, - { role: 'user', content: '你是谁' } - ] - - for await (const _event of provider.coreStream( - historyMessages, - 'gpt-5.3-codex', - modelConfig, - 0.7, - 512, - [] - )) { - // consume stream - } - - const requestParams = mockResponsesCreate.mock.calls[0][0] as { - input: Array> - } - - expect(requestParams.input).toEqual([ - { - role: 'system', - content: [{ type: 'input_text', text: 'system prompt' }] - }, - { - role: 'user', - content: [{ type: 'input_text', text: 'hi' }] - }, - { - role: 'assistant', - content: 'Hi! What can I help you with today?' - }, - { - role: 'user', - content: [{ type: 'input_text', text: '你是谁' }] + maxTokens: 1024, + contextLength: 8192, + vision: false, + functionCall: false, + reasoning: false, + type: 'chat' + } as ModelConfig, + 0.7, + 256, + [] + )) { + break } - ]) - }) - - it('flattens assistant content parts to text and omits unsupported images', async () => { - mockResponsesCreate.mockResolvedValue( - createAsyncStream([ - { - type: 'response.completed', - response: { - usage: { - input_tokens: 7, - output_tokens: 2, - total_tokens: 9 - } - } - } - ]) - ) - - const provider = new OpenAIResponsesProvider( - mockProvider, - mockConfigPresenter, - mcpRuntime as any - ) - ;(provider as any).isInitialized = true - - const historyMessages: ChatMessage[] = [ - { role: 'user', content: 'show history' }, - { - role: 'assistant', - content: [ - { type: 'text', text: 'Line 1. ' }, - { type: 'text', text: 'Line 2.' } - ] - }, - { - role: 'assistant', - content: [ - { type: 'text', text: 'Look ' }, - { type: 'image_url', image_url: { url: 'https://example.com/image.png' } }, - { type: 'text', text: 'here.' } - ] - }, - { - role: 'assistant', - content: [{ type: 'image_url', image_url: { url: 'https://example.com/only-image.png' } }] - }, - { role: 'user', content: 'continue' } - ] + } catch {} - for await (const _event of provider.coreStream( - historyMessages, - 'gpt-5.3-codex', - modelConfig, - 0.7, - 512, - [] - )) { - // consume stream - } + const context = mockRunAiSdkCoreStream.mock.calls.at(-1)?.[0] - const requestParams = mockResponsesCreate.mock.calls[0][0] as { - input: Array> - } - - expect(requestParams.input).toEqual([ - { - role: 'user', - content: [{ type: 'input_text', text: 'show history' }] - }, - { - role: 'assistant', - content: 'Line 1. Line 2.' - }, - { - role: 'assistant', - content: 'Look here.' - }, - { - role: 'user', - content: [{ type: 'input_text', text: 'continue' }] - } - ]) + expect(context.providerKind).toBe('openai-responses') + expect(context.shouldUseImageGeneration('gpt-image-1', {} as ModelConfig)).toBe(true) + expect(context.shouldUseImageGeneration('gpt-4o', {} as ModelConfig)).toBe(false) }) - it('emits request trace with final endpoint, headers and body', async () => { - const persist = vi.fn() - const traceAwareConfig = { - ...modelConfig, - requestTraceContext: { - enabled: true, - persist - } - } as ModelConfig & { - requestTraceContext: { - enabled: boolean - persist: (payload: { - endpoint: string - headers: Record - body: unknown - }) => void - } - } - - mockResponsesCreate.mockResolvedValue( - createAsyncStream([ - { - type: 'response.completed', - response: { - usage: { - input_tokens: 1, - output_tokens: 1, - total_tokens: 2 - } - } - } - ]) - ) - - const provider = new OpenAIResponsesProvider( - mockProvider, - mockConfigPresenter, - mcpRuntime as any + it('uses azure runtime semantics for azure-openai responses providers', async () => { + const provider = new AiSdkProvider( + createProvider({ + id: 'azure-openai', + name: 'Azure OpenAI', + baseUrl: 'https://example.openai.azure.com/openai' + }), + createConfigPresenter() ) ;(provider as any).isInitialized = true - for await (const _event of provider.coreStream( - messages, - 'gpt-4o', - traceAwareConfig, - 0.7, - 512, - [] - )) { - // consume stream - } - - expect(persist).toHaveBeenCalledTimes(1) - const payload = persist.mock.calls[0][0] as { - endpoint: string - headers: Record - body: Record - } - - expect(payload.endpoint).toContain('/responses') - expect(payload.headers).toHaveProperty('Authorization', 'Bearer test-key') - expect(payload.body).toMatchObject({ - model: 'gpt-4o', - temperature: 0.7, - max_output_tokens: 512, - stream: true - }) - }) - - it('injects prompt_cache_key for official OpenAI Responses requests', async () => { - mockResponsesCreate.mockResolvedValue( - createAsyncStream([ + try { + for await (const _event of provider.coreStream( + [{ role: 'user', content: 'paint' }], + 'gpt-image-1', { - type: 'response.completed', - response: { - usage: { - input_tokens: 8, - output_tokens: 2, - total_tokens: 10 - } - } - } - ]) - ) - - const provider = new OpenAIResponsesProvider( - mockProvider, - mockConfigPresenter, - mcpRuntime as any - ) - ;(provider as any).isInitialized = true - - const promptCacheModelConfig: ModelConfig = { - ...modelConfig, - conversationId: 'session-1' - } + apiEndpoint: 'image', + maxTokens: 1024, + contextLength: 8192, + vision: false, + functionCall: false, + reasoning: false, + type: 'chat' + } as ModelConfig, + 0.7, + 256, + [] + )) { + break + } + } catch {} - for await (const _event of provider.coreStream( - [{ role: 'user', content: 'cache me' }], - 'gpt-5', - promptCacheModelConfig, - 0.7, - 512, - [] - )) { - // consume stream - } + const context = mockRunAiSdkCoreStream.mock.calls.at(-1)?.[0] - const requestParams = mockResponsesCreate.mock.calls[0][0] as Record - expect(requestParams.prompt_cache_key).toMatch(/^deepchat:openai:gpt-5:/) + expect(context.providerKind).toBe('azure') + expect(context.buildTraceHeaders()).toMatchObject({ + 'Content-Type': 'application/json', + 'api-key': 'test-key' + }) + expect( + context.shouldUseImageGeneration('gpt-image-1', { + apiEndpoint: 'image' + } as ModelConfig) + ).toBe(true) + expect(context.shouldUseImageGeneration('gpt-image-1', {} as ModelConfig)).toBe(false) }) }) diff --git a/test/main/presenter/llmProviderPresenter/zenmuxProvider.test.ts b/test/main/presenter/llmProviderPresenter/zenmuxProvider.test.ts index b8977106b..33f5579a9 100644 --- a/test/main/presenter/llmProviderPresenter/zenmuxProvider.test.ts +++ b/test/main/presenter/llmProviderPresenter/zenmuxProvider.test.ts @@ -1,28 +1,6 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import type { IConfigPresenter, LLM_PROVIDER } from '../../../../src/shared/presenter' -import { ZenmuxProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/zenmuxProvider' - -const ZENMUX_ANTHROPIC_BASE_URL = 'https://zenmux.ai/api/anthropic' - -const { - mockOpenAIConstructor, - mockAnthropicConstructor, - mockChatCompletionsCreate, - mockOpenAIModelsList, - mockAnthropicMessagesCreate, - mockAnthropicModelsList, - mockGetProxyUrl, - mockCacheImage -} = vi.hoisted(() => ({ - mockOpenAIConstructor: vi.fn(), - mockAnthropicConstructor: vi.fn(), - mockChatCompletionsCreate: vi.fn(), - mockOpenAIModelsList: vi.fn().mockResolvedValue({ data: [] }), - mockAnthropicMessagesCreate: vi.fn(), - mockAnthropicModelsList: vi.fn().mockResolvedValue({ data: [] }), - mockGetProxyUrl: vi.fn().mockReturnValue(null), - mockCacheImage: vi.fn() -})) +import { AiSdkProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/aiSdkProvider' vi.mock('electron', () => ({ app: { @@ -31,78 +9,17 @@ vi.mock('electron', () => ({ getPath: vi.fn(() => '/mock/path'), isReady: vi.fn(() => true), on: vi.fn() - }, - session: {}, - ipcMain: { - on: vi.fn(), - handle: vi.fn(), - removeHandler: vi.fn() - }, - BrowserWindow: vi.fn(() => ({ - loadURL: vi.fn(), - loadFile: vi.fn(), - on: vi.fn(), - webContents: { send: vi.fn(), on: vi.fn(), isDestroyed: vi.fn(() => false) }, - isDestroyed: vi.fn(() => false), - close: vi.fn(), - show: vi.fn(), - hide: vi.fn() - })), - dialog: { - showOpenDialog: vi.fn() - }, - shell: { - openExternal: vi.fn() } })) -vi.mock('openai', () => { - class MockOpenAI { - chat = { - completions: { - create: mockChatCompletionsCreate - } - } - models = { - list: mockOpenAIModelsList - } - embeddings = { - create: vi.fn() - } - - constructor(options: Record) { - mockOpenAIConstructor(options) - } - } - - return { - default: MockOpenAI, - AzureOpenAI: MockOpenAI - } -}) - vi.mock('@/presenter', () => ({ presenter: { devicePresenter: { - cacheImage: mockCacheImage + cacheImage: vi.fn() } } })) -vi.mock('@anthropic-ai/sdk', () => ({ - default: vi.fn().mockImplementation((options: Record) => { - mockAnthropicConstructor(options) - return { - messages: { - create: mockAnthropicMessagesCreate - }, - models: { - list: mockAnthropicModelsList - } - } - }) -})) - vi.mock('@/eventbus', () => ({ eventBus: { on: vi.fn(), @@ -123,6 +40,10 @@ vi.mock('@/events', () => ({ PROVIDER_BATCH_UPDATE: 'PROVIDER_BATCH_UPDATE', MODEL_LIST_CHANGED: 'MODEL_LIST_CHANGED' }, + PROVIDER_DB_EVENTS: { + LOADED: 'LOADED', + UPDATED: 'UPDATED' + }, NOTIFICATION_EVENTS: { SHOW_ERROR: 'SHOW_ERROR' } @@ -130,25 +51,23 @@ vi.mock('@/events', () => ({ vi.mock('../../../../src/main/presenter/proxyConfig', () => ({ proxyConfig: { - getProxyUrl: mockGetProxyUrl + getProxyUrl: vi.fn().mockReturnValue(null) } })) -vi.mock('../../../../src/main/presenter/configPresenter/modelCapabilities', () => ({ - modelCapabilities: { - supportsReasoningEffort: vi.fn().mockReturnValue(false), - supportsVerbosity: vi.fn().mockReturnValue(false), - supportsReasoning: vi.fn().mockReturnValue(false), - resolveProviderId: vi.fn((providerId: string) => providerId) - } +vi.mock('../../../../src/main/presenter/llmProviderPresenter/aiSdk', () => ({ + runAiSdkCoreStream: vi.fn(), + runAiSdkDimensions: vi.fn(), + runAiSdkEmbeddings: vi.fn(), + runAiSdkGenerateText: vi.fn() })) -const createConfigPresenter = () => +const createConfigPresenter = (): IConfigPresenter => ({ getProviderModels: vi.fn().mockReturnValue([]), getCustomModels: vi.fn().mockReturnValue([]), - getModelConfig: vi.fn().mockReturnValue(undefined), getDbProviderModels: vi.fn().mockReturnValue([]), + getModelConfig: vi.fn().mockReturnValue(undefined), getSetting: vi.fn().mockReturnValue(undefined), setProviderModels: vi.fn(), getModelStatus: vi.fn().mockReturnValue(true) @@ -164,184 +83,58 @@ const createProvider = (overrides?: Partial): LLM_PROVIDER => ({ ...overrides }) -describe('ZenmuxProvider', () => { +describe('AiSdkProvider zenmux', () => { beforeEach(() => { vi.clearAllMocks() - mockGetProxyUrl.mockReturnValue(null) - mockOpenAIModelsList.mockResolvedValue({ data: [] }) - mockAnthropicModelsList.mockResolvedValue({ data: [] }) - mockChatCompletionsCreate.mockResolvedValue({ - choices: [{ message: { content: 'openai-ok' } }] - }) - mockAnthropicMessagesCreate.mockResolvedValue({ - content: [{ type: 'text', text: 'anthropic-ok' }], - usage: undefined - }) }) - it('routes anthropic/* models through the fixed Anthropic endpoint', async () => { - const provider = new ZenmuxProvider(createProvider(), createConfigPresenter()) - - const result = await provider.generateText('hello', 'anthropic/claude-sonnet-4-5') + it('routes anthropic models through the anthropic runtime', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + const routeDecision = (provider as any).resolveRouteDecision('anthropic/claude-sonnet-4-5') + const runtimeProvider = (provider as any).getRuntimeProvider(routeDecision) as LLM_PROVIDER - expect(result.content).toBe('anthropic-ok') - expect(mockAnthropicConstructor).toHaveBeenCalledWith( - expect.objectContaining({ - apiKey: 'test-key', - baseURL: ZENMUX_ANTHROPIC_BASE_URL - }) - ) - expect(mockAnthropicMessagesCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: 'anthropic/claude-sonnet-4-5' - }) - ) - expect(mockAnthropicMessagesCreate.mock.calls.at(-1)?.[0]).not.toHaveProperty('cache_control') - expect(mockChatCompletionsCreate).not.toHaveBeenCalled() + expect(routeDecision.providerKind).toBe('anthropic') + expect(runtimeProvider.baseUrl).toBe('https://zenmux.ai/api/anthropic') }) - it('uses explicit Anthropic cache breakpoints for ZenMux Claude history', async () => { - mockAnthropicMessagesCreate.mockResolvedValueOnce({ - content: [{ type: 'text', text: 'anthropic-ok' }], - usage: undefined - }) - - const provider = new ZenmuxProvider(createProvider(), createConfigPresenter()) - const anthropicDelegate = (provider as any).anthropicDelegate - anthropicDelegate.clientInitialized = true - anthropicDelegate.isInitialized = true - anthropicDelegate.anthropic = { - messages: { - create: mockAnthropicMessagesCreate - }, - models: { - list: mockAnthropicModelsList - } - } - vi.spyOn(anthropicDelegate, 'ensureClientInitialized').mockResolvedValue(undefined) - - const result = await provider.completions( - [ - { role: 'user', content: 'history' }, - { role: 'assistant', content: 'stable reply' }, - { role: 'user', content: 'latest question' } - ], - 'anthropic/claude-sonnet-4-5' - ) + it('routes non-anthropic models through the openai-compatible runtime', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) + const routeDecision = (provider as any).resolveRouteDecision('moonshotai/kimi-k2.5') - expect(result.content).toBe('anthropic-ok') - expect(mockAnthropicMessagesCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: 'anthropic/claude-sonnet-4-5', - messages: [ - { - role: 'user', - content: [{ type: 'text', text: 'history' }] - }, - { - role: 'assistant', - content: [ - { - type: 'text', - text: 'stable reply', - cache_control: { - type: 'ephemeral' - } - } - ] - }, - { - role: 'user', - content: [{ type: 'text', text: 'latest question' }] - } - ] - }) - ) - expect(mockAnthropicMessagesCreate.mock.calls.at(-1)?.[0]).not.toHaveProperty('cache_control') + expect(routeDecision.providerKind).toBe('openai-compatible') }) - it('routes non-anthropic models through the configured OpenAI-compatible endpoint', async () => { - const provider = new ZenmuxProvider( - createProvider({ baseUrl: 'https://custom.zenmux.ai/api/v1' }), - createConfigPresenter() - ) - - const result = await provider.generateText('hello', 'moonshotai/kimi-k2.5') - - expect(result.content).toBe('openai-ok') - expect(mockOpenAIConstructor).toHaveBeenCalledWith( - expect.objectContaining({ - apiKey: 'test-key', - baseURL: 'https://custom.zenmux.ai/api/v1' - }) - ) - expect(mockChatCompletionsCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: 'moonshotai/kimi-k2.5' + it('fetches model metadata from the shared OpenAI-compatible path and keeps the ZenMux group', async () => { + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ id: 'moonshotai/kimi-k2.5' }] }) - ) - expect(mockAnthropicMessagesCreate).not.toHaveBeenCalled() - }) - - it('fetches model metadata from the OpenAI-compatible models API and keeps the ZenMux group', async () => { - mockOpenAIModelsList.mockResolvedValue({ - data: [{ id: 'moonshotai/kimi-k2.5' }, { id: 'anthropic/claude-sonnet-4-5' }] }) + vi.stubGlobal('fetch', fetchMock) + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) - const provider = new ZenmuxProvider(createProvider(), createConfigPresenter()) const models = await provider.fetchModels() - expect(mockOpenAIModelsList).toHaveBeenCalled() expect(models).toEqual([ expect.objectContaining({ id: 'moonshotai/kimi-k2.5', group: 'ZenMux', providerId: 'zenmux' - }), - expect.objectContaining({ - id: 'anthropic/claude-sonnet-4-5', - group: 'ZenMux', - providerId: 'zenmux' }) ]) }) - it('uses the OpenAI-compatible check path', async () => { - const provider = new ZenmuxProvider(createProvider(), createConfigPresenter()) - - const result = await provider.check() - - expect(result).toEqual({ isOk: true, errorMsg: null }) - expect(mockOpenAIModelsList).toHaveBeenCalled() - expect(mockAnthropicMessagesCreate).not.toHaveBeenCalled() - }) - - it('refreshes both delegates on proxy updates after the anthropic route has been initialized', async () => { - const provider = new ZenmuxProvider(createProvider(), createConfigPresenter()) - const openaiDelegate = (provider as any).openaiDelegate - const anthropicDelegate = (provider as any).anthropicDelegate - const openaiProxySpy = vi.spyOn(openaiDelegate, 'onProxyResolved') - const anthropicProxySpy = vi.spyOn(anthropicDelegate, 'onProxyResolved') - await anthropicDelegate.ensureClientInitialized() - - expect(anthropicDelegate.isClientInitialized()).toBe(true) - - provider.onProxyResolved() - - expect(openaiProxySpy).toHaveBeenCalledTimes(1) - expect(anthropicProxySpy).toHaveBeenCalledTimes(1) - }) - - it('fails fast for embeddings on anthropic/* models', async () => { - const provider = new ZenmuxProvider(createProvider(), createConfigPresenter()) + it('fails fast for embeddings on anthropic models', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) await expect(provider.getEmbeddings('anthropic/claude-sonnet-4-5', ['hello'])).rejects.toThrow( 'Embeddings not supported for Anthropic models: anthropic/claude-sonnet-4-5' ) }) - it('fails fast for embedding dimensions on anthropic/* models', async () => { - const provider = new ZenmuxProvider(createProvider(), createConfigPresenter()) + it('fails fast for embedding dimensions on anthropic models', async () => { + const provider = new AiSdkProvider(createProvider(), createConfigPresenter()) await expect(provider.getDimensions('anthropic/claude-sonnet-4-5')).rejects.toThrow( 'Embeddings not supported for Anthropic models: anthropic/claude-sonnet-4-5' diff --git a/test/main/presenter/sessionPresenter/messageFormatter.test.ts b/test/main/presenter/sessionPresenter/messageFormatter.test.ts index 1396036cc..bbd6dc30b 100644 --- a/test/main/presenter/sessionPresenter/messageFormatter.test.ts +++ b/test/main/presenter/sessionPresenter/messageFormatter.test.ts @@ -110,4 +110,55 @@ describe('messageFormatter', () => { expect(String(messages[0].content)).toContain('function_call_record') expect(String(messages[0].content)).toContain('search') }) + + it('preserves tool call provider options when function calling is enabled', () => { + const toolCallBlock: AssistantMessageBlock = { + type: 'tool_call', + status: 'success', + timestamp: Date.now(), + extra: { + providerOptionsJson: JSON.stringify({ + vertex: { + thoughtSignature: 'tool-thought-signature' + } + }) + }, + tool_call: { + id: 'tool-1', + name: 'search', + params: '{"q":"hi"}', + response: 'ok' + } + } + + const assistantMessage = createMessage('assistant-1', 'assistant', [toolCallBlock]) + const messages = addContextMessages([assistantMessage], false, true) + + expect(messages).toEqual([ + { + role: 'assistant', + content: undefined, + tool_calls: [ + { + id: 'tool-1', + type: 'function', + function: { + name: 'search', + arguments: '{"q":"hi"}' + }, + provider_options: { + vertex: { + thoughtSignature: 'tool-thought-signature' + } + } + } + ] + }, + { + role: 'tool', + content: 'ok', + tool_call_id: 'tool-1' + } + ]) + }) })