Skip to content

Commit

Permalink
Merge pull request #79 from parea-ai/PAI-1382-ts-set-dataset-name-on-…
Browse files Browse the repository at this point in the history
…exper-and-status-fail-on-any-error

fix(experiement): wait for longer for evals to finish)
  • Loading branch information
jalexanderII authored Jul 16, 2024
2 parents 58641d0 + 7a34b33 commit abbaca7
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ async function callFunction(function_call: ChatCompletionMessage.FunctionCall):
>
| Promise<DBItem>
> {
const args = JSON.parse(function_call.arguments!);
let args;
try {
args = JSON.parse(function_call.arguments!);
} catch (e) {
args = function_call.arguments;
}
switch (function_call.name) {
case 'list':
return await list(args['genre']);
Expand Down
21 changes: 20 additions & 1 deletion src/experiment/experiment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export class Experiment<T extends Record<string, any>, R> {
) {
this.runner = new ExperimentRunner(this.options.nWorkers || 10);
this.p = parea;
this.trySetDataset(dataset);
}

/**
Expand All @@ -67,10 +68,11 @@ export class Experiment<T extends Record<string, any>, R> {
return experimentContext.runInContext(experimentUUID, async () => {
try {
this.dataset = await this.determineDataset(this.dataset);
const maxRetries = typeof this.options?.maxRetries === 'number' ? this.options.maxRetries : 60;
const trials = this.dataset.flatMap((data) =>
Array(this.options.nTrials || 1)
.fill(null)
.map(() => new Trial(data, this.func, experimentUUID)),
.map(() => new Trial(data, this.func, experimentUUID, maxRetries)),
);

const results = await this.runner.runTrials(trials);
Expand Down Expand Up @@ -185,4 +187,21 @@ export class Experiment<T extends Record<string, any>, R> {
)}/${experimentUUID}\n`,
);
}

/**
* Set Dataset name as metadata is using dataset
* @param dataset The input dataset, either as a string (collection name) or an array of data.
*/
private trySetDataset(dataset: string | T[]): void {
if (typeof dataset === 'string') {
if (!this.options.metadata) {
this.options.metadata = {};
} else if (this.options.metadata.Dataset) {
console.warn(
'Metadata key "Dataset" is reserved for the dataset name. Overwriting it with the provided dataset name.',
);
}
this.options.metadata = { ...this.options.metadata, Dataset: dataset };
}
}
}
20 changes: 20 additions & 0 deletions src/experiment/experimentContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ class ExperimentContext {
const context = store.get(experimentUUID) || { logs: [], scores: [] };
context.scores.push(score);
store.set(experimentUUID, context);
} else {
console.error(`Experiment context store not found for experiment ${experimentUUID}`);
}
}

/**
* Adds a score to the experiment context.
* @param experimentUUID - The UUID of the experiment.
* @param scores - The evaluation results to add.
*/
addScores(experimentUUID: string, scores: EvaluationResult[]): void {
const store = this.context.getStore();
if (store) {
const context = store.get(experimentUUID) || { logs: [], scores: [] };
context.scores.push(...scores);
store.set(experimentUUID, context);
} else {
console.error(`Experiment context store not found for experiment ${experimentUUID}`);
}
}

Expand All @@ -63,6 +81,8 @@ class ExperimentContext {
const context = store.get(experimentUUID) || { logs: [], scores: [] };
context.logs.push(log);
store.set(experimentUUID, context);
} else {
console.error(`Experiment context store not found for experiment ${experimentUUID}`);
}
}

