Skip to content

Commit

Permalink
Merge pull request #5 from statelyai/davidkpiano/api-updates-1
Browse files Browse the repository at this point in the history
API updates
  • Loading branch information
davidkpiano authored Feb 3, 2024
2 parents c2400bf + 687bed8 commit 81fe549
Show file tree
Hide file tree
Showing 14 changed files with 641 additions and 506 deletions.
5 changes: 5 additions & 0 deletions .changeset/gorgeous-bats-explain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@statelyai/agent": patch
---

Simplify API (WIP)
5 changes: 5 additions & 0 deletions .changeset/pretty-fishes-shake.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@statelyai/agent': patch
---

Add `createSchemas`, `createOpenAIAdapter`, and change `createAgent`
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,5 @@ dist
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*

.vscode/settings.json
2 changes: 1 addition & 1 deletion examples/helpers/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { fromPromise } from 'xstate';
export const getFromTerminal = fromPromise<string, string>(
async ({ input }) => {
const topic = await new Promise<string>((res) => {
console.log(input);
console.log(input + '\n');
const listener = (data: Buffer) => {
const result = data.toString().trim();
process.stdin.off('data', listener);
Expand Down
62 changes: 32 additions & 30 deletions examples/joke.ts
Original file line number Diff line number Diff line change
@@ -1,40 +1,49 @@
import OpenAI from 'openai';
import {
assign,
createActor,
fromCallback,
fromPromise,
log,
setup,
} from 'xstate';
import { createAgent } from '../src';
import { assign, fromCallback, fromPromise, log, setup } from 'xstate';
import { createAgent, createOpenAIAdapter, createSchemas } from '../src';
import { loadingAnimation } from './helpers/loader';

const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});

const agent = createAgent(openai, {
model: 'gpt-3.5-turbo-1106',
const schemas = createSchemas({
context: {
topic: { type: 'string' },
jokes: {
type: 'array',
items: {
type: 'string',
},
desire: { type: ['string', 'null'] },
lastRating: { type: ['string', 'null'] },
},
desire: { type: ['string', 'null'] as const },
lastRating: { type: ['string', 'null'] as const },
},
events: {
askForTopic: {
type: 'object',
properties: {
topic: {
type: 'string',
},
},
},
endJokes: {
type: 'object',
properties: {},
},
},
events: {},
});

const getJokeCompletion = agent.fromChatCompletion(
const adapter = createOpenAIAdapter(openai, {
model: 'gpt-3.5-turbo-1106',
});

const getJokeCompletion = adapter.fromChat(
(topic: string) => `Tell me a joke about ${topic}.`
);

const rateJoke = agent.fromChatCompletion(
const rateJoke = adapter.fromChat(
(joke: string) => `Rate this joke on a scale of 1 to 10: ${joke}`
);

Expand All @@ -52,7 +61,7 @@ const getTopic = fromPromise(async () => {
return topic;
});

const decide = agent.fromEvent(
const decide = adapter.fromEventChoice(
(lastRating: string) =>
`Choose what to do next, given the previous rating of the joke: ${lastRating}`
);
Expand Down Expand Up @@ -96,15 +105,8 @@ const loader = fromCallback(({ input }: { input: string }) => {
});

const jokeMachine = setup({
types: {
context: {} as {
topic: string;
jokes: string[];
desire: string | null;
lastRating: string | null;
},
input: {} as { topic: string },
},
schemas,
types: schemas.types,
actors: {
getJokeCompletion,
getTopic,
Expand Down Expand Up @@ -146,7 +148,7 @@ const jokeMachine = setup({
event.output.choices[0]!.message.content!
),
}),
log((x) => x.context.jokes.at(-1)),
log((x) => `\n` + x.context.jokes.at(-1)),
],
target: 'rateJoke',
},
Expand All @@ -168,7 +170,7 @@ const jokeMachine = setup({
lastRating: ({ event }) =>
event.output.choices[0]!.message.content!,
}),
log(({ context }) => context.lastRating),
log(({ context }) => '\n' + context.lastRating),
],
target: 'decide',
},
Expand Down Expand Up @@ -210,5 +212,5 @@ const jokeMachine = setup({
},
});

const actor = createActor(jokeMachine);
actor.start();
const agent = createAgent(jokeMachine);
agent.start();
Empty file.
36 changes: 17 additions & 19 deletions examples/ticTacToe.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import { assign, setup, assertEvent, createActor } from 'xstate';
import { assign, setup, assertEvent } from 'xstate';
import OpenAI from 'openai';
import { createAgent } from '../src/openai';
import { createOpenAIAdapter, createSchemas, createAgent } from '../src';

const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});

type Player = 'x' | 'o';

const agent = createAgent(openai, {
model: 'gpt-4-1106-preview',
const schemas = createSchemas({
context: {
board: {
type: 'array',
Expand Down Expand Up @@ -69,28 +68,32 @@ const agent = createAgent(openai, {
},
});

const adapter = createOpenAIAdapter(openai, {
model: 'gpt-4-1106-preview',
});

const initialContext = {
board: Array(9).fill(null) as Array<Player | null>,
moves: 0,
player: 'x' as Player,
gameReport: '',
events: [],
} satisfies typeof agent.types.context;
} satisfies typeof schemas.types.context;

const bot = agent.fromEvent(
({ context }: { context: typeof agent.types.context }) => `
const bot = adapter.fromEventChoice(
({ context }: { context: typeof schemas.types.context }) => `
You are playing a game of tic tac toe. This is the current game state. The 3x3 board is represented by a 9-element array. The first element is the top-left cell, the second element is the top-middle cell, the third element is the top-right cell, the fourth element is the middle-left cell, and so on. The value of each cell is either null, x, or o. The value of null means that the cell is empty. The value of x means that the cell is occupied by an x. The value of o means that the cell is occupied by an o.
${JSON.stringify(context, null, 2)}
Execute the single best next move to try to win the game. Do not play on an existing cell.`
);

const gameReporter = agent.fromChatCompletionStream(
const gameReporter = adapter.fromChatStream(
({
context,
}: {
context: typeof agent.types.context;
context: typeof schemas.types.context;
}) => `Here is the game board:
${JSON.stringify(context.board, null, 2)}
Expand Down Expand Up @@ -124,7 +127,8 @@ function getWinner(board: typeof initialContext.board): Player | null {
}

export const ticTacToeMachine = setup({
types: agent.types,
schemas,
types: schemas.types,
actors: {
bot,
gameReporter,
Expand Down Expand Up @@ -248,14 +252,8 @@ export const ticTacToeMachine = setup({
},
});

const actor = createActor(ticTacToeMachine, {
inspect: (e) => {
if (e.type === '@xstate.event') {
console.log(e.event);
}
},
});
actor.subscribe((s) => {
const agent = createAgent(ticTacToeMachine);
agent.subscribe((s) => {
console.log(s.value, s.context);
});
actor.start();
agent.start();
47 changes: 30 additions & 17 deletions examples/weather.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import OpenAI from 'openai';
import { createAgent, fromEventChoice } from '../src';
import { assign, createActor, fromPromise, log, setup } from 'xstate';
import { createAgent, createOpenAIAdapter, createSchemas } from '../src';
import { assign, fromPromise, log, setup } from 'xstate';
import { getFromTerminal } from './helpers/helpers';

async function searchTavily(
Expand All @@ -23,6 +23,7 @@ async function searchTavily(
},
body: JSON.stringify(body),
});

const json = await response.json();
if (!response.ok) {
throw new Error(
Expand All @@ -39,8 +40,7 @@ const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});

const agent = createAgent(openai, {
model: 'gpt-4-1106-preview',
const schemas = createSchemas({
context: {
location: { type: 'string' },
history: { type: 'array', items: { type: 'string' } },
Expand All @@ -64,17 +64,27 @@ const agent = createAgent(openai, {
},
});

const adapter = createOpenAIAdapter(openai, {
model: 'gpt-4-1106-preview',
});

const getWeather = fromPromise(async ({ input }: { input: string }) => {
const results = await searchTavily(
`Get the weather for this location: ${input}`,
{
maxResults: 5,
apiKey: process.env.TAVILY_API_KEY!,
}
);
return results;
});

const machine = setup({
types: agent.types,
schemas,
types: schemas.types,
actors: {
searchTavily: fromPromise(async ({ input }: { input: string }) => {
const results = await searchTavily(input, {
maxResults: 5,
apiKey: process.env.TAVILY_API_KEY!,
});
return results;
}),
decide: agent.fromEvent(
getWeather,
decide: adapter.fromEventChoice(
(input: string) =>
`Decide what to do based on the given input, which may or may not be a location: ${input}`
),
Expand Down Expand Up @@ -121,9 +131,8 @@ const machine = setup({
gettingWeather: {
entry: log('Getting weather...'),
invoke: {
src: 'searchTavily',
input: ({ context }) =>
`Get the weather for this location: ${context.location}`,
src: 'getWeather',
input: ({ context }) => context.location,
onDone: {
actions: [
log(({ event }) => event.output),
Expand All @@ -144,4 +153,8 @@ const machine = setup({
},
});

createActor(machine).start();
createAgent(machine, {
input: {
location: 'New York',
},
}).start();
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"access": "public"
},
"dependencies": {
"xstate": "^5.5.1"
"xstate": "^5.6.0"
},
"packageManager": "[email protected]"
}
Loading

0 comments on commit 81fe549

Please sign in to comment.