Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: run resolver for multiple function callings #53

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.dist
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
OPENAI_API_KEY=
# Assistant ID - leave it empty if you don't have an assistant yet
ASSISTANT_ID=
ASSISTANT_IS_LOGGER_ENABLED=

# Agents:
# -------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion apps/api/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ async function bootstrap() {
const globalPrefix = 'api';
const config = new DocumentBuilder()
.setTitle('@boldare/openai-assistant')
.setVersion('1.0.1')
.setVersion('1.0.2')
.build();
const document = SwaggerModule.createDocument(app, config);

Expand Down
2 changes: 1 addition & 1 deletion libs/openai-assistant/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@boldare/openai-assistant",
"description": "NestJS library for building chatbot solutions based on the OpenAI Assistant API",
"version": "1.0.1",
"version": "1.0.2",
"private": false,
"dependencies": {
"tslib": "^2.3.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ describe('AssistantService', () => {
jest
.spyOn(aiService.provider.beta.assistants, 'update')
.mockRejectedValueOnce('error');
jest.spyOn(assistantService, 'create').mockResolvedValueOnce(undefined);
jest.spyOn(assistantService, 'create').mockResolvedValueOnce({} as Assistant);

await assistantService.init();

Expand All @@ -97,7 +97,7 @@ describe('AssistantService', () => {
.spyOn(configService, 'get')
.mockReturnValue({ ...assistantConfigMock, id: '' });

jest.spyOn(assistantService, 'create').mockResolvedValueOnce(undefined);
jest.spyOn(assistantService, 'create').mockResolvedValueOnce({} as Assistant);

await assistantService.init();

Expand Down
9 changes: 6 additions & 3 deletions libs/openai-assistant/src/lib/assistant/assistant.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export class AssistantService {
};
}

async init(): Promise<void> {
async init(): Promise<Assistant> {
const { id, options } = this.assistantConfig.get();

if (!id) {
Expand All @@ -43,16 +43,17 @@ export class AssistantService {
this.getParams(),
options,
);
return this.assistant;
} catch (e) {
await this.create();
return await this.create();
}
}

async update(params: Partial<AssistantCreateParams>): Promise<void> {
this.assistant = await this.assistants.update(this.assistant.id, params);
}

async create(): Promise<void> {
async create(): Promise<Assistant> {
const { options } = this.assistantConfig.get();
const params = this.getParams();
this.assistant = await this.assistants.create(params, options);
Expand All @@ -63,6 +64,8 @@ export class AssistantService {

this.logger.log(`Created new assistant (${this.assistant.id})`);
await this.assistantMemoryService.saveAssistantId(this.assistant.id);

return this.assistant;
}

async updateFiles(fileNames?: string[]): Promise<Assistant> {
Expand Down
96 changes: 54 additions & 42 deletions libs/openai-assistant/src/lib/chat/chat.gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,54 @@ export class ChatGateway implements OnGatewayConnection {
this.logger = new Logger(ChatGateway.name);
}

log(message: string): void {
try {
const isLoggerEnabled: string = JSON.parse(
(process.env['ASSISTANT_IS_LOGGER_ENABLED'] || 'false').toLowerCase(),
);

if (isLoggerEnabled) {
this.logger.log(message);
}
} catch (error) {
this.logger.error('"ASSISTANT_IS_LOGGER_ENABLED" should be boolean');
}
}

async handleConnection() {
this.logger.log('Client connected');
this.log('Client connected');
}

getCallbacks(socketId: string): ChatCallCallbacks {
return {
[ChatEvents.MessageCreated]: this.emitMessageCreated.bind(this, socketId),
[ChatEvents.MessageDelta]: this.emitMessageDelta.bind(this, socketId),
[ChatEvents.MessageDone]: this.emitMessageDone.bind(this, socketId),
[ChatEvents.TextCreated]: this.emitTextCreated.bind(this, socketId),
[ChatEvents.TextDelta]: this.emitTextDelta.bind(this, socketId),
[ChatEvents.TextDone]: this.emitTextDone.bind(this, socketId),
[ChatEvents.MessageCreated]: eventData =>
this.emitMessageCreated(socketId, eventData),
[ChatEvents.MessageDelta]: eventData =>
this.emitMessageDelta(socketId, eventData),
[ChatEvents.MessageDone]: eventData =>
this.emitMessageDone(socketId, eventData),
[ChatEvents.TextCreated]: eventData =>
this.emitTextCreated(socketId, eventData),
[ChatEvents.TextDelta]: eventData =>
this.emitTextDelta(socketId, eventData),
[ChatEvents.TextDone]: eventData =>
this.emitTextDone(socketId, eventData),
[ChatEvents.ToolCallCreated]: this.emitToolCallCreated.bind(
this,
socketId,
),
[ChatEvents.ToolCallDelta]: this.emitToolCallDelta.bind(this, socketId),
[ChatEvents.ToolCallDone]: this.emitToolCallDone.bind(this, socketId),
[ChatEvents.ImageFileDone]: this.emitImageFileDone.bind(this, socketId),
[ChatEvents.RunStepCreated]: this.emitRunStepCreated.bind(this, socketId),
[ChatEvents.RunStepDelta]: this.emitRunStepDelta.bind(this, socketId),
[ChatEvents.RunStepDone]: this.emitRunStepDone.bind(this, socketId),
[ChatEvents.ToolCallDelta]: eventData =>
this.emitToolCallDelta(socketId, eventData),
[ChatEvents.ToolCallDone]: eventData =>
this.emitToolCallDone(socketId, eventData),
[ChatEvents.ImageFileDone]: eventData =>
this.emitImageFileDone(socketId, eventData),
[ChatEvents.RunStepCreated]: eventData =>
this.emitRunStepCreated(socketId, eventData),
[ChatEvents.RunStepDelta]: eventData =>
this.emitRunStepDelta(socketId, eventData),
[ChatEvents.RunStepDone]: eventData =>
this.emitRunStepDone(socketId, eventData),
};
}

Expand All @@ -69,15 +95,15 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() request: ChatCallDto,
@ConnectedSocket() socket: Socket,
) {
this.logger.log(
this.log(
`Socket "${ChatEvents.CallStart}" | threadId ${request.threadId} | files: ${request?.file_ids?.join(', ')} | content: ${request.content}`,
);

const callbacks: ChatCallCallbacks = this.getCallbacks(socket.id);
const message = await this.chatsService.call(request, callbacks);

this.server?.to(socket.id).emit(ChatEvents.CallDone, message);
this.logger.log(
this.log(
`Socket "${ChatEvents.CallDone}" | threadId ${message.threadId} | content: ${message.content}`,
);
}
Expand All @@ -87,7 +113,7 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: MessageCreatedPayload,
) {
this.server.to(socketId).emit(ChatEvents.MessageCreated, data);
this.logger.log(
this.log(
`Socket "${ChatEvents.MessageCreated}" | threadId: ${data.message.thread_id}`,
);
}
Expand All @@ -97,7 +123,7 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: MessageDeltaPayload,
) {
this.server.to(socketId).emit(ChatEvents.MessageDelta, data);
this.logger.log(
this.log(
`Socket "${ChatEvents.MessageDelta}" | threadId: ${data.message.thread_id}`,
);
}
Expand All @@ -107,7 +133,7 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: MessageDonePayload,
) {
this.server.to(socketId).emit(ChatEvents.MessageDone, data);
this.logger.log(
this.log(
`Socket "${ChatEvents.MessageDone}" | threadId: ${data.message.thread_id}`,
);
}
Expand All @@ -117,19 +143,17 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: TextCreatedPayload,
) {
this.server.to(socketId).emit(ChatEvents.TextCreated, data);
this.logger.log(`Socket "${ChatEvents.TextCreated}" | ${data.text.value}`);
this.log(`Socket "${ChatEvents.TextCreated}" | ${data.text.value}`);
}

async emitTextDelta(socketId: string, @MessageBody() data: TextDeltaPayload) {
this.server.to(socketId).emit(ChatEvents.TextDelta, data);
this.logger.log(
`Socket "${ChatEvents.TextDelta}" | ${data.textDelta.value}`,
);
this.log(`Socket "${ChatEvents.TextDelta}" | ${data.textDelta.value}`);
}

async emitTextDone(socketId: string, @MessageBody() data: TextDonePayload) {
this.server.to(socketId).emit(ChatEvents.TextDone, data);
this.logger.log(
this.log(
`Socket "${ChatEvents.TextDone}" | threadId: ${data.message?.thread_id} | ${data.text.value}`,
);
}
Expand All @@ -139,9 +163,7 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: ToolCallCreatedPayload,
) {
this.server.to(socketId).emit(ChatEvents.ToolCallCreated, data);
this.logger.log(
`Socket "${ChatEvents.ToolCallCreated}": ${data.toolCall.id}`,
);
this.log(`Socket "${ChatEvents.ToolCallCreated}": ${data.toolCall.id}`);
}

codeInterpreterHandler(
Expand Down Expand Up @@ -185,9 +207,7 @@ export class ChatGateway implements OnGatewayConnection {
socketId: string,
@MessageBody() data: ToolCallDeltaPayload,
) {
this.logger.log(
`Socket "${ChatEvents.ToolCallDelta}": ${data.toolCall.id}`,
);
this.log(`Socket "${ChatEvents.ToolCallDelta}": ${data.toolCall.id}`);

switch (data.toolCallDelta.type) {
case 'code_interpreter':
Expand All @@ -211,46 +231,38 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: ToolCallDonePayload,
) {
this.server.to(socketId).emit(ChatEvents.ToolCallDone, data);
this.logger.log(`Socket "${ChatEvents.ToolCallDone}": ${data.toolCall.id}`);
this.log(`Socket "${ChatEvents.ToolCallDone}": ${data.toolCall.id}`);
}

async emitImageFileDone(
socketId: string,
@MessageBody() data: ImageFileDonePayload,
) {
this.server.to(socketId).emit(ChatEvents.ImageFileDone, data);
this.logger.log(
`Socket "${ChatEvents.ImageFileDone}": ${data.content.file_id}`,
);
this.log(`Socket "${ChatEvents.ImageFileDone}": ${data.content.file_id}`);
}

async emitRunStepCreated(
socketId: string,
@MessageBody() data: RunStepCreatedPayload,
) {
this.server.to(socketId).emit(ChatEvents.RunStepCreated, data);
this.logger.log(
`Socket "${ChatEvents.RunStepCreated}": ${data.runStep.status}`,
);
this.log(`Socket "${ChatEvents.RunStepCreated}": ${data.runStep.status}`);
}

async emitRunStepDelta(
socketId: string,
@MessageBody() data: RunStepDeltaPayload,
) {
this.server.to(socketId).emit(ChatEvents.RunStepDelta, data);
this.logger.log(
`Socket "${ChatEvents.RunStepDelta}": ${data.runStep.status}`,
);
this.log(`Socket "${ChatEvents.RunStepDelta}": ${data.runStep.status}`);
}

async emitRunStepDone(
socketId: string,
@MessageBody() data: RunStepDonePayload,
) {
this.server.to(socketId).emit(ChatEvents.RunStepDone, data);
this.logger.log(
`Socket "${ChatEvents.RunStepDone}": ${data.runStep.status}`,
);
this.log(`Socket "${ChatEvents.RunStepDone}": ${data.runStep.status}`);
}
}
15 changes: 12 additions & 3 deletions libs/openai-assistant/src/lib/chat/chat.service.spec.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import { Test } from '@nestjs/testing';
import { APIPromise } from 'openai/core';
import { Message, Run } from 'openai/resources/beta/threads';
import { AssistantStream } from 'openai/lib/AssistantStream';
import { AiModule } from './../ai/ai.module';
import { ChatModule } from './chat.module';
import { ChatService } from './chat.service';
import { ChatHelpers } from './chat.helpers';
import { ChatCallDto } from './chat.model';
import { AssistantStream } from 'openai/lib/AssistantStream';
import { RunService } from '../run/run.service';

jest.mock('../stream/stream.utils', () => ({
assistantStreamEventHandler: jest.fn(),
}));

describe('ChatService', () => {
let chatService: ChatService;
let chatbotHelpers: ChatHelpers;
let runService: RunService;

beforeEach(async () => {
const moduleRef = await Test.createTestingModule({
Expand All @@ -19,18 +25,21 @@ describe('ChatService', () => {

chatService = moduleRef.get<ChatService>(ChatService);
chatbotHelpers = moduleRef.get<ChatHelpers>(ChatHelpers);
runService = moduleRef.get<RunService>(RunService);

jest.spyOn(runService, 'resolve').mockReturnThis();

jest
.spyOn(chatbotHelpers, 'getAnswer')
.mockReturnValue(Promise.resolve('Hello response') as Promise<string>);


jest
.spyOn(chatService.threads.messages, 'create')
.mockReturnValue({} as APIPromise<Message>);

jest.spyOn(chatService, 'assistantStream').mockReturnValue({
jest.spyOn(chatService, 'getAssistantStream').mockReturnValue({
finalRun: jest.fn(),
on: () => jest.fn(),
} as unknown as Promise<AssistantStream>);
});

Expand Down
Loading