diff --git a/src/CodexAcpClient.ts b/src/CodexAcpClient.ts index d5b59f3..8578633 100644 --- a/src/CodexAcpClient.ts +++ b/src/CodexAcpClient.ts @@ -13,6 +13,7 @@ import type {Disposable} from "vscode-jsonrpc"; import type { ClientInfo, ReasoningEffort, + ServiceTier, ServerNotification } from "./app-server"; import type {JsonValue} from "./app-server/serde_json/JsonValue"; @@ -220,6 +221,7 @@ export class CodexAcpClient { sessionId: request.sessionId, currentModelId: currentModelId, models: codexModels, + currentServiceTier: response.serviceTier ?? null, } } @@ -242,6 +244,7 @@ export class CodexAcpClient { sessionId: request.sessionId, currentModelId: currentModelId, models: codexModels, + currentServiceTier: response.serviceTier ?? null, thread: response.thread, }; } @@ -271,6 +274,7 @@ export class CodexAcpClient { sessionId: response.thread.id, currentModelId: currentModelId, models: codexModels, + currentServiceTier: response.serviceTier ?? null, }; } @@ -394,6 +398,7 @@ export class CodexAcpClient { request: acp.PromptRequest, agentMode: AgentMode, modelId: ModelId, + serviceTier: ServiceTier | null, disableSummary: boolean, cwd: string, ): Promise { @@ -412,6 +417,7 @@ export class CodexAcpClient { cwd: null, effort: effort, model: modelId.model, + serviceTier: serviceTier, }); } @@ -605,6 +611,7 @@ export type SessionMetadata = { sessionId: string, currentModelId: string, models: Model[], + currentServiceTier?: ServiceTier | null, } export type SessionMetadataWithThread = SessionMetadata & { diff --git a/src/CodexAcpServer.ts b/src/CodexAcpServer.ts index 1a80d6c..2ee88ae 100644 --- a/src/CodexAcpServer.ts +++ b/src/CodexAcpServer.ts @@ -32,6 +32,14 @@ import { createFileChangeUpdate, createMcpToolCallUpdate, } from "./CodexToolCallMapper"; +import { + createFastModeConfigOption, + FAST_MODE_CONFIG_ID, + FAST_MODE_OFF, + FAST_MODE_ON, + modelSupportsFast, + resolveFastServiceTier, +} from "./FastModeConfig"; export interface SessionState { sessionId: string, @@ -46,6 +54,8 @@ export interface SessionState { rateLimits: RateLimitsMap | null; account: Account | null; cwd: string; + fastModeEnabled: boolean; + currentModelSupportsFast: boolean; sessionMcpServers?: Array; } @@ -163,6 +173,7 @@ export class CodexAcpServer implements acp.Agent { const {sessionId, currentModelId, models} = sessionMetadata; const sessionMcpServers = this.resolveSessionMcpServers(requestedMcpServers, "sessionId" in request); const currentModel = this.findCurrentModel(models, currentModelId); + const currentModelSupportsFast = modelSupportsFast(currentModel); const sessionState: SessionState = { sessionId: sessionId, currentModelId: currentModelId, @@ -176,6 +187,8 @@ export class CodexAcpServer implements acp.Agent { rateLimits: null, account: account, cwd: request.cwd, + fastModeEnabled: sessionMetadata.currentServiceTier === "fast", + currentModelSupportsFast: currentModelSupportsFast, sessionMcpServers: sessionMcpServers, } this.sessions.set(sessionId, sessionState); @@ -221,7 +234,8 @@ export class CodexAcpServer implements acp.Agent { }); return { models: modelState, - modes: modeState + modes: modeState, + configOptions: this.createSessionConfigOptions(this.getSessionState(sessionId)), }; } @@ -236,7 +250,8 @@ export class CodexAcpServer implements acp.Agent { }); return { models: modelState, - modes: modeState + modes: modeState, + configOptions: this.createSessionConfigOptions(this.getSessionState(sessionId)), }; } @@ -261,7 +276,8 @@ export class CodexAcpServer implements acp.Agent { return { sessionId: sessionId, models: modelState, - modes: modeState + modes: modeState, + configOptions: this.createSessionConfigOptions(this.getSessionState(sessionId)), }; } @@ -302,6 +318,28 @@ export class CodexAcpServer implements acp.Agent { return {}; } + async setSessionConfigOption(params: acp.SetSessionConfigOptionRequest): Promise { + logger.log("Set session config option requested", { + sessionId: params.sessionId, + configId: params.configId, + }); + const sessionState = this.sessions.get(params.sessionId); + if (!sessionState) throw new Error(`Session ${params.sessionId} not found`); + + if (params.configId !== FAST_MODE_CONFIG_ID || ("type" in params && params.type === "boolean")) { + throw RequestError.invalidParams(); + } + + if (params.value !== FAST_MODE_ON && params.value !== FAST_MODE_OFF) { + throw RequestError.invalidParams(); + } + + sessionState.fastModeEnabled = params.value === FAST_MODE_ON; + return { + configOptions: this.createSessionConfigOptions(sessionState), + }; + } + async unstable_setSessionModel(params: acp.SetSessionModelRequest): Promise { logger.log("Set session model requested", { sessionId: params.sessionId, @@ -337,10 +375,17 @@ export class CodexAcpServer implements acp.Agent { sessionState.currentModelId = ModelId.fromComponents(model, reasoningEffort).toString(); sessionState.supportedReasoningEfforts = model.supportedReasoningEfforts; sessionState.supportedInputModalities = model.inputModalities; + sessionState.currentModelSupportsFast = modelSupportsFast(model); return {}; } + private createSessionConfigOptions(sessionState: SessionState): Array { + return [ + createFastModeConfigOption(sessionState.fastModeEnabled), + ]; + } + private publishAvailableCommandsAsync(sessionId: string) { void this.availableCommands.publish(sessionId); } @@ -388,6 +433,7 @@ export class CodexAcpServer implements acp.Agent { const {sessionId, currentModelId, models, thread} = sessionMetadata; const sessionMcpServers = this.resolveSessionMcpServers(requestedMcpServers, true); const currentModel = this.findCurrentModel(models, currentModelId); + const currentModelSupportsFast = modelSupportsFast(currentModel); const sessionState: SessionState = { sessionId: sessionId, currentModelId: currentModelId, @@ -401,6 +447,8 @@ export class CodexAcpServer implements acp.Agent { rateLimits: null, account: account, cwd: request.cwd, + fastModeEnabled: sessionMetadata.currentServiceTier === "fast", + currentModelSupportsFast: currentModelSupportsFast, sessionMcpServers: sessionMcpServers, }; this.sessions.set(sessionId, sessionState); @@ -767,8 +815,12 @@ export class CodexAcpServer implements acp.Agent { throw RequestError.invalidRequest("The current model does not support image input"); } const agentMode = sessionState.agentMode; + const serviceTier = resolveFastServiceTier( + sessionState.fastModeEnabled, + sessionState.currentModelSupportsFast, + ); const turnCompleted = await this.runWithProcessCheck( - () => this.codexAcpClient.sendPrompt(params, agentMode, modelId, disableSummary, sessionState.cwd)); + () => this.codexAcpClient.sendPrompt(params, agentMode, modelId, serviceTier, disableSummary, sessionState.cwd)); // Check if turn was interrupted (cancelled) if (turnCompleted.turn.status === "interrupted") { diff --git a/src/FastModeConfig.ts b/src/FastModeConfig.ts new file mode 100644 index 0000000..4187aaa --- /dev/null +++ b/src/FastModeConfig.ts @@ -0,0 +1,40 @@ +import type {SessionConfigOption} from "@agentclientprotocol/sdk"; +import type {ServiceTier} from "./app-server"; +import type {Model} from "./app-server/v2"; + +export const FAST_MODE_CONFIG_ID = "fast-mode"; +export const FAST_MODE_ON = "on"; +export const FAST_MODE_OFF = "off"; + +const FAST_MODE_DESCRIPTION = "1.5x speed, increased usage"; + +export function modelSupportsFast(model: Model | undefined): boolean { + return model?.additionalSpeedTiers?.includes("fast") ?? false; +} + +export function resolveFastServiceTier(fastModeEnabled: boolean, currentModelSupportsFast: boolean): ServiceTier | null { + return fastModeEnabled && currentModelSupportsFast ? "fast" : null; +} + +export function createFastModeConfigOption(fastModeEnabled: boolean): SessionConfigOption { + return { + id: FAST_MODE_CONFIG_ID, + name: "Fast mode", + description: FAST_MODE_DESCRIPTION, + category: FAST_MODE_CONFIG_ID, + type: "select", + currentValue: fastModeEnabled ? FAST_MODE_ON : FAST_MODE_OFF, + options: [ + { + value: FAST_MODE_OFF, + name: "Off", + description: "Default speed, normal usage", + }, + { + value: FAST_MODE_ON, + name: "On", + description: FAST_MODE_DESCRIPTION, + }, + ], + }; +} diff --git a/src/__tests__/CodexACPAgent/data/send-attachments-turn-start.json b/src/__tests__/CodexACPAgent/data/send-attachments-turn-start.json index 9c685ec..c4ab4c5 100644 --- a/src/__tests__/CodexACPAgent/data/send-attachments-turn-start.json +++ b/src/__tests__/CodexACPAgent/data/send-attachments-turn-start.json @@ -56,7 +56,8 @@ "personality": null, "cwd": "cwd", "effort": "effort", - "model": "model" + "model": "model", + "serviceTier": null } } { diff --git a/src/__tests__/CodexACPAgent/fast-mode-config.test.ts b/src/__tests__/CodexACPAgent/fast-mode-config.test.ts new file mode 100644 index 0000000..0e10369 --- /dev/null +++ b/src/__tests__/CodexACPAgent/fast-mode-config.test.ts @@ -0,0 +1,149 @@ +import {describe, expect, it, vi} from "vitest"; +import { + createCodexMockTestFixture, + createTestModel, + mockPromptTurn, + setupPromptTestSession, +} from "../acp-test-utils"; +import { + createFastModeConfigOption, + FAST_MODE_CONFIG_ID, + FAST_MODE_OFF, + FAST_MODE_ON, +} from "../../FastModeConfig"; + +describe("Fast mode session config", () => { + async function createSession(currentServiceTier: "fast" | "flex" | null = null) { + const fixture = createCodexMockTestFixture(); + const codexAcpAgent = fixture.getCodexAcpAgent(); + const codexAcpClient = fixture.getCodexAcpClient(); + const fastModel = createTestModel({ + id: "fast-model", + additionalSpeedTiers: ["fast"], + }); + + vi.spyOn(codexAcpClient, "authRequired").mockResolvedValue(false); + vi.spyOn(codexAcpClient, "getAccount").mockResolvedValue({account: null, requiresOpenaiAuth: false}); + vi.spyOn(codexAcpClient, "newSession").mockResolvedValue({ + sessionId: "session-id", + currentModelId: "fast-model[medium]", + models: [fastModel], + currentServiceTier, + }); + + const response = await codexAcpAgent.newSession({cwd: "/test/cwd", mcpServers: []}); + return {fixture, codexAcpAgent, codexAcpClient, response}; + } + + function setupPromptSession(fastModeEnabled: boolean, currentModelSupportsFast: boolean) { + const {mockFixture, turnStartSpy} = setupPromptTestSession({ + sessionId: "session-id", + currentModelId: "fast-model[medium]", + fastModeEnabled, + currentModelSupportsFast, + }); + return {codexAcpAgent: mockFixture.getCodexAcpAgent(), turnStartSpy}; + } + + it("returns the Fast mode config option defaulted to Off for new sessions", async () => { + const {response} = await createSession(); + + expect(response.configOptions).toEqual([createFastModeConfigOption(false)]); + }); + + it("initializes Fast mode as On when the app-server session tier is fast", async () => { + const {response, codexAcpAgent} = await createSession("fast"); + + expect(response.configOptions).toEqual([createFastModeConfigOption(true)]); + expect(codexAcpAgent.getSessionState("session-id").fastModeEnabled).toBe(true); + }); + + it("toggles Fast mode through session config options", async () => { + const {codexAcpAgent} = await createSession(); + + const onResponse = await codexAcpAgent.setSessionConfigOption({ + sessionId: "session-id", + configId: FAST_MODE_CONFIG_ID, + value: FAST_MODE_ON, + }); + expect(onResponse.configOptions).toEqual([createFastModeConfigOption(true)]); + expect(codexAcpAgent.getSessionState("session-id").fastModeEnabled).toBe(true); + + const offResponse = await codexAcpAgent.setSessionConfigOption({ + sessionId: "session-id", + configId: FAST_MODE_CONFIG_ID, + value: FAST_MODE_OFF, + }); + expect(offResponse.configOptions).toEqual([createFastModeConfigOption(false)]); + expect(codexAcpAgent.getSessionState("session-id").fastModeEnabled).toBe(false); + }); + + it("rejects unknown Fast mode config ids and values", async () => { + const {codexAcpAgent} = await createSession(); + + await expect(codexAcpAgent.setSessionConfigOption({ + sessionId: "session-id", + configId: "unknown-config", + value: FAST_MODE_ON, + })).rejects.toThrow(); + + await expect(codexAcpAgent.setSessionConfigOption({ + sessionId: "session-id", + configId: FAST_MODE_CONFIG_ID, + value: "turbo", + })).rejects.toThrow(); + }); + + it("sends the fast service tier when Fast mode is enabled for a fast-capable model", async () => { + const {codexAcpAgent, turnStartSpy} = setupPromptSession(true, true); + + await codexAcpAgent.prompt({sessionId: "session-id", prompt: [{type: "text", text: "test"}]}); + + expect(turnStartSpy).toHaveBeenCalledWith(expect.objectContaining({ + serviceTier: "fast", + })); + }); + + it("explicitly clears service tier when Fast mode is off", async () => { + const {codexAcpAgent, turnStartSpy} = setupPromptSession(false, true); + + await codexAcpAgent.prompt({sessionId: "session-id", prompt: [{type: "text", text: "test"}]}); + + expect(turnStartSpy).toHaveBeenCalledWith(expect.objectContaining({ + serviceTier: null, + })); + }); + + it("explicitly clears service tier when the selected model does not support fast", async () => { + const {codexAcpAgent, turnStartSpy} = setupPromptSession(true, false); + + await codexAcpAgent.prompt({sessionId: "session-id", prompt: [{type: "text", text: "test"}]}); + + expect(turnStartSpy).toHaveBeenCalledWith(expect.objectContaining({ + serviceTier: null, + })); + }); + + it("keeps Fast mode selected across model switches but stops applying it for non-fast models", async () => { + const {codexAcpAgent, codexAcpClient, fixture} = await createSession("fast"); + const slowModel = createTestModel({id: "slow-model"}); + vi.spyOn(codexAcpClient, "fetchAvailableModels").mockResolvedValue([slowModel]); + const turnStartSpy = mockPromptTurn(fixture, "session-id"); + + await codexAcpAgent.unstable_setSessionModel({ + sessionId: "session-id", + modelId: "slow-model[medium]", + }); + + const sessionState = codexAcpAgent.getSessionState("session-id"); + expect(sessionState.fastModeEnabled).toBe(true); + expect(sessionState.currentModelSupportsFast).toBe(false); + + await codexAcpAgent.prompt({sessionId: "session-id", prompt: [{type: "text", text: "test"}]}); + + expect(turnStartSpy).toHaveBeenCalledWith(expect.objectContaining({ + model: "slow-model", + serviceTier: null, + })); + }); +}); diff --git a/src/__tests__/acp-test-utils.ts b/src/__tests__/acp-test-utils.ts index a162921..ec510dc 100644 --- a/src/__tests__/acp-test-utils.ts +++ b/src/__tests__/acp-test-utils.ts @@ -10,6 +10,7 @@ import fs from "node:fs"; import os from "node:os"; import {AgentMode} from "../AgentMode"; import {expect, vi} from "vitest"; +import type {Model, ReasoningEffortOption} from "../app-server/v2"; export type MethodCallEvent = { method: string; args: any[] }; @@ -322,10 +323,73 @@ export function createTestSessionState(overrides?: Partial): Sessi supportedReasoningEfforts: [], supportedInputModalities: ["text", "image"], agentMode: AgentMode.DEFAULT_AGENT_MODE, + fastModeEnabled: false, + currentModelSupportsFast: false, ...overrides, }; } +export function createTestModel(overrides?: Partial): Model { + const id = overrides?.id ?? "model-id"; + const defaultEffort: ReasoningEffortOption = {reasoningEffort: "medium", description: "Balanced"}; + return { + id, + model: id, + upgrade: null, + upgradeInfo: null, + availabilityNux: null, + displayName: id, + description: `${id} model`, + hidden: false, + supportedReasoningEfforts: [defaultEffort], + defaultReasoningEffort: "medium", + inputModalities: ["text", "image"], + supportsPersonality: false, + additionalSpeedTiers: [], + isDefault: true, + ...overrides, + }; +} + +export function setupPromptTestSession(sessionOverrides?: Partial) { + const mockFixture = createCodexMockTestFixture(); + const sessionState = createTestSessionState(sessionOverrides); + + vi.spyOn(mockFixture.getCodexAcpAgent(), "getSessionState").mockReturnValue(sessionState); + const turnStartSpy = mockPromptTurn(mockFixture, sessionState.sessionId); + + return {mockFixture, sessionState, turnStartSpy}; +} + +export function mockPromptTurn(fixture: CodexMockTestFixture, sessionId: string) { + const codexAppServerClient = fixture.getCodexAppServerClient(); + const turnStartSpy = vi.spyOn(codexAppServerClient, "turnStart").mockResolvedValue({ + turn: { + id: "turn-id", + items: [], + status: "inProgress", + error: null, + startedAt: null, + completedAt: null, + durationMs: null, + } + }); + vi.spyOn(codexAppServerClient, "awaitTurnCompleted").mockResolvedValue({ + threadId: sessionId, + turn: { + id: "turn-id", + items: [], + status: "completed", + error: null, + startedAt: null, + completedAt: null, + durationMs: null, + } + }); + + return turnStartSpy; +} + export async function setupPromptAndSendNotifications( fixture: CodexMockTestFixture, sessionId: string,