Expand Down
29 changes: 19 additions & 10 deletions src/experiment/trial.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ export class Trial<T extends Record<string, any>, R> {
* @param data The input data for the trial.
* @param func The function to be executed for the trial.
* @param experimentUUID The UUID of the experiment this trial belongs to.
* @param maxRetries - The maximum number of retries to wait for eval to finish. Each retry waits for 1s. Default is 60.
*/
constructor(
private data: T,
private func: (...args: any[]) => R | Promise<R>,
private experimentUUID: string,
private maxRetries: number,
) {}

/**
Expand Down Expand Up @@ -48,14 +50,12 @@ export class Trial<T extends Record<string, any>, R> {
return funcResult;
});

this.state = ExperimentStatus.COMPLETED;

await this.waitForLogs();

const { state, error } = await this.waitForLogs();
this.state = state;
const scores = experimentContext.getScores(this.experimentUUID);
const logs = experimentContext.getLogs(this.experimentUUID);

return new TrialResult(this.data, result, null, this.state, scores, logs);
return new TrialResult(this.data, result, error || null, state, scores, logs);
} catch (error) {
this.state = ExperimentStatus.FAILED;
const e = error instanceof Error ? error : new Error(String(error));
Expand All @@ -64,14 +64,23 @@ export class Trial<T extends Record<string, any>, R> {
});
}

private async waitForLogs(maxRetries: number = 5): Promise<void> {
for (let i = 0; i < maxRetries; i++) {
private async waitForLogs(): Promise<{ state: ExperimentStatus; error?: Error }> {
await new Promise((resolve) => setTimeout(resolve, 2500)); // Wait for 2.5s before checking logs
for (let i = 1; i < this.maxRetries; i++) {
const logs = experimentContext.getLogs(this.experimentUUID);
if (logs.length > 0) {
return;
return { state: ExperimentStatus.COMPLETED };
}
// log every 10 retries
if (i % 10 === 0) {
console.debug(
`Waiting for eval to finish for trial in experiment ${this.experimentUUID}. Retrying (${i}/${this.maxRetries})...`,
);
}
await new Promise((resolve) => setTimeout(resolve, 500)); // Wait for 500ms before checking again
await new Promise((resolve) => setTimeout(resolve, 1000)); // Wait for 1s before checking again
}
console.warn(`Warning: No logs were collected for trial in experiment ${this.experimentUUID}`);
const msg = `No logs were collected for trial in experiment ${this.experimentUUID} after ${this.maxRetries} trys. Eval function likely did not finish, try increasing maxRetries on p.experiment. e.g: p.experiment('ExperimentName', data, func, { maxRetries: 120 })`;
console.warn(msg);
return { state: ExperimentStatus.FAILED, error: new Error(msg) };
}
}
6 changes: 6 additions & 0 deletions src/experiment/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@ export type TracedFunction<T extends Record<string, any>, R> = (

/**
* Options for configuring an experiment.
* @param nTrials - The number of trials to run
* @param metadata - Additional metadata for the experiment
* @param nWorkers - The number of workers to use for parallel execution
* @param datasetLevelEvalFuncs - An array of evaluation functions to run on the entire dataset
* @param maxRetries - The maximum number of retries to wait for eval to finish. Each retry waits for 1s. Default is 60.
*/
export interface ExperimentOptions {
nTrials?: number;
metadata?: Record<string, any>;
nWorkers?: number;
datasetLevelEvalFuncs?: ((logs: EvaluatedLog[]) => EvalFunctionReturn)[];
maxRetries?: number;
}

/**
Expand Down
13 changes: 13 additions & 0 deletions src/utils/core/StreamHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export class StreamHandler<Item> {
let role: string | undefined;
let content: string | undefined;
let tool_calls: any[] | undefined;
let function_call: any | undefined;
let finish_reason: string | undefined;
let metrics: Record<string, number> = {};

Expand Down Expand Up @@ -137,6 +138,17 @@ export class StreamHandler<Item> {
tool_calls[0].function.arguments += delta.tool_calls[0].function.arguments;
}
}

if (delta.function_call) {
if (!function_call) {
function_call = {
name: delta.function_call.name,
arguments: delta.function_call.arguments,
};
} else {
function_call.arguments += delta.function_call.arguments;
}
}
}

return {
Expand All @@ -148,6 +160,7 @@ export class StreamHandler<Item> {
role,
content,
tool_calls,
function_call,
} as ChatCompletionMessage,
logprobs: null,
finish_reason,
Expand Down
4 changes: 1 addition & 3 deletions src/utils/core/TraceManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,7 @@ export class TraceManager {
trace.updateLog({ scores });
if (experiment_uuid) {
experimentContext.addLog(experiment_uuid, trace.getLog());
scores.forEach((score) => {
experimentContext.addScore(experiment_uuid, score);
});
experimentContext.addScores(experiment_uuid, scores);
}
trace.finalize();
});
Expand Down
4 changes: 4 additions & 0 deletions src/utils/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,9 @@ export function processEvaluationResult(
} else {
scores.push(result as EvaluationResult);
}
} else {
const msg = `Evaluation function ${funcName} returned an undefined or null result.`;
console.warn(msg);
scores.push({ name: `error-${funcName}`, score: 0, reason: msg });
}
}
79 changes: 18 additions & 61 deletions src/utils/message-converters.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Message, MessageConverter, Role } from '../types';
import { ChatCompletionMessageParam } from 'openai/src/resources/chat/completions';

/**
* Implements the MessageConverter interface for converting OpenAI messages.
Expand All @@ -9,76 +10,32 @@ export class OpenAIMessageConverter implements MessageConverter {
* @param m - The input message to be converted.
* @returns A standardized Message object.
*/
convert(m: any): Message {
if (m?.role === 'assistant' && !!m?.tool_calls) {
let content = `${m}`;
try {
content = this.formatToolCalls(m);
} catch (e) {
console.error(`Error converting assistant message with tool calls: ${e}`);
}
convert(m: ChatCompletionMessageParam): Message {
if (m.role === 'tool') {
return {
role: Role.tool,
content: JSON.stringify({ tool_call_id: m.tool_call_id, content: m.content }),
};
} else if (m.role === 'function') {
return {
role: Role.function,
content: typeof m.content === 'string' ? m.content : JSON.stringify(m.content),
};
} else if (m.role === 'assistant' && !!m.function_call) {
return {
role: Role.assistant,
content: content,
content: JSON.stringify(m.function_call),
};
} else if (m.role === 'tool') {
} else if (m.role === 'assistant' && !!m.tool_calls) {
return {
role: Role.tool,
content: JSON.stringify({ tool_call_id: m.tool_call_id, content: m.content }),
role: Role.assistant,
content: JSON.stringify(m.tool_calls),
};
} else {
return {
role: Role[m.role as keyof typeof Role],
content: m.content,
content: typeof m.content === 'string' ? m.content : JSON.stringify(m.content || {}),
};
}
}

/**
* Formats tool calls from an OpenAI response message.
* @param responseMessage - The response message containing tool calls.
* @returns A formatted string representation of the tool calls.
* @private
*/
private formatToolCalls(responseMessage: any): string {
const formattedToolCalls: any[] = [];
for (const toolCall of responseMessage['tool_calls']) {
if (toolCall['type'] === 'function') {
const functionName: string = toolCall['function']['name'];
const functionArgs: any = this.parseArgs(toolCall['function']['arguments']);
const toolCallId: string = toolCall['id'];
formattedToolCalls.push({
id: toolCallId,
type: toolCall['type'],
function: {
name: functionName,
arguments: functionArgs,
},
});
} else {
formattedToolCalls.push(toolCall);
}
}
return JSON.stringify(formattedToolCalls, null, 4);
}

/**
* Parses function arguments from a response.
* @param responseFunctionArgs - The function arguments to parse.
* @returns Parsed arguments as an object or string.
* @throws {Error} If there's an error parsing the arguments.
* @private
*/
private parseArgs(responseFunctionArgs: any): any {
if (responseFunctionArgs instanceof Object) {
return responseFunctionArgs;
} else {
try {
return JSON.parse(responseFunctionArgs);
} catch (e) {
console.error(`Error parsing tool call arguments as Object, storing as string instead: ${e}`);
return typeof responseFunctionArgs === 'string' ? responseFunctionArgs : `${responseFunctionArgs}`;
}
}
}
}
Loading

0 comments on commit abbaca7

Please sign in to comment.