Skip to content

Commit

Permalink
feat(openai-assistant): handled "requires_action" event
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianmusial committed Mar 29, 2024
1 parent 059e283 commit 209abbb
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 133 deletions.
10 changes: 2 additions & 8 deletions libs/openai-assistant/src/lib/chat/chat.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import { AiModule } from './../ai/ai.module';
import { ChatModule } from './chat.module';
import { ChatService } from './chat.service';
import { ChatHelpers } from './chat.helpers';
import { RunService } from '../run';
import { ChatCallDto } from './chat.model';
import { AssistantStream } from 'openai/lib/AssistantStream';

describe('ChatService', () => {
let chatService: ChatService;
let chatbotHelpers: ChatHelpers;
let runService: RunService;

beforeEach(async () => {
const moduleRef = await Test.createTestingModule({
Expand All @@ -21,23 +19,19 @@ describe('ChatService', () => {

chatService = moduleRef.get<ChatService>(ChatService);
chatbotHelpers = moduleRef.get<ChatHelpers>(ChatHelpers);
runService = moduleRef.get<RunService>(RunService);

jest
.spyOn(chatbotHelpers, 'getAnswer')
.mockReturnValue(Promise.resolve('Hello response') as Promise<string>);

jest.spyOn(runService, 'resolve').mockReturnThis();

jest
.spyOn(chatService.threads.messages, 'create')
.mockReturnValue({} as APIPromise<Message>);

jest.spyOn(chatService, 'assistantStream').mockReturnValue({
finalRun(): Promise<Run> {
return Promise.resolve({} as Run);
},
} as AssistantStream);
finalRun: jest.fn(),
} as unknown as Promise<AssistantStream>);
});

it('should be defined', () => {
Expand Down
30 changes: 16 additions & 14 deletions libs/openai-assistant/src/lib/chat/chat.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
ChatCallResponseDto,
} from './chat.model';
import { ChatHelpers } from './chat.helpers';
import { MessageCreateParams } from 'openai/resources/beta/threads';
import { MessageCreateParams, Run } from 'openai/resources/beta/threads';
import { AssistantStream } from 'openai/lib/AssistantStream';
import { assistantStreamEventHandler } from '../stream/stream.utils';

Expand Down Expand Up @@ -36,27 +36,29 @@ export class ChatService {

await this.threads.messages.create(threadId, message);

const assistantId =
payload?.assistantId || process.env['ASSISTANT_ID'] || '';
const run = this.assistantStream(assistantId, threadId, callbacks);
const finalRun = await run.finalRun();

await this.runService.resolve(finalRun, true, callbacks);
const runner = await this.assistantStream(payload, callbacks);
const finalRun = await runner.finalRun();

return {
content: await this.chatbotHelpers.getAnswer(finalRun),
threadId,
};
}

assistantStream(
assistantId: string,
threadId: string,
async assistantStream(
payload: ChatCallDto,
callbacks?: ChatCallCallbacks,
): AssistantStream {
const runner = this.threads.runs.createAndStream(threadId, {
assistant_id: assistantId,
});
): Promise<AssistantStream> {
const assistant_id =
payload?.assistantId || process.env['ASSISTANT_ID'] || '';

const runner = this.threads.runs
.createAndStream(payload.threadId, { assistant_id })
.on('event', event => {
if (event.event === 'thread.run.requires_action') {
this.runService.submitAction(event.data, callbacks);
}
});

return assistantStreamEventHandler<AssistantStream>(runner, callbacks);
}
Expand Down
74 changes: 0 additions & 74 deletions libs/openai-assistant/src/lib/run/run.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,80 +33,6 @@ describe('RunService', () => {
expect(runService).toBeDefined();
});

describe('continueRun', () => {
it('should call threads.runs.retrieve', async () => {
const spyOnRetrieve = jest
.spyOn(aiService.provider.beta.threads.runs, 'retrieve')
.mockReturnThis();
const run = { thread_id: '1', id: '123' } as Run;

await runService.continueRun(run);

expect(spyOnRetrieve).toHaveBeenCalled();
});

it('should wait for timeout', async () => {
const run = { thread_id: '1', id: '123' } as Run;
const spyOnTimeout = jest.spyOn(global, 'setTimeout');

await runService.continueRun(run);

expect(spyOnTimeout).toHaveBeenCalledWith(
expect.any(Function),
runService.timeout,
);
});
});

describe('resolve', () => {
it('should call continueRun', async () => {
const spyOnContinueRun = jest
.spyOn(runService, 'continueRun')
.mockResolvedValue({} as Run);
const run = { status: 'requires_action' } as Run;

await runService.resolve(run, false);

expect(spyOnContinueRun).toHaveBeenCalled();
});

it('should call submitAction', async () => {
const spyOnSubmitAction = jest
.spyOn(runService, 'submitAction')
.mockResolvedValue();
const run = {
status: 'requires_action',
required_action: { type: 'submit_tool_outputs' },
} as Run;

await runService.resolve(run, false);

expect(spyOnSubmitAction).toHaveBeenCalled();
});

it('should call default', async () => {
const spyOnContinueRun = jest
.spyOn(runService, 'continueRun')
.mockResolvedValue({} as Run);
const run = { status: 'unknown' } as unknown as Run;

await runService.resolve(run, false);

expect(spyOnContinueRun).toHaveBeenCalled();
});

it('should not invoke action when status is cancelling', async () => {
const spyOnContinueRun = jest
.spyOn(runService, 'continueRun')
.mockResolvedValue({} as Run);
const run = { status: 'cancelling' } as unknown as Run;

await runService.resolve(run, false);

expect(spyOnContinueRun).not.toHaveBeenCalled();
});
});

describe('submitAction', () => {
it('should call submitToolOutputsStream', async () => {
const spyOnSubmitToolOutputsStream = jest
Expand Down
38 changes: 1 addition & 37 deletions libs/openai-assistant/src/lib/run/run.service.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import { Injectable } from '@nestjs/common';
import {
Run,
RunSubmitToolOutputsParams,
Text,
TextDelta,
} from 'openai/resources/beta/threads';
import { Run, RunSubmitToolOutputsParams } from 'openai/resources/beta/threads';
import { AiService } from '../ai';
import { AgentService } from '../agent';
import { ChatCallCallbacks } from '../chat';
Expand All @@ -13,43 +8,12 @@ import { assistantStreamEventHandler } from '../stream/stream.utils';
@Injectable()
export class RunService {
private readonly threads = this.aiService.provider.beta.threads;
timeout = 2000;
isRunning = true;

constructor(
private readonly aiService: AiService,
private readonly agentsService: AgentService,
) {}

async continueRun(run: Run): Promise<Run> {
await new Promise(resolve => setTimeout(resolve, this.timeout));
return this.threads.runs.retrieve(run.thread_id, run.id);
}

async resolve(
run: Run,
runningStatus: boolean,
callbacks?: ChatCallCallbacks,
): Promise<void> {
while (this.isRunning)
switch (run.status) {
case 'cancelling':
case 'cancelled':
case 'failed':
case 'expired':
case 'completed':
return;
case 'requires_action':
await this.submitAction(run, callbacks);
run = await this.continueRun(run);
this.isRunning = runningStatus;
continue;
default:
run = await this.continueRun(run);
this.isRunning = runningStatus;
}
}

async submitAction(run: Run, callbacks?: ChatCallCallbacks): Promise<void> {
if (run.required_action?.type !== 'submit_tool_outputs') {
return;
Expand Down

0 comments on commit 209abbb

Please sign in to comment.