Skip to content

Commit

Permalink
Merge branch 'main' into fix/orchestration-e2e-test
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaSiva authored Aug 13, 2024
2 parents d0a2b85 + 5bf34b5 commit 7534873
Show file tree
Hide file tree
Showing 8 changed files with 392 additions and 5 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/e2e-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ jobs:
fail-fast: false
matrix:
environment: [staging, production]
include:
# disabled, because currently staging is behind production
exclude:
- environment: staging
secret-name: AI_CORE_STAGING
include:
- environment: production
secret-name: AI_CORE_PRODUCTION
name: "Build and Test"
Expand Down
1 change: 1 addition & 0 deletions packages/gen-ai-hub/src/orchestration/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from './client/api/index.js';
export * from './orchestration-client.js';
export * from './orchestration-types.js';
export * from './orchestration-filter-utility.js';
107 changes: 106 additions & 1 deletion packages/gen-ai-hub/src/orchestration/orchestration-client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { mockInference, parseMockResponse } from '../test-util/mock-http.js';
import { BaseLlmParametersWithDeploymentId } from '../core/index.js';
import { CompletionPostResponse } from './client/api/index.js';
import { GenAiHubCompletionParameters } from './orchestration-types.js';
import { azureContentFilter } from './orchestration-filter-utility.js';
jest.unstable_mockModule('../core/context.js', () => ({
getAiCoreDestination: jest.fn(() =>
Promise.resolve(mockGetAiCoreDestination())
Expand Down Expand Up @@ -34,7 +35,7 @@ describe('GenAiHubClient', () => {
jest.restoreAllMocks();
});

it('calls chatCompletion with minimum configuration and parses response', async () => {
it('calls chatCompletion with minimum configuration', async () => {
const request: GenAiHubCompletionParameters = {
deploymentConfiguration,
llmConfig: {
Expand Down Expand Up @@ -70,6 +71,110 @@ describe('GenAiHubClient', () => {
expect(client.chatCompletion(request)).resolves.toEqual(mockResponse);
});

it('calls chatCompletion with filter configuration supplied using convenience function', async () => {
const request: GenAiHubCompletionParameters = {
deploymentConfiguration,
llmConfig: {
model_name: 'gpt-35-turbo-16k',
model_params: { max_tokens: 50, temperature: 0.1 }
},
prompt: {
template: [
{ role: 'user', content: 'Create {number} paraphrases of {phrase}' }
],
template_params: { phrase: 'I hate you.', number: 3 }
},
filterConfig: {
input: azureContentFilter({ Hate: 4, SelfHarm: 2 }),
output: azureContentFilter({ Sexual: 0, Violence: 4 })
}
};
const mockResponse = parseMockResponse<CompletionPostResponse>(
'orchestration',
'genaihub-chat-completion-filter-config.json'
);

mockInference(
{
data: {
deploymentConfiguration,
...constructCompletionPostRequest(request)
}
},
{
data: mockResponse,
status: 200
},
destination,
{
url: 'completion'
}
);
expect(client.chatCompletion(request)).resolves.toEqual(mockResponse);
});

it('calls chatCompletion with filtering configuration', async () => {
const request: GenAiHubCompletionParameters = {
deploymentConfiguration,
llmConfig: {
model_name: 'gpt-35-turbo-16k',
model_params: { max_tokens: 50, temperature: 0.1 }
},
prompt: {
template: [
{ role: 'user', content: 'Create {number} paraphrases of {phrase}' }
],
template_params: { phrase: 'I hate you.', number: 3 }
},
filterConfig: {
input: {
filters: [
{
type: 'azure_content_safety',
config: {
Hate: 4,
SelfHarm: 2
}
}
]
},
output: {
filters: [
{
type: 'azure_content_safety',
config: {
Sexual: 0,
Violence: 4
}
}
]
}
}
};
const mockResponse = parseMockResponse<CompletionPostResponse>(
'orchestration',
'genaihub-chat-completion-filter-config.json'
);

mockInference(
{
data: {
deploymentConfiguration,
...constructCompletionPostRequest(request)
}
},
{
data: mockResponse,
status: 200
},
destination,
{
url: 'completion'
}
);
expect(client.chatCompletion(request)).resolves.toEqual(mockResponse);
});

it('sends message history together with templating config', async () => {
const request: GenAiHubCompletionParameters = {
deploymentConfiguration,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ export function constructCompletionPostRequest(
templating_module_config: {
template: input.prompt.template
},
llm_module_config: input.llmConfig
llm_module_config: input.llmConfig,
...(Object.keys(input?.filterConfig || {}).length && {
filtering_module_config: input.filterConfig
})
}
},
...(input.prompt.template_params && {
Expand All @@ -55,4 +58,5 @@ export function constructCompletionPostRequest(
messages_history: input.prompt.messages_history
})
};
return result;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import {
CompletionPostRequest,
FilteringModuleConfig
} from './client/api/index.js';
import { constructCompletionPostRequest } from './orchestration-client.js';
import { azureContentFilter } from './orchestration-filter-utility.js';
import { GenAiHubCompletionParameters } from './orchestration-types.js';

describe('Filter utility', () => {
const genaihubCompletionParameters: GenAiHubCompletionParameters = {
deploymentConfiguration: {
deploymentId: 'deployment-id'
},
llmConfig: {
model_name: 'gpt-35-turbo-16k',
model_params: { max_tokens: 50, temperature: 0.1 }
},
prompt: {
template: [
{ role: 'user', content: 'Create {number} paraphrases of {phrase}' }
],
template_params: { phrase: 'I hate you.', number: 3 }
}
};

afterEach(() => {
genaihubCompletionParameters.filterConfig = undefined;
});

it('constructs filter configuration with only input', async () => {
const filterConfig: FilteringModuleConfig = {
input: azureContentFilter({ Hate: 4, SelfHarm: 0 })
};
const expectedFilterConfig: FilteringModuleConfig = {
input: {
filters: [
{
type: 'azure_content_safety',
config: {
Hate: 4,
SelfHarm: 0
}
}
]
}
};
genaihubCompletionParameters.filterConfig = filterConfig;
const completionPostRequest: CompletionPostRequest =
constructCompletionPostRequest(genaihubCompletionParameters);
expect(
completionPostRequest.orchestration_config.module_configurations
.filtering_module_config
).toEqual(expectedFilterConfig);
});

it('constructs filter configuration with only output', async () => {
const filterConfig: FilteringModuleConfig = {
output: azureContentFilter({ Sexual: 2, Violence: 6 })
};
const expectedFilterConfig: FilteringModuleConfig = {
output: {
filters: [
{
type: 'azure_content_safety',
config: {
Sexual: 2,
Violence: 6
}
}
]
}
};
genaihubCompletionParameters.filterConfig = filterConfig;
const completionPostRequest: CompletionPostRequest =
constructCompletionPostRequest(genaihubCompletionParameters);
expect(
completionPostRequest.orchestration_config.module_configurations
.filtering_module_config
).toEqual(expectedFilterConfig);
});

it('constructs filter configuration with both input and ouput', async () => {
const filterConfig: FilteringModuleConfig = {
input: azureContentFilter({
Hate: 4,
SelfHarm: 0,
Sexual: 2,
Violence: 6
}),
output: azureContentFilter({ Sexual: 2, Violence: 6 })
};
const expectedFilterConfig: FilteringModuleConfig = {
input: {
filters: [
{
type: 'azure_content_safety',
config: {
Hate: 4,
SelfHarm: 0,
Sexual: 2,
Violence: 6
}
}
]
},
output: {
filters: [
{
type: 'azure_content_safety',
config: {
Sexual: 2,
Violence: 6
}
}
]
}
};
genaihubCompletionParameters.filterConfig = filterConfig;
const completionPostRequest: CompletionPostRequest =
constructCompletionPostRequest(genaihubCompletionParameters);
expect(
completionPostRequest.orchestration_config.module_configurations
.filtering_module_config
).toEqual(expectedFilterConfig);
});

it('omits filters if not set', async () => {
const filterConfig: FilteringModuleConfig = {
input: azureContentFilter(),
output: azureContentFilter()
};
genaihubCompletionParameters.filterConfig = filterConfig;
const completionPostRequest: CompletionPostRequest =
constructCompletionPostRequest(genaihubCompletionParameters);
const expectedFilterConfig: FilteringModuleConfig = {
input: {
filters: [
{
type: 'azure_content_safety'
}
]
},
output: {
filters: [
{
type: 'azure_content_safety'
}
]
}
};
expect(
completionPostRequest.orchestration_config.module_configurations
.filtering_module_config
).toEqual(expectedFilterConfig);
});

it('omits filter configuration if not set', async () => {
const filterConfig: FilteringModuleConfig = {};
genaihubCompletionParameters.filterConfig = filterConfig;
const completionPostRequest: CompletionPostRequest =
constructCompletionPostRequest(genaihubCompletionParameters);
expect(
completionPostRequest.orchestration_config.module_configurations
.filtering_module_config
).toBeUndefined();
});

it('throw error when configuring empty filter', async () => {
const createFilterConfig = () => {
{
azureContentFilter({});
}
};
expect(createFilterConfig).toThrow(
'Filter property cannot be an empty object'
);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { AzureContentSafety, FilteringConfig } from './client/api/index.js';

/**
* Convenience function to create Azure filters.
* @param filter - Filtering configuration for Azure filter. If skipped, the default Azure filter configuration is used.
* @returns An object with the Azure filtering configuration.
*/
export function azureContentFilter(
filter?: AzureContentSafety
): FilteringConfig {
if (filter && !Object.keys(filter).length) {
throw new Error('Filter property cannot be an empty object');
}
return {
filters: [
{
type: 'azure_content_safety',
...(filter && { config: filter })
}
]
};
}
7 changes: 6 additions & 1 deletion packages/gen-ai-hub/src/orchestration/orchestration-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { BaseLlmParameters } from '../core/index.js';
import {
ChatMessages,
CompletionPostResponse,
FilteringModuleConfig,
InputParamsEntry,
LLMModuleConfig
} from './client/api/index.js';
Expand Down Expand Up @@ -44,11 +45,15 @@ export type LlmConfig = LLMModuleConfig;
*/
export interface OrchestrationCompletionParameters {
/**
* Prompt options.
* Prompt configuration options.
*/
prompt: PromptConfig;
/**
* Llm configuration options.
*/
llmConfig: LlmConfig;
/**
* Filter configuration options.
*/
filterConfig?: FilteringModuleConfig;
}
Loading

0 comments on commit 7534873

Please sign in to comment.