Skip to content

Commit

Permalink
feat: Add message history option (#27)
Browse files Browse the repository at this point in the history
* messages_history + tests

* fix: Changes from lint

* fix openai tests

* review

* fix: Changes from lint

---------

Co-authored-by: cloud-sdk-js <[email protected]>
  • Loading branch information
deekshas8 and cloud-sdk-js authored Jul 19, 2024
1 parent 8814496 commit 3d1926a
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 129 deletions.
95 changes: 34 additions & 61 deletions packages/gen-ai-hub/src/client/openai/openai-client.test.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import fs from 'fs';
import path from 'path';
import nock from 'nock';
import { HttpDestination } from '@sap-cloud-sdk/connectivity';
import { mockGetAiCoreDestination } from '../../../test-util/mock-context.js';
import {
BaseLlmParameters,
BaseLlmParametersWithDeploymentId,
EndpointOptions
} from '../../core/http-client.js';
import { mockInference } from '../../../test-util/mock-http.js';
import {
mockInference,
parseMockResponse
} from '../../../test-util/mock-http.js';
import { OpenAiClient } from './openai-client.js';
import {
OpenAiChatCompletionOutput,
Expand All @@ -20,17 +20,15 @@ import {

describe('openai client', () => {
let destination: HttpDestination;
let deploymentConfig: BaseLlmParameters;
const deploymentConfiguration: BaseLlmParametersWithDeploymentId = {
deploymentId: 'deployment-id'
};
let chatCompletionEndpoint: EndpointOptions;
let embeddingsEndpoint: EndpointOptions;

beforeAll(() => {
destination = mockGetAiCoreDestination();
deploymentConfig = {
deploymentConfiguration: {
deploymentId: 'deployment-id'
} as BaseLlmParametersWithDeploymentId
};

chatCompletionEndpoint = {
url: 'chat/completions',
apiVersion: '2024-02-01'
Expand Down Expand Up @@ -58,60 +56,47 @@ describe('openai client', () => {
};
const request: OpenAiChatCompletionParameters = {
...prompt,
...deploymentConfig
deploymentConfiguration
};
const mockResponse = fs.readFileSync(
path.join(
'test-util',
'mock-data',
'openai',
'openai-chat-completion-success-response.json'
),
'utf8'
const mockResponse = parseMockResponse<OpenAiChatCompletionOutput>(
'openai',
'openai-chat-completion-success-response.json'
);

mockInference(
{
data: request
},
{
data: JSON.parse(mockResponse),
data: mockResponse,
status: 200
},
destination,
chatCompletionEndpoint
);

const result: OpenAiChatCompletionOutput =
await new OpenAiClient().chatCompletion(request);
const expectedResponse: OpenAiChatCompletionOutput =
JSON.parse(mockResponse);

expect(result).toEqual(expectedResponse);
expect(new OpenAiClient().chatCompletion(request)).resolves.toEqual(
mockResponse
);
});

it('throws on bad request', async () => {
const prompt = { messages: [] };
const request: OpenAiChatCompletionParameters = {
...prompt,
...deploymentConfig
deploymentConfiguration
};
const mockResponse = fs.readFileSync(
path.join(
'test-util',
'mock-data',
'openai',
'openai-error-response.json'
),
'utf8'
const mockResponse = parseMockResponse(
'openai',
'openai-error-response.json'
);

mockInference(
{
data: request
},
{
data: JSON.parse(mockResponse),
data: mockResponse,
status: 400
},
destination,
Expand All @@ -129,66 +114,54 @@ describe('openai client', () => {
const prompt = { input: ['AI is fascinating'] };
const request: OpenAiEmbeddingParameters = {
...prompt,
...deploymentConfig
deploymentConfiguration
};
const mockResponse = fs.readFileSync(
path.join(
'test-util',
'mock-data',
'openai',
'openai-embeddings-success-response.json'
),
'utf8'
const mockResponse = parseMockResponse<OpenAiEmbeddingOutput>(
'openai',
'openai-embeddings-success-response.json'
);

mockInference(
{
data: request
},
{
data: JSON.parse(mockResponse),
data: mockResponse,
status: 200
},
destination,
embeddingsEndpoint
);

const result: OpenAiEmbeddingOutput = await new OpenAiClient().embeddings(
request
expect(new OpenAiClient().embeddings(request)).resolves.toEqual(
mockResponse
);
const expectedResponse: OpenAiEmbeddingOutput = JSON.parse(mockResponse);
expect(result).toEqual(expectedResponse);
});

it('throws on bad request', async () => {
const prompt = { input: [] };
const request: OpenAiEmbeddingParameters = {
...prompt,
...deploymentConfig
deploymentConfiguration
};
const mockResponse = fs.readFileSync(
path.join(
'test-util',
'mock-data',
'openai',
'openai-error-response.json'
),
'utf8'
const mockResponse = parseMockResponse(
'openai',
'openai-error-response.json'
);

mockInference(
{
data: request
},
{
data: JSON.parse(mockResponse),
data: mockResponse,
status: 400
},
destination,
embeddingsEndpoint
);

await expect(new OpenAiClient().embeddings(request)).rejects.toThrow();
expect(new OpenAiClient().embeddings(request)).rejects.toThrow();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
* Representation of the 'ChatMessage' schema.
*/
export type ChatMessage = {
role: string;
role: 'user' | 'assistant' | 'system';
content: string;
} & Record<string, any>;
103 changes: 52 additions & 51 deletions packages/gen-ai-hub/src/orchestration/orchestration-client.test.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
import fs from 'fs';
import path from 'path';
import nock from 'nock';
import { HttpDestination } from '@sap-cloud-sdk/connectivity';
import { BaseLlmParametersWithDeploymentId } from '../core/index.js';
import { mockGetAiCoreDestination } from '../../test-util/mock-context.js';
import { mockInference } from '../../test-util/mock-http.js';
import { mockInference, parseMockResponse } from '../../test-util/mock-http.js';
import { BaseLlmParametersWithDeploymentId } from '../core/index.js';
import {
GenAiHubClient,
GenAiHubCompletionParameters
} from './orchestration-client.js';
import { CompletionPostResponse, ModuleConfigs } from './api/index.js';
import {
CompletionPostResponse,
LLMModuleConfig,
ModuleConfigs
} from './api/index.js';

describe('GenAiHubClient', () => {
let destination: HttpDestination;
let deploymentConfiguration: BaseLlmParametersWithDeploymentId;
let client: GenAiHubClient;
const deploymentConfiguration: BaseLlmParametersWithDeploymentId = {
deploymentId: 'deployment-id'
};
const llm_module_config: LLMModuleConfig = {
model_name: 'gpt-35-turbo-16k',
model_params: {
max_tokens: 50,
temperature: 0.1
}
};

beforeAll(() => {
deploymentConfiguration = {
deploymentId: 'deployment-id'
};
destination = mockGetAiCoreDestination();
client = new GenAiHubClient();
});
Expand All @@ -28,92 +36,85 @@ describe('GenAiHubClient', () => {
nock.cleanAll();
});

it(' calls chatCompletion and parses response', async () => {
it('calls chatCompletion with minimum configuration and parses response', async () => {
const module_configurations: ModuleConfigs = {
templating_module_config: {
template: [{ role: 'user', content: 'Hello!' }]
},
llm_module_config: {
model_name: 'gpt-35-turbo-16k',
model_params: {
max_tokens: 50,
temperature: 0.1
}
}
llm_module_config
};
const request: GenAiHubCompletionParameters = {
deploymentConfiguration,
orchestration_config: { module_configurations }
};

const mockResponse = fs.readFileSync(
path.join(
'test-util',
'mock-data',
'orchestration',
'genaihub-chat-completion-success-response.json'
),
'utf-8'
const mockResponse = parseMockResponse<CompletionPostResponse>(
'orchestration',
'genaihub-chat-completion-success-response.json'
);

mockInference(
{
data: { ...request, input_params: {} }
},
{
data: JSON.parse(mockResponse),
data: mockResponse,
status: 200
},
destination,
{
url: 'completion'
}
);
const result = await client.chatCompletion(request);
const expectedResponse: CompletionPostResponse = JSON.parse(mockResponse);
expect(result).toEqual(expectedResponse);
expect(client.chatCompletion(request)).resolves.toEqual(mockResponse);
});

it('throws error for incorrect input parameters', async () => {
it('sends message history together with templating config', async () => {
const module_configurations: ModuleConfigs = {
templating_module_config: {
template: [{ role: 'actor', content: 'Hello' }]
template: [{ role: 'user', content: "What's my name?" }]
},
llm_module_config: {
model_name: 'gpt-35-turbo-16k',
model_params: {
max_tokens: 50,
temperature: 0.1
}
}
llm_module_config
};
const request: GenAiHubCompletionParameters = {
deploymentConfiguration,
orchestration_config: { module_configurations }
orchestration_config: { module_configurations },
messages_history: [
{
role: 'system',
content:
'You are a helpful assistant who remembers all details the user shares with you.'
},
{
role: 'user',
content: 'Hi! Im Bob'
},
{
role: 'assistant',
content:
"Hi Bob, nice to meet you! I'm an AI assistant. I'll remember that your name is Bob as we continue our conversation."
}
]
};
const mockResponse = fs.readFileSync(
path.join(
'test-util',
'mock-data',
'orchestration',
'genaihub-error-response.json'
),
'utf-8'
);

const mockResponse = parseMockResponse<CompletionPostResponse>(
'orchestration',
'genaihub-chat-completion-message-history.json'
);
mockInference(
{
data: { ...request, input_params: {} }
},
{
data: JSON.parse(mockResponse),
status: 400
data: mockResponse,
status: 200
},
destination,
{
url: 'completion'
}
);
await expect(client.chatCompletion(request)).rejects.toThrow();

expect(client.chatCompletion(request)).resolves.toEqual(mockResponse);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
* Input Parameters for GenAI hub chat completion.
*/
export type GenAiHubCompletionParameters = BaseLlmParameters &
Pick<CompletionPostRequest, 'orchestration_config'>;
Pick<CompletionPostRequest, 'orchestration_config' | 'messages_history'>;

/**
* Get the orchestration client.
Expand Down
Loading

0 comments on commit 3d1926a

Please sign in to comment.