diff --git a/libs/openai-assistant/src/lib/chat/chat.service.spec.ts b/libs/openai-assistant/src/lib/chat/chat.service.spec.ts index 83012c3..e5c6c57 100644 --- a/libs/openai-assistant/src/lib/chat/chat.service.spec.ts +++ b/libs/openai-assistant/src/lib/chat/chat.service.spec.ts @@ -5,14 +5,12 @@ import { AiModule } from './../ai/ai.module'; import { ChatModule } from './chat.module'; import { ChatService } from './chat.service'; import { ChatHelpers } from './chat.helpers'; -import { RunService } from '../run'; import { ChatCallDto } from './chat.model'; import { AssistantStream } from 'openai/lib/AssistantStream'; describe('ChatService', () => { let chatService: ChatService; let chatbotHelpers: ChatHelpers; - let runService: RunService; beforeEach(async () => { const moduleRef = await Test.createTestingModule({ @@ -21,23 +19,19 @@ describe('ChatService', () => { chatService = moduleRef.get(ChatService); chatbotHelpers = moduleRef.get(ChatHelpers); - runService = moduleRef.get(RunService); jest .spyOn(chatbotHelpers, 'getAnswer') .mockReturnValue(Promise.resolve('Hello response') as Promise); - jest.spyOn(runService, 'resolve').mockReturnThis(); jest .spyOn(chatService.threads.messages, 'create') .mockReturnValue({} as APIPromise); jest.spyOn(chatService, 'assistantStream').mockReturnValue({ - finalRun(): Promise { - return Promise.resolve({} as Run); - }, - } as AssistantStream); + finalRun: jest.fn(), + } as unknown as Promise); }); it('should be defined', () => { diff --git a/libs/openai-assistant/src/lib/chat/chat.service.ts b/libs/openai-assistant/src/lib/chat/chat.service.ts index fc30b91..84e415e 100644 --- a/libs/openai-assistant/src/lib/chat/chat.service.ts +++ b/libs/openai-assistant/src/lib/chat/chat.service.ts @@ -7,7 +7,7 @@ import { ChatCallResponseDto, } from './chat.model'; import { ChatHelpers } from './chat.helpers'; -import { MessageCreateParams } from 'openai/resources/beta/threads'; +import { MessageCreateParams, Run } from 'openai/resources/beta/threads'; import { AssistantStream } from 'openai/lib/AssistantStream'; import { assistantStreamEventHandler } from '../stream/stream.utils'; @@ -36,12 +36,8 @@ export class ChatService { await this.threads.messages.create(threadId, message); - const assistantId = - payload?.assistantId || process.env['ASSISTANT_ID'] || ''; - const run = this.assistantStream(assistantId, threadId, callbacks); - const finalRun = await run.finalRun(); - - await this.runService.resolve(finalRun, true, callbacks); + const runner = await this.assistantStream(payload, callbacks); + const finalRun = await runner.finalRun(); return { content: await this.chatbotHelpers.getAnswer(finalRun), @@ -49,14 +45,20 @@ export class ChatService { }; } - assistantStream( - assistantId: string, - threadId: string, + async assistantStream( + payload: ChatCallDto, callbacks?: ChatCallCallbacks, - ): AssistantStream { - const runner = this.threads.runs.createAndStream(threadId, { - assistant_id: assistantId, - }); + ): Promise { + const assistant_id = + payload?.assistantId || process.env['ASSISTANT_ID'] || ''; + + const runner = this.threads.runs + .createAndStream(payload.threadId, { assistant_id }) + .on('event', event => { + if (event.event === 'thread.run.requires_action') { + this.runService.submitAction(event.data, callbacks); + } + }); return assistantStreamEventHandler(runner, callbacks); } diff --git a/libs/openai-assistant/src/lib/run/run.service.spec.ts b/libs/openai-assistant/src/lib/run/run.service.spec.ts index a235884..acdda6e 100644 --- a/libs/openai-assistant/src/lib/run/run.service.spec.ts +++ b/libs/openai-assistant/src/lib/run/run.service.spec.ts @@ -33,80 +33,6 @@ describe('RunService', () => { expect(runService).toBeDefined(); }); - describe('continueRun', () => { - it('should call threads.runs.retrieve', async () => { - const spyOnRetrieve = jest - .spyOn(aiService.provider.beta.threads.runs, 'retrieve') - .mockReturnThis(); - const run = { thread_id: '1', id: '123' } as Run; - - await runService.continueRun(run); - - expect(spyOnRetrieve).toHaveBeenCalled(); - }); - - it('should wait for timeout', async () => { - const run = { thread_id: '1', id: '123' } as Run; - const spyOnTimeout = jest.spyOn(global, 'setTimeout'); - - await runService.continueRun(run); - - expect(spyOnTimeout).toHaveBeenCalledWith( - expect.any(Function), - runService.timeout, - ); - }); - }); - - describe('resolve', () => { - it('should call continueRun', async () => { - const spyOnContinueRun = jest - .spyOn(runService, 'continueRun') - .mockResolvedValue({} as Run); - const run = { status: 'requires_action' } as Run; - - await runService.resolve(run, false); - - expect(spyOnContinueRun).toHaveBeenCalled(); - }); - - it('should call submitAction', async () => { - const spyOnSubmitAction = jest - .spyOn(runService, 'submitAction') - .mockResolvedValue(); - const run = { - status: 'requires_action', - required_action: { type: 'submit_tool_outputs' }, - } as Run; - - await runService.resolve(run, false); - - expect(spyOnSubmitAction).toHaveBeenCalled(); - }); - - it('should call default', async () => { - const spyOnContinueRun = jest - .spyOn(runService, 'continueRun') - .mockResolvedValue({} as Run); - const run = { status: 'unknown' } as unknown as Run; - - await runService.resolve(run, false); - - expect(spyOnContinueRun).toHaveBeenCalled(); - }); - - it('should not invoke action when status is cancelling', async () => { - const spyOnContinueRun = jest - .spyOn(runService, 'continueRun') - .mockResolvedValue({} as Run); - const run = { status: 'cancelling' } as unknown as Run; - - await runService.resolve(run, false); - - expect(spyOnContinueRun).not.toHaveBeenCalled(); - }); - }); - describe('submitAction', () => { it('should call submitToolOutputsStream', async () => { const spyOnSubmitToolOutputsStream = jest diff --git a/libs/openai-assistant/src/lib/run/run.service.ts b/libs/openai-assistant/src/lib/run/run.service.ts index 92e1466..e3c0e94 100644 --- a/libs/openai-assistant/src/lib/run/run.service.ts +++ b/libs/openai-assistant/src/lib/run/run.service.ts @@ -1,10 +1,5 @@ import { Injectable } from '@nestjs/common'; -import { - Run, - RunSubmitToolOutputsParams, - Text, - TextDelta, -} from 'openai/resources/beta/threads'; +import { Run, RunSubmitToolOutputsParams } from 'openai/resources/beta/threads'; import { AiService } from '../ai'; import { AgentService } from '../agent'; import { ChatCallCallbacks } from '../chat'; @@ -13,43 +8,12 @@ import { assistantStreamEventHandler } from '../stream/stream.utils'; @Injectable() export class RunService { private readonly threads = this.aiService.provider.beta.threads; - timeout = 2000; - isRunning = true; constructor( private readonly aiService: AiService, private readonly agentsService: AgentService, ) {} - async continueRun(run: Run): Promise { - await new Promise(resolve => setTimeout(resolve, this.timeout)); - return this.threads.runs.retrieve(run.thread_id, run.id); - } - - async resolve( - run: Run, - runningStatus: boolean, - callbacks?: ChatCallCallbacks, - ): Promise { - while (this.isRunning) - switch (run.status) { - case 'cancelling': - case 'cancelled': - case 'failed': - case 'expired': - case 'completed': - return; - case 'requires_action': - await this.submitAction(run, callbacks); - run = await this.continueRun(run); - this.isRunning = runningStatus; - continue; - default: - run = await this.continueRun(run); - this.isRunning = runningStatus; - } - } - async submitAction(run: Run, callbacks?: ChatCallCallbacks): Promise { if (run.required_action?.type !== 'submit_tool_outputs') { return;