diff --git a/docs/specs/ollama-model-selection/plan.md b/docs/specs/ollama-model-selection/plan.md new file mode 100644 index 000000000..1f935638c --- /dev/null +++ b/docs/specs/ollama-model-selection/plan.md @@ -0,0 +1,48 @@ +# Ollama 模型可选性与跨窗口同步实施计划 + +## 1. 关键决策 + +1. 真源统一到 main/config 持久层,renderer `ollamaStore` 只保留 UI 状态。 +2. Ollama provider 刷新时同时采集 `listModels()` 与 `listRunningModels()` 并合并。 +3. 新发现模型通过 `ensureModelStatus(..., true)` 默认启用,但不覆盖已有显式状态。 +4. 本次只做 SDK 审计,不升级 `ollama` 依赖版本。 + +## 2. main/config 设计 + +### 2.1 状态语义 + +1. 在 `ModelStatusHelper` 增加 `ensureModelStatus`。 +2. 仅当状态尚未存在时写入默认值。 +3. 直接写存储并更新 cache,不发送 `MODEL_STATUS_CHANGED`。 + +### 2.2 Ollama provider 列表构建 + +1. `fetchProviderModels()` 并行获取本地模型与运行中模型。 +2. 使用 `model.name` 合并,优先保留本地模型主体字段。 +3. 合并 `capabilities`、`model_info` 和已有缓存能力元数据。 +4. 生成 `MODEL_META` 后写入 config provider models。 + +## 3. renderer 刷新链路 + +1. `ollamaStore.refreshOllamaModels(providerId)` 先更新设置页本地/运行中列表。 +2. 随后调用 `llmP.refreshModels(providerId)` 让 main 重建持久化目录。 +3. 当前窗口再调用 `modelStore.refreshProviderModels(providerId)` 收敛显示。 +4. pull 成功事件复用同一刷新链路。 + +## 4. 回归点 + +1. 删除 `modelStore.updateModelStatus()` 中对 Ollama 的提前返回,确保显式启停会落到 config。 +2. 聊天侧继续依赖 `modelStore.enabledModels`,不增加临时兼容分支。 + +## 5. 测试策略 + +### Main + +1. `ModelStatusHelper.ensureModelStatus` 不覆盖显式关闭状态。 +2. `OllamaProvider.fetchModels()` 合并本地与运行中模型并保留能力元数据。 + +### Renderer + +1. `ollamaStore.refreshOllamaModels()` 会调用 `llmP.refreshModels()` 与 `modelStore.refreshProviderModels()`。 +2. pull 成功事件会触发同样的刷新链路。 +3. `ChatStatusBar`、`ModelSelect`、`ModelChooser` 显示 Ollama chat 模型并过滤 Ollama embedding 模型。 diff --git a/docs/specs/ollama-model-selection/spec.md b/docs/specs/ollama-model-selection/spec.md new file mode 100644 index 000000000..8e1e64dc3 --- /dev/null +++ b/docs/specs/ollama-model-selection/spec.md @@ -0,0 +1,71 @@ +# Ollama 模型可选性与跨窗口同步规格 + +## 概述 + +修复 Ollama 模型在设置页可见但在聊天状态栏、模型选择器等入口不可选的问题,并统一模型目录真源到 main/config 持久层。 + +## 背景与目标 + +1. 当前设置窗口通过 `ollamaStore` 的本地临时状态直接拼装可选模型,和主聊天窗口依赖的 `modelStore`/config 真源不一致。 +2. 运行中的 Ollama 模型没有稳定进入持久化目录,导致跨窗口、刷新后、重启后行为不一致。 +3. 新发现的本地/运行中模型需要默认可选,但必须保留用户显式关闭的结果。 + +## 用户故事 + +### US-1:聊天入口可选 + +作为用户,我希望在设置页刷新出 Ollama 模型后,不用重启聊天窗口就能在 `ChatStatusBar` 和其他模型选择器中选中它们。 + +### US-2:运行中模型可见 + +作为用户,我希望仅存在于“运行中模型”列表里的 Ollama 模型也能进入可选目录,而不是只认本地模型列表。 + +### US-3:显式关闭可保留 + +作为用户,我希望手动关闭某个 Ollama 模型后,后续刷新不会把它重新强制打开。 + +## 功能需求 + +### A. 真源统一 + +- [ ] Ollama 可选模型目录以 main/config 持久层为准。 +- [ ] renderer `ollamaStore` 只维护设置页 UI 状态,不再直接改写全局模型真源。 + +### B. 模型合并 + +- [ ] Ollama provider 刷新时合并 `本地模型 ∪ 运行中模型`。 +- [ ] 以 `model.name` 去重。 +- [ ] 本地模型和运行中模型同名时,以本地模型为主,补齐运行中模型额外信息。 +- [ ] 运行中-only 模型也进入持久化 provider model 列表。 + +### C. 元数据保真 + +- [ ] 合并结果保留 `type/contextLength/vision/functionCall/reasoning` 等关键能力字段。 +- [ ] 嵌入模型继续标记为 `embedding`,聊天选择器仍应过滤掉它们。 + +### D. 默认启用语义 + +- [ ] 新发现的 Ollama 本地/运行中模型默认设为可选。 +- [ ] 已有显式状态时不得覆盖,尤其是用户已关闭的模型。 +- [ ] 该默认启用写入不发送逐模型状态变更事件。 + +### E. 刷新链路 + +- [ ] 设置页手动刷新、初始化、pull 成功后都走同一条刷新链路: + 1. 拉取本地/运行中模型更新设置页 UI。 + 2. 调 main `refreshModels(providerId)` 重建持久化目录。 + 3. 当前窗口刷新 `modelStore.refreshProviderModels(providerId)`。 + +## 非目标 + +1. 本次不升级 `ollama` SDK。 +2. 不改 Ollama 设置页“运行中的模型 / 本地模型”分区 UI。 +3. 不新增用户可见配置项。 + +## 验收标准 + +- [ ] 设置页刷新 Ollama 后,聊天窗口模型选择器立即能看到新的 chat 模型。 +- [ ] 运行中-only 的 Ollama 模型能被选中发起会话。 +- [ ] Ollama embedding 模型不会出现在聊天模型列表中。 +- [ ] 用户关闭某个 Ollama 模型后,刷新不会重新启用它。 +- [ ] `package.json` 中 `ollama` 版本保持不变。 diff --git a/docs/specs/ollama-model-selection/tasks.md b/docs/specs/ollama-model-selection/tasks.md new file mode 100644 index 000000000..0d6bb371e --- /dev/null +++ b/docs/specs/ollama-model-selection/tasks.md @@ -0,0 +1,11 @@ +# Ollama 模型可选性与跨窗口同步任务拆分 + +- [x] 为 config/model status 增加 `ensureModelStatus` 语义。 +- [x] 调整 Ollama provider 的模型抓取逻辑,合并本地与运行中模型。 +- [x] 调整 renderer `ollamaStore`,改为 UI 状态 + 主进程刷新链路。 +- [x] 移除 Ollama 模型状态更新的本地短路逻辑,保证显式关闭可持久化。 +- [x] 补充 main/store/component 回归测试。 +- [x] 运行 `pnpm run format` +- [x] 运行 `pnpm run i18n` +- [x] 运行 `pnpm run lint` +- [x] 运行相关测试并确认通过 diff --git a/src/main/presenter/configPresenter/index.ts b/src/main/presenter/configPresenter/index.ts index 82b69368d..6d386fc69 100644 --- a/src/main/presenter/configPresenter/index.ts +++ b/src/main/presenter/configPresenter/index.ts @@ -1016,6 +1016,10 @@ export class ConfigPresenter implements IConfigPresenter { this.modelStatusHelper.setModelStatus(providerId, modelId, enabled) } + ensureModelStatus(providerId: string, modelId: string, enabled: boolean): void { + this.modelStatusHelper.ensureModelStatus(providerId, modelId, enabled) + } + enableModel(providerId: string, modelId: string): void { this.modelStatusHelper.enableModel(providerId, modelId) } diff --git a/src/main/presenter/configPresenter/modelStatusHelper.ts b/src/main/presenter/configPresenter/modelStatusHelper.ts index 9203ad36c..0585e9a8a 100644 --- a/src/main/presenter/configPresenter/modelStatusHelper.ts +++ b/src/main/presenter/configPresenter/modelStatusHelper.ts @@ -65,6 +65,14 @@ export class ModelStatusHelper { return result } + private hasStoredStatus(statusKey: string): boolean { + const candidate = this.store as ElectronStore & { has?: (key: string) => boolean } + if (typeof candidate.has === 'function') { + return candidate.has(statusKey) + } + return this.store.get(statusKey) !== undefined + } + setModelStatus(providerId: string, modelId: string, enabled: boolean): void { const statusKey = this.getStatusKey(providerId, modelId) this.setSetting(statusKey, enabled) @@ -84,6 +92,21 @@ export class ModelStatusHelper { this.setModelStatus(providerId, modelId, false) } + ensureModelStatus(providerId: string, modelId: string, enabled: boolean): void { + const statusKey = this.getStatusKey(providerId, modelId) + + if (this.cache.has(statusKey) || this.hasStoredStatus(statusKey)) { + if (!this.cache.has(statusKey)) { + const status = this.store.get(statusKey) as boolean | undefined + this.cache.set(statusKey, typeof status === 'boolean' ? status : false) + } + return + } + + this.store.set(statusKey, enabled) + this.cache.set(statusKey, enabled) + } + clearModelStatusCache(): void { this.cache.clear() } diff --git a/src/main/presenter/llmProviderPresenter/providers/ollamaProvider.ts b/src/main/presenter/llmProviderPresenter/providers/ollamaProvider.ts index d8400b2ce..d1a21e4c0 100644 --- a/src/main/presenter/llmProviderPresenter/providers/ollamaProvider.ts +++ b/src/main/presenter/llmProviderPresenter/providers/ollamaProvider.ts @@ -11,6 +11,7 @@ import { LLM_EMBEDDING_ATTRS, IConfigPresenter } from '@shared/presenter' +import { ModelType } from '@shared/model' import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS } from '@shared/modelConfigDefaults' import { createStreamEvent } from '@shared/types/core/llm-events' import { BaseLLMProvider, SUMMARY_TITLES_PROMPT } from '../baseProvider' @@ -75,24 +76,129 @@ export class OllamaProvider extends BaseLLMProvider { return headers } + private mergeCapabilities(...sources: Array): string[] { + return Array.from(new Set(sources.flatMap((source) => (Array.isArray(source) ? source : [])))) + } + + private mergeModelInfo( + primary?: OllamaModel['model_info'], + secondary?: OllamaModel['model_info'] + ): OllamaModel['model_info'] { + if (!primary && !secondary) { + return undefined + } + + const mergedGeneral = + secondary?.general || primary?.general + ? { + ...secondary?.general, + ...primary?.general + } + : undefined + + const mergedVisionEmbeddingLength = + primary?.vision?.embedding_length ?? secondary?.vision?.embedding_length + const mergedVision = + typeof mergedVisionEmbeddingLength === 'number' + ? { + embedding_length: mergedVisionEmbeddingLength + } + : undefined + + return { + ...secondary, + ...primary, + ...(mergedGeneral ? { general: mergedGeneral } : {}), + ...(mergedVision ? { vision: mergedVision } : {}) + } + } + + private mergeOllamaModels(preferred: OllamaModel, secondary?: OllamaModel): OllamaModel { + if (!secondary) { + return preferred + } + + return { + ...secondary, + ...preferred, + details: { + ...secondary.details, + ...preferred.details + }, + model_info: this.mergeModelInfo(preferred.model_info, secondary.model_info), + capabilities: this.mergeCapabilities(preferred.capabilities, secondary.capabilities) + } + } + + private resolveOllamaModelMeta(model: OllamaModel, cachedModel?: MODEL_META): MODEL_META { + const capabilitySet = new Set( + this.mergeCapabilities( + model.capabilities, + cachedModel?.type === ModelType.Embedding ? ['embedding'] : undefined, + cachedModel?.vision ? ['vision'] : undefined, + cachedModel?.functionCall ? ['tools'] : undefined, + cachedModel?.reasoning ? ['thinking'] : undefined + ) + ) + + const resolvedType = capabilitySet.has('embedding') + ? ModelType.Embedding + : (cachedModel?.type ?? ModelType.Chat) + + const family = model.details?.family || cachedModel?.group || 'default' + const parameterSize = model.details?.parameter_size || '' + const description = `${parameterSize} ${family} model`.trim() + + return { + id: model.name, + name: model.name, + providerId: this.provider.id, + contextLength: + model.model_info?.context_length ?? + cachedModel?.contextLength ?? + DEFAULT_MODEL_CONTEXT_LENGTH, + maxTokens: cachedModel?.maxTokens ?? DEFAULT_MODEL_MAX_TOKENS, + isCustom: false, + group: family, + description, + vision: capabilitySet.has('vision') || Boolean(model.model_info?.vision?.embedding_length), + functionCall: capabilitySet.has('tools'), + reasoning: capabilitySet.has('thinking'), + type: resolvedType + } + } + // Basic Provider functionality implementation protected async fetchProviderModels(): Promise { try { console.log('Ollama service check', this.ollama, this.provider) - // Get list of locally installed Ollama models - const ollamaModels = await this.listModels() + const [localModels, runningModels] = await Promise.all([ + this.listModels(), + this.listRunningModels() + ]) - // Convert Ollama model format to application's MODEL_META format - return ollamaModels.map((model) => ({ - id: model.name, - name: model.name, - providerId: this.provider.id, - contextLength: DEFAULT_MODEL_CONTEXT_LENGTH, - maxTokens: DEFAULT_MODEL_MAX_TOKENS, - isCustom: false, - group: model.details?.family || 'default', - description: `${model.details?.parameter_size || ''} ${model.details?.family || ''} model` - })) + const cachedModels = new Map( + this.configPresenter.getProviderModels(this.provider.id).map((model) => [model.id, model]) + ) + + const mergedModels = new Map() + for (const localModel of localModels) { + mergedModels.set(localModel.name, localModel) + } + for (const runningModel of runningModels) { + const existing = mergedModels.get(runningModel.name) + const merged = existing + ? this.mergeOllamaModels(existing, runningModel) + : this.mergeOllamaModels(runningModel) + mergedModels.set(runningModel.name, merged) + } + + const resolvedModels = Array.from(mergedModels.values()).map((model) => { + this.configPresenter.ensureModelStatus(this.provider.id, model.name, true) + return this.resolveOllamaModelMeta(model, cachedModels.get(model.name)) + }) + + return resolvedModels } catch (error) { console.error('Failed to fetch Ollama models:', error) // Fallback to aggregated Provider DB curated list for Ollama diff --git a/src/renderer/src/stores/modelStore.ts b/src/renderer/src/stores/modelStore.ts index c6b97f89e..821da3a06 100644 --- a/src/renderer/src/stores/modelStore.ts +++ b/src/renderer/src/stores/modelStore.ts @@ -577,11 +577,6 @@ export const useModelStore = defineStore('model', () => { const previousState = getLocalModelEnabledState(providerId, modelId) updateLocalModelStatus(providerId, modelId, enabled) - const provider = providerStore.providers.find((p) => p.id === providerId) - if (provider?.apiType === 'ollama') { - return - } - try { await llmP.updateModelStatus(providerId, modelId, enabled) await refreshProviderModels(providerId) diff --git a/src/renderer/src/stores/ollamaStore.ts b/src/renderer/src/stores/ollamaStore.ts index 12e60e610..9ac903083 100644 --- a/src/renderer/src/stores/ollamaStore.ts +++ b/src/renderer/src/stores/ollamaStore.ts @@ -1,19 +1,15 @@ import { onBeforeUnmount, onMounted, ref } from 'vue' -import { defineStore, storeToRefs } from 'pinia' +import { defineStore } from 'pinia' import { OLLAMA_EVENTS } from '@/events' import { usePresenter } from '@/composables/usePresenter' -import type { OllamaModel, RENDERER_MODEL_META } from '@shared/presenter' -import { ModelType } from '@shared/model' -import { DEFAULT_MODEL_CONTEXT_LENGTH, DEFAULT_MODEL_MAX_TOKENS } from '@shared/modelConfigDefaults' +import type { OllamaModel } from '@shared/presenter' import { useModelStore } from '@/stores/modelStore' import { useProviderStore } from '@/stores/providerStore' export const useOllamaStore = defineStore('ollama', () => { const llmP = usePresenter('llmproviderPresenter') - const configP = usePresenter('configPresenter') const modelStore = useModelStore() const providerStore = useProviderStore() - const { allProviderModels, enabledModels } = storeToRefs(modelStore) const runningModels = ref>({}) const localModels = ref>({}) @@ -60,119 +56,6 @@ export const useOllamaStore = defineStore('ollama', () => { const getOllamaPullingModels = (providerId: string): Record => pullingProgress.value[providerId] || {} - type OllamaRendererModel = RENDERER_MODEL_META & { - ollamaModel?: OllamaModel - temperature?: number - reasoningEffort?: string - verbosity?: string - thinkingBudget?: number - forcedSearch?: boolean - searchStrategy?: string - } - - const syncOllamaModelsToGlobal = async (providerId: string): Promise => { - const ollamaProvider = providerStore.providers.find((p) => p.id === providerId) - if (!ollamaProvider) return - - const existingOllamaModels = - allProviderModels.value.find((item) => item.providerId === providerId)?.models || [] - - const existingModelMap = new Map( - existingOllamaModels.map((model) => [model.id, model as OllamaRendererModel]) - ) - - const local = getOllamaLocalModels(providerId) - - const ollamaModelsAsGlobal = await Promise.all( - local.map(async (model) => { - const existingModel = existingModelMap.get(model.name) - const capabilitySources: string[] = [] - if (Array.isArray((model as any)?.capabilities)) { - capabilitySources.push(...((model as any).capabilities as string[])) - } - if ( - existingModel?.ollamaModel && - Array.isArray((existingModel.ollamaModel as any)?.capabilities) - ) { - capabilitySources.push(...((existingModel.ollamaModel as any).capabilities as string[])) - } - const capabilitySet = new Set(capabilitySources) - - const modelConfig = await configP.getModelConfig(model.name, providerId) - - const contextLength = - modelConfig?.contextLength ?? - existingModel?.contextLength ?? - (model as any)?.model_info?.context_length ?? - DEFAULT_MODEL_CONTEXT_LENGTH - - const maxTokens = - modelConfig?.maxTokens ?? existingModel?.maxTokens ?? DEFAULT_MODEL_MAX_TOKENS - - const resolvedType = - modelConfig?.type ?? - existingModel?.type ?? - (capabilitySet.has('embedding') ? ModelType.Embedding : ModelType.Chat) - - const normalized: OllamaRendererModel = { - ...existingModel, - id: model.name, - name: model.name, - contextLength, - maxTokens, - group: existingModel?.group || 'local', - enabled: true, - isCustom: existingModel?.isCustom || false, - providerId, - vision: modelConfig?.vision ?? existingModel?.vision ?? capabilitySet.has('vision'), - functionCall: - modelConfig?.functionCall ?? existingModel?.functionCall ?? capabilitySet.has('tools'), - reasoning: - modelConfig?.reasoning ?? existingModel?.reasoning ?? capabilitySet.has('thinking'), - temperature: modelConfig?.temperature ?? existingModel?.temperature, - reasoningEffort: modelConfig?.reasoningEffort ?? existingModel?.reasoningEffort, - verbosity: modelConfig?.verbosity ?? existingModel?.verbosity, - thinkingBudget: modelConfig?.thinkingBudget ?? existingModel?.thinkingBudget, - type: resolvedType, - ollamaModel: model - } - - return normalized - }) - ) - - const existingIndex = allProviderModels.value.findIndex( - (item) => item.providerId === providerId - ) - - if (existingIndex !== -1) { - allProviderModels.value[existingIndex].models = ollamaModelsAsGlobal - } else { - allProviderModels.value.push({ - providerId, - models: ollamaModelsAsGlobal - }) - } - - const enabledIndex = enabledModels.value.findIndex((item) => item.providerId === providerId) - const enabledOllamaModels = ollamaModelsAsGlobal.filter((model) => model.enabled) - - if (enabledIndex !== -1) { - if (enabledOllamaModels.length > 0) { - enabledModels.value[enabledIndex].models = enabledOllamaModels - } else { - enabledModels.value.splice(enabledIndex, 1) - } - } else if (enabledOllamaModels.length > 0) { - enabledModels.value.push({ - providerId, - models: enabledOllamaModels - }) - } - - enabledModels.value = [...enabledModels.value] - } - const refreshOllamaModels = async (providerId: string): Promise => { try { const [running, local] = await Promise.all([ @@ -181,7 +64,8 @@ export const useOllamaStore = defineStore('ollama', () => { ]) setRunningModels(providerId, running) setLocalModels(providerId, local) - await syncOllamaModelsToGlobal(providerId) + await llmP.refreshModels(providerId) + await modelStore.refreshProviderModels(providerId) } catch (error) { console.error('Failed to refresh Ollama models for', providerId, error) } @@ -218,9 +102,9 @@ export const useOllamaStore = defineStore('ollama', () => { } if (status === 'success' || status === 'completed') { - setTimeout(() => { + setTimeout(async () => { updatePullingProgress(providerId, modelName) - modelStore.getProviderModelsQuery(providerId).refetch() + await refreshOllamaModels(providerId) }, 600) } } @@ -292,7 +176,6 @@ export const useOllamaStore = defineStore('ollama', () => { getOllamaRunningModels, getOllamaLocalModels, getOllamaPullingModels, - syncOllamaModelsToGlobal, handleOllamaModelPullEvent, setupOllamaEventListeners, removeOllamaEventListeners, diff --git a/src/shared/types/presenters/legacy.presenters.d.ts b/src/shared/types/presenters/legacy.presenters.d.ts index 05647f1dc..c73198570 100644 --- a/src/shared/types/presenters/legacy.presenters.d.ts +++ b/src/shared/types/presenters/legacy.presenters.d.ts @@ -573,6 +573,7 @@ export interface IConfigPresenter { setCloseToQuit(value: boolean): void getModelStatus(providerId: string, modelId: string): boolean setModelStatus(providerId: string, modelId: string, enabled: boolean): void + ensureModelStatus(providerId: string, modelId: string, enabled: boolean): void // Batch get model status getBatchModelStatus(providerId: string, modelIds: string[]): Record // Language settings diff --git a/test/main/presenter/configPresenter/modelStatusHelper.test.ts b/test/main/presenter/configPresenter/modelStatusHelper.test.ts new file mode 100644 index 000000000..57a62a8fa --- /dev/null +++ b/test/main/presenter/configPresenter/modelStatusHelper.test.ts @@ -0,0 +1,68 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ModelStatusHelper } from '../../../../src/main/presenter/configPresenter/modelStatusHelper' + +const { sendToRenderer } = vi.hoisted(() => ({ + sendToRenderer: vi.fn() +})) + +vi.mock('@/eventbus', () => ({ + eventBus: { + sendToRenderer + }, + SendTarget: { + ALL_WINDOWS: 'ALL_WINDOWS' + } +})) + +class MockElectronStore { + private readonly data = new Map() + + get(key: string) { + return this.data.get(key) + } + + set(key: string, value: unknown) { + this.data.set(key, value) + } + + delete(key: string) { + this.data.delete(key) + } + + has(key: string) { + return this.data.has(key) + } +} + +describe('ModelStatusHelper.ensureModelStatus', () => { + beforeEach(() => { + sendToRenderer.mockReset() + }) + + it('writes the default value only when no status exists yet', () => { + const store = new MockElectronStore() + const helper = new ModelStatusHelper({ + store: store as any, + setSetting: (key, value) => store.set(key, value) + }) + + helper.ensureModelStatus('ollama', 'qwen3:8b', true) + + expect(helper.getModelStatus('ollama', 'qwen3:8b')).toBe(true) + expect(sendToRenderer).not.toHaveBeenCalled() + }) + + it('preserves an explicit user choice when ensureModelStatus runs later', () => { + const store = new MockElectronStore() + const helper = new ModelStatusHelper({ + store: store as any, + setSetting: (key, value) => store.set(key, value) + }) + + helper.setModelStatus('ollama', 'deepseek-r1:1.5b', false) + helper.ensureModelStatus('ollama', 'deepseek-r1:1.5b', true) + + expect(helper.getModelStatus('ollama', 'deepseek-r1:1.5b')).toBe(false) + expect(sendToRenderer).toHaveBeenCalledTimes(1) + }) +}) diff --git a/test/main/presenter/llmProviderPresenter/ollamaProvider.test.ts b/test/main/presenter/llmProviderPresenter/ollamaProvider.test.ts new file mode 100644 index 000000000..a077f7527 --- /dev/null +++ b/test/main/presenter/llmProviderPresenter/ollamaProvider.test.ts @@ -0,0 +1,163 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ModelType } from '../../../../src/shared/model' +import type { + IConfigPresenter, + LLM_PROVIDER, + MODEL_META, + OllamaModel +} from '../../../../src/shared/presenter' +import { OllamaProvider } from '../../../../src/main/presenter/llmProviderPresenter/providers/ollamaProvider' + +vi.mock('ollama', () => ({ + Ollama: class MockOllama {} +})) + +vi.mock('@shared/logger', () => ({ + default: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn() + } +})) + +vi.mock('../../../../src/main/presenter/devicePresenter', () => ({ + DevicePresenter: { + getDefaultHeaders: () => ({}) + } +})) + +const createModel = ( + name: string, + options?: { + family?: string + parameterSize?: string + contextLength?: number + capabilities?: string[] + } +): OllamaModel => ({ + name, + model: name, + size: 1, + digest: `${name}-digest`, + modified_at: new Date(), + details: { + format: 'gguf', + family: options?.family ?? 'llama', + families: [options?.family ?? 'llama'], + parameter_size: options?.parameterSize ?? '7b', + quantization_level: 'Q4_K_M' + }, + model_info: { + context_length: options?.contextLength ?? 8192, + embedding_length: options?.capabilities?.includes('embedding') ? 768 : undefined + }, + capabilities: options?.capabilities ?? ['chat'] +}) + +describe('OllamaProvider.fetchModels', () => { + let configPresenter: IConfigPresenter + let provider: LLM_PROVIDER + + beforeEach(() => { + configPresenter = { + getProviderModels: vi.fn(() => [ + { + id: 'deepseek-r1:1.5b', + name: 'deepseek-r1:1.5b', + providerId: 'ollama', + group: 'deepseek', + contextLength: 16384, + maxTokens: 4096, + functionCall: true, + reasoning: false, + vision: false, + type: ModelType.Chat + } satisfies MODEL_META + ]), + getCustomModels: vi.fn(() => []), + setProviderModels: vi.fn(), + ensureModelStatus: vi.fn() + } as unknown as IConfigPresenter + + provider = { + id: 'ollama', + name: 'Ollama', + apiType: 'ollama', + apiKey: '', + baseUrl: 'http://127.0.0.1:11434', + enable: false + } + }) + + it('merges local and running models, keeps running-only models, and preserves capabilities', async () => { + const ollamaProvider = new OllamaProvider(provider, configPresenter) + + vi.spyOn(ollamaProvider, 'listModels').mockResolvedValue([ + createModel('deepseek-r1:1.5b', { + family: 'deepseek', + parameterSize: '1.5b', + contextLength: 32768, + capabilities: ['chat', 'tools'] + }), + createModel('nomic-embed-text:latest', { + family: 'nomic', + parameterSize: '335m', + contextLength: 8192, + capabilities: ['embedding'] + }) + ]) + vi.spyOn(ollamaProvider, 'listRunningModels').mockResolvedValue([ + createModel('deepseek-r1:1.5b', { + family: 'deepseek', + parameterSize: '1.5b', + contextLength: 32768, + capabilities: ['chat', 'thinking'] + }), + createModel('qwen3:8b', { + family: 'qwen', + parameterSize: '8b', + contextLength: 65536, + capabilities: ['chat'] + }) + ]) + + const models = await ollamaProvider.fetchModels() + + expect(models.map((model) => model.id)).toEqual([ + 'deepseek-r1:1.5b', + 'nomic-embed-text:latest', + 'qwen3:8b' + ]) + expect(models).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + id: 'deepseek-r1:1.5b', + functionCall: true, + reasoning: true, + contextLength: 32768, + type: ModelType.Chat + }), + expect.objectContaining({ + id: 'nomic-embed-text:latest', + type: ModelType.Embedding + }), + expect.objectContaining({ + id: 'qwen3:8b', + group: 'qwen' + }) + ]) + ) + expect(configPresenter.ensureModelStatus).toHaveBeenCalledWith( + 'ollama', + 'deepseek-r1:1.5b', + true + ) + expect(configPresenter.ensureModelStatus).toHaveBeenCalledWith( + 'ollama', + 'nomic-embed-text:latest', + true + ) + expect(configPresenter.ensureModelStatus).toHaveBeenCalledWith('ollama', 'qwen3:8b', true) + expect(configPresenter.setProviderModels).toHaveBeenCalledWith('ollama', models) + }) +}) diff --git a/test/renderer/components/ChatStatusBar.test.ts b/test/renderer/components/ChatStatusBar.test.ts index e54ba234c..5cd155946 100644 --- a/test/renderer/components/ChatStatusBar.test.ts +++ b/test/renderer/components/ChatStatusBar.test.ts @@ -687,6 +687,31 @@ describe('ChatStatusBar model and session panels', () => { expect(wrapper.text()).not.toContain('bge-rerank-v2') }) + it('shows Ollama chat models in the picker while filtering Ollama embedding models out', async () => { + const { wrapper } = await setup({ + extraModelGroups: [ + { + providerId: 'ollama', + providerName: 'Ollama', + models: [ + { id: 'deepseek-r1:1.5b', name: 'DeepSeek R1', type: 'chat' }, + { id: 'nomic-embed-text:latest', name: 'Nomic Embed', type: 'embedding' } + ] + } + ] + }) + + const filteredGroups = (wrapper.vm as any).filteredModelGroups as Array<{ + providerId: string + models: Array<{ id: string }> + }> + const ollamaGroup = filteredGroups.find((group) => group.providerId === 'ollama') + + expect(ollamaGroup?.models.map((model) => model.id)).toEqual(['deepseek-r1:1.5b']) + expect(wrapper.text()).toContain('deepseek-r1:1.5b') + expect(wrapper.text()).not.toContain('nomic-embed-text:latest') + }) + it('skips non-chat defaults and falls back to the first chat-selectable model', async () => { const { wrapper, draftStore } = await setup({ extraModelGroups: [ diff --git a/test/renderer/components/ModelChooser.test.ts b/test/renderer/components/ModelChooser.test.ts new file mode 100644 index 000000000..7f1f6cbbe --- /dev/null +++ b/test/renderer/components/ModelChooser.test.ts @@ -0,0 +1,140 @@ +import { describe, expect, it, vi } from 'vitest' +import { mount } from '@vue/test-utils' +import { ref } from 'vue' +import { ModelType } from '../../../src/shared/model' + +const setup = async () => { + vi.resetModules() + + vi.doMock('@/stores/providerStore', () => ({ + useProviderStore: () => ({ + sortedProviders: [ + { id: 'ollama', name: 'Ollama', enable: true }, + { id: 'openai', name: 'OpenAI', enable: true } + ] + }) + })) + + vi.doMock('@/stores/modelStore', () => ({ + useModelStore: () => ({ + enabledModels: [ + { + providerId: 'ollama', + models: [ + { id: 'deepseek-r1:1.5b', name: 'deepseek-r1:1.5b', type: 'chat' }, + { id: 'nomic-embed-text:latest', name: 'nomic-embed-text:latest', type: 'embedding' } + ] + } + ] + }) + })) + + vi.doMock('@/stores/theme', () => ({ + useThemeStore: () => ({ + isDark: false + }) + })) + + vi.doMock('@/stores/language', () => ({ + useLanguageStore: () => ({ + dir: 'ltr' + }) + })) + + vi.doMock('@/components/chat-input/composables/useChatMode', () => ({ + useChatMode: () => ({ + currentMode: ref('agent') + }) + })) + + vi.doMock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key + }) + })) + + vi.doMock('@shadcn/components/ui/badge', () => ({ + Badge: { + name: 'Badge', + template: '
' + } + })) + + vi.doMock('@shadcn/components/ui/button', () => ({ + Button: { + name: 'Button', + props: { + type: { type: String, default: 'button' } + }, + emits: ['click'], + template: '' + } + })) + + vi.doMock('@shadcn/components/ui/card', () => ({ + Card: { + name: 'Card', + template: '
' + }, + CardContent: { + name: 'CardContent', + template: '
' + } + })) + + vi.doMock('@shadcn/components/ui/input', () => ({ + Input: { + name: 'Input', + props: ['modelValue'], + emits: ['update:modelValue'], + template: + '' + } + })) + + vi.doMock('@shadcn/components/ui/scroll-area', () => ({ + ScrollArea: { + name: 'ScrollArea', + template: '
' + } + })) + + vi.doMock('@/components/icons/ModelIcon.vue', () => ({ + default: { + name: 'ModelIcon', + template: '' + } + })) + + vi.doMock('@iconify/vue', () => ({ + Icon: { + name: 'Icon', + template: '' + } + })) + + const ModelChooser = (await import('@/components/ModelChooser.vue')).default + + return mount(ModelChooser, { + props: { + type: [ModelType.Chat] + } + }) +} + +describe('ModelChooser', () => { + it('includes Ollama chat models and excludes Ollama embedding models', async () => { + const wrapper = await setup() + + expect(wrapper.text()).toContain('deepseek-r1:1.5b') + expect(wrapper.text()).not.toContain('nomic-embed-text:latest') + + const firstButton = wrapper.find('button') + await firstButton.trigger('click') + + expect(wrapper.emitted('update:model')?.[0]).toEqual([ + { id: 'deepseek-r1:1.5b', name: 'deepseek-r1:1.5b', type: 'chat' }, + 'ollama' + ]) + }) +}) diff --git a/test/renderer/components/ModelSelect.test.ts b/test/renderer/components/ModelSelect.test.ts new file mode 100644 index 000000000..39755cc9b --- /dev/null +++ b/test/renderer/components/ModelSelect.test.ts @@ -0,0 +1,96 @@ +import { describe, expect, it, vi } from 'vitest' +import { mount } from '@vue/test-utils' +import { ref } from 'vue' +import { ModelType } from '../../../src/shared/model' + +const setup = async () => { + vi.resetModules() + + vi.doMock('@/stores/providerStore', () => ({ + useProviderStore: () => ({ + sortedProviders: [ + { id: 'ollama', name: 'Ollama', enable: true }, + { id: 'openai', name: 'OpenAI', enable: true } + ] + }) + })) + + vi.doMock('@/stores/modelStore', () => ({ + useModelStore: () => ({ + enabledModels: [ + { + providerId: 'ollama', + models: [ + { id: 'deepseek-r1:1.5b', name: 'deepseek-r1:1.5b', type: 'chat' }, + { id: 'nomic-embed-text:latest', name: 'nomic-embed-text:latest', type: 'embedding' } + ] + } + ] + }) + })) + + vi.doMock('@/stores/theme', () => ({ + useThemeStore: () => ({ + isDark: false + }) + })) + + vi.doMock('@/stores/language', () => ({ + useLanguageStore: () => ({ + dir: 'ltr' + }) + })) + + vi.doMock('@/components/chat-input/composables/useChatMode', () => ({ + useChatMode: () => ({ + currentMode: ref('agent') + }) + })) + + vi.doMock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string) => key + }) + })) + + vi.doMock('@shadcn/components/ui/input', () => ({ + Input: { + name: 'Input', + props: ['modelValue'], + emits: ['update:modelValue'], + template: + '' + } + })) + + vi.doMock('@/components/icons/ModelIcon.vue', () => ({ + default: { + name: 'ModelIcon', + template: '' + } + })) + + const ModelSelect = (await import('@/components/ModelSelect.vue')).default + + return mount(ModelSelect, { + props: { + type: [ModelType.Chat] + } + }) +} + +describe('ModelSelect', () => { + it('includes Ollama chat models and excludes Ollama embedding models', async () => { + const wrapper = await setup() + + expect(wrapper.text()).toContain('deepseek-r1:1.5b') + expect(wrapper.text()).not.toContain('nomic-embed-text:latest') + + const firstOption = wrapper.findAll('.cursor-pointer')[0] + await firstOption.trigger('click') + + expect(wrapper.emitted('update:model')).toEqual([ + [{ id: 'deepseek-r1:1.5b', name: 'deepseek-r1:1.5b', type: 'chat' }, 'ollama'] + ]) + }) +}) diff --git a/test/renderer/stores/modelStore.test.ts b/test/renderer/stores/modelStore.test.ts index 8dbd838e4..39fd76f76 100644 --- a/test/renderer/stores/modelStore.test.ts +++ b/test/renderer/stores/modelStore.test.ts @@ -21,7 +21,11 @@ const createQueryCache = () => { } } -const setupStore = async (overrides?: { configPresenter?: Record }) => { +const setupStore = async (overrides?: { + configPresenter?: Record + llmPresenter?: Record + providerStore?: Record +}) => { vi.resetModules() const queryCache = createQueryCache() @@ -39,7 +43,13 @@ const setupStore = async (overrides?: { configPresenter?: Record }) ...overrides?.configPresenter } const llmPresenter = { - getModelList: vi.fn(async () => []) + getModelList: vi.fn(async () => []), + updateModelStatus: vi.fn(async () => undefined), + ...overrides?.llmPresenter + } + const providerStore = { + providers: [], + ...overrides?.providerStore } vi.doMock('pinia', () => ({ @@ -59,7 +69,7 @@ const setupStore = async (overrides?: { configPresenter?: Record }) })) vi.doMock('@/stores/providerStore', () => ({ - useProviderStore: () => ({ providers: [] }) + useProviderStore: () => providerStore })) vi.doMock('@/composables/usePresenter', () => ({ @@ -151,4 +161,29 @@ describe('modelStore.refreshProviderModels', () => { } ]) }) + + it('persists ollama model status changes through llm presenter', async () => { + const { store, llmPresenter } = await setupStore({ + providerStore: { + providers: [{ id: 'ollama', apiType: 'ollama' }] + }, + configPresenter: { + getDbProviderModels: vi.fn(async () => []), + getProviderModels: vi.fn(async () => [ + { + id: 'deepseek-r1:1.5b', + name: 'deepseek-r1:1.5b', + providerId: 'ollama', + isCustom: false + } + ]), + getBatchModelStatus: vi.fn(async () => ({ 'deepseek-r1:1.5b': true })) + } + }) + + await store.refreshProviderModels('ollama') + await store.updateModelStatus('ollama', 'deepseek-r1:1.5b', false) + + expect(llmPresenter.updateModelStatus).toHaveBeenCalledWith('ollama', 'deepseek-r1:1.5b', false) + }) }) diff --git a/test/renderer/stores/ollamaStore.test.ts b/test/renderer/stores/ollamaStore.test.ts new file mode 100644 index 000000000..503632aa9 --- /dev/null +++ b/test/renderer/stores/ollamaStore.test.ts @@ -0,0 +1,114 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' + +const createModel = (name: string) => ({ + name, + model: name, + size: 1, + digest: `${name}-digest`, + modified_at: new Date(), + details: { + format: 'gguf', + family: 'llama', + families: ['llama'], + parameter_size: '7b', + quantization_level: 'Q4_K_M' + }, + model_info: { + context_length: 8192, + embedding_length: 0 + }, + capabilities: ['chat'] +}) + +const setupStore = async () => { + vi.resetModules() + vi.useFakeTimers() + + const llmPresenter = { + listOllamaRunningModels: vi.fn(async () => [createModel('qwen3:8b')]), + listOllamaModels: vi.fn(async () => [createModel('deepseek-r1:1.5b')]), + refreshModels: vi.fn(async () => undefined), + pullOllamaModels: vi.fn(async () => true) + } + const modelStore = { + refreshProviderModels: vi.fn(async () => undefined) + } + const providerStore = { + providers: [{ id: 'ollama', apiType: 'ollama', enable: true }] + } + + ;(window as any).electron = { + ipcRenderer: { + on: vi.fn(), + removeAllListeners: vi.fn() + } + } + + vi.doMock('pinia', () => ({ + defineStore: (_id: string, setup: () => unknown) => setup + })) + + vi.doMock('vue', async () => { + const actual = await vi.importActual('vue') + return { + ...actual, + onMounted: vi.fn(), + onBeforeUnmount: vi.fn() + } + }) + + vi.doMock('@/composables/usePresenter', () => ({ + usePresenter: () => llmPresenter + })) + + vi.doMock('@/stores/modelStore', () => ({ + useModelStore: () => modelStore + })) + + vi.doMock('@/stores/providerStore', () => ({ + useProviderStore: () => providerStore + })) + + const { useOllamaStore } = await import('@/stores/ollamaStore') + + return { + store: useOllamaStore(), + llmPresenter, + modelStore + } +} + +afterEach(() => { + vi.useRealTimers() +}) + +describe('ollamaStore', () => { + it('refreshes UI lists, then persists through main refresh and local modelStore refresh', async () => { + const { store, llmPresenter, modelStore } = await setupStore() + + await store.refreshOllamaModels('ollama') + + expect(store.getOllamaRunningModels('ollama').map((model) => model.name)).toEqual(['qwen3:8b']) + expect(store.getOllamaLocalModels('ollama').map((model) => model.name)).toEqual([ + 'deepseek-r1:1.5b' + ]) + expect(llmPresenter.refreshModels).toHaveBeenCalledWith('ollama') + expect(modelStore.refreshProviderModels).toHaveBeenCalledWith('ollama') + }) + + it('reuses the same refresh chain when pull completes', async () => { + const { store, llmPresenter, modelStore } = await setupStore() + + store.handleOllamaModelPullEvent({ + eventId: 'pullOllamaModels', + providerId: 'ollama', + modelName: 'deepseek-r1:1.5b', + status: 'success' + }) + + await vi.advanceTimersByTimeAsync(600) + + expect(llmPresenter.refreshModels).toHaveBeenCalledWith('ollama') + expect(modelStore.refreshProviderModels).toHaveBeenCalledWith('ollama') + }) +})