Skip to content

Commit

Permalink
Merge pull request #38 from statelyai/davidkpiano/context
Browse files Browse the repository at this point in the history
Add `context` field in agent
  • Loading branch information
davidkpiano authored Jul 16, 2024
2 parents 56ac295 + ae9f389 commit dc81287
Show file tree
Hide file tree
Showing 25 changed files with 155 additions and 89 deletions.
26 changes: 26 additions & 0 deletions .changeset/wild-bobcats-care.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
'@statelyai/agent': minor
---

You can now add `context` Zod schema to your agent. For now, this is meant to be passed directly to the state machine, but in the future, the schema can be shared with the LLM agent to better understand the state machine and its context for decision making.

Breaking: The `context` and `events` types are now in `agent.types` instead of ~~`agent.eventTypes`.

```ts
const agent = createAgent({
// ...
context: {
score: z.number().describe('The score of the game'),
// ...
},
});

const machine = setup({
types: agent.types,
}).createMachine({
context: {
score: 0,
},
// ...
});
```
26 changes: 9 additions & 17 deletions examples/chatbot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,18 @@ const agent = createAgent({
}),
'agent.endConversation': z.object({}).describe('Stop the conversation'),
},
context: {
userMessage: z.string(),
},
});

const machine = setup({
types: {
context: {} as {
conversation: string[];
},
events: agent.eventTypes,
},
types: agent.types,
actors: { agent: fromDecision(agent), getFromTerminal },
}).createMachine({
initial: 'listening',
context: {
conversation: [],
userMessage: '',
},
states: {
listening: {
Expand All @@ -35,8 +33,7 @@ const machine = setup({
input: 'User:',
onDone: {
actions: assign({
conversation: (x) =>
x.context.conversation.concat('User: ' + x.event.output),
userMessage: (x) => x.event.output,
}),
target: 'responding',
},
Expand All @@ -47,20 +44,15 @@ const machine = setup({
src: 'agent',
input: (x) => ({
context: {
conversation: x.context.conversation,
userMessage: 'User says: ' + x.context.userMessage,
},
messages: agent.select((mem) => mem.messages),
goal: 'Respond to the user, unless they want to end the conversation.',
}),
},
on: {
'agent.respond': {
actions: [
assign({
conversation: (x) =>
x.context.conversation.concat('Assistant: ' + x.event.response),
}),
log((x) => `Agent: ${x.event.response}`),
],
actions: [log((x) => `Agent: ${x.event.response}`)],
target: 'listening',
},
'agent.endConversation': 'finished',
Expand Down
12 changes: 5 additions & 7 deletions examples/cot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@ const agent = createAgent({
answer: z.string().describe('The answer to the question'),
}),
},
context: {
question: z.string().nullable(),
thought: z.string().nullable(),
},
});

const machine = setup({
types: {
context: {} as {
question: string | null;
thought: string | null;
},
events: agent.eventTypes,
},
types: agent.types,
actors: { agent: fromDecision(agent), getFromTerminal },
}).createMachine({
initial: 'asking',
Expand Down
2 changes: 1 addition & 1 deletion examples/email.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const agent = createAgent({

const machine = setup({
types: {
events: agent.eventTypes,
events: agent.types.events,
input: {} as {
email: string;
instructions: string;
Expand Down
2 changes: 1 addition & 1 deletion examples/example.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const agent = createAgent({

const machine = setup({
types: {
events: agent.eventTypes,
events: agent.types.events,
},
actors: { agent: fromDecision(agent), summarizer: fromText(agent) },
}).createMachine({
Expand Down
2 changes: 1 addition & 1 deletion examples/goal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const machine = setup({
question: string | null;
goal: string | null;
},
events: agent.eventTypes,
events: agent.types.events,
},
actors: { decider, getFromTerminal },
}).createMachine({
Expand Down
18 changes: 8 additions & 10 deletions examples/joke.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,17 @@ const agent = createAgent({
.describe('Explains why the joke was irrelevant'),
'agent.markAsRelevant': z.object({}).describe('The joke was relevant'),
},
context: {
topic: z.string().describe('The topic for the joke'),
jokes: z.array(z.string()).describe('The jokes told so far'),
desire: z.string().nullable().describe('The user desire'),
lastRating: z.number().nullable().describe('The last joke rating'),
loader: z.string().nullable().describe('The loader text'),
},
});

const jokeMachine = setup({
types: {
context: {} as {
topic: string;
jokes: string[];
desire: string | null;
lastRating: number | null;
loader: string | null;
},
events: agent.eventTypes,
},
types: agent.types,
actors: {
agent: fromDecision(agent),
loader,
Expand Down
2 changes: 1 addition & 1 deletion examples/number.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const machine = setup({
previousGuesses: number[];
answer: number | null;
},
events: agent.eventTypes,
events: agent.types.events,
},
actors: {
agent: fromDecision(agent),
Expand Down
2 changes: 1 addition & 1 deletion examples/raffle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const machine = setup({
lastInput: string | null;
entries: string[];
},
events: {} as typeof agent.eventTypes | { type: 'draw' },
events: {} as typeof agent.types.events | { type: 'draw' },
},
actors: { agent: fromDecision(agent), getFromTerminal },
}).createMachine({
Expand Down
2 changes: 1 addition & 1 deletion examples/support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ const agent = createAgent({

const machine = setup({
types: {
events: agent.eventTypes,
events: agent.types.events,
input: {} as string,
context: {} as {
customerIssue: string;
Expand Down
21 changes: 18 additions & 3 deletions examples/ticTacToe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ const agent = createAgent({
}),
reset: z.object({}).describe('Reset the game to the initial state'),
},
context: {
board: z
.array(z.union([z.literal(null), z.literal('x'), z.literal('o')]))
.describe('The 3x3 board represented as a 9-element array.'),
moves: z
.number()
.min(0)
.max(9)
.describe('The number of moves made in the game.'),
player: z
.union([z.literal('x'), z.literal('o')])
.describe('The current player (x or o)'),
gameReport: z.string(),
events: z.array(z.string()),
},
});

type Player = 'x' | 'o';
Expand All @@ -41,7 +56,7 @@ const initialContext = {
player: 'x' as Player,
gameReport: '',
events: [],
} satisfies GameContext;
} satisfies typeof agent.types.context;

function getWinner(board: typeof initialContext.board): Player | null {
const lines = [
Expand All @@ -64,8 +79,8 @@ function getWinner(board: typeof initialContext.board): Player | null {

export const ticTacToeMachine = setup({
types: {
context: {} as GameContext,
events: agent.eventTypes,
context: agent.types.context,
events: agent.types.events,
},
actors: {
agent: fromDecision(agent),
Expand Down
4 changes: 3 additions & 1 deletion examples/todo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ const machine = setup({
todos: Todo[];
command: string | null;
},
events: {} as typeof agent.eventTypes | { type: 'assist'; command: string },
events: {} as
| typeof agent.types.events
| { type: 'assist'; command: string },
},
actors: { agent: fromDecision(agent), getFromTerminal },
}).createMachine({
Expand Down
2 changes: 1 addition & 1 deletion examples/tutor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const machine = setup({
context: {} as {
conversation: string[];
},
events: agent.eventTypes,
events: agent.types.events,
},
actors: { agent: fromDecision(agent), getFromTerminal },
}).createMachine({
Expand Down
2 changes: 1 addition & 1 deletion examples/verify.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ const machine = setup({
answer: string | null;
validation: string | null;
},
events: agent.eventTypes,
events: agent.types.events,
},
actors: {
getFromTerminal,
Expand Down
2 changes: 1 addition & 1 deletion examples/weather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ const machine = setup({
history: string[];
count: number;
},
events: agent.eventTypes,
events: agent.types.events,
},
actors: {
agent: fromDecision(agent),
Expand Down
2 changes: 1 addition & 1 deletion examples/word.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const agent = createAgent({
const wordGuesserMachine = setup({
types: {
context: {} as typeof context,
events: agent.eventTypes,
events: agent.types.events,
},
actors: {
agent: fromDecision(agent),
Expand Down
21 changes: 21 additions & 0 deletions src/agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,24 @@ test('You can listen for plan events', async () => {
})
);
});

test('agent.types provides context and event types', () => {
const agent = createAgent({
model: {} as any,
events: {
setScore: z.object({
score: z.number(),
}),
},
context: {
score: z.number(),
},
});

agent.types satisfies { context: any; events: any };

agent.types.context satisfies { score: number };

// @ts-expect-error
agent.types.context satisfies { score: string };
});
22 changes: 15 additions & 7 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
Observer,
toObserver,
} from 'xstate';
import { ZodEventMapping } from './schemas';
import { ZodContextMapping, ZodEventMapping } from './schemas';
import {
Agent,
AgentLogic,
Expand All @@ -21,12 +21,14 @@ import {
AgentObservationInput,
AgentMemoryContext,
AgentObservation,
ContextFromZodContextMapping,
} from './types';
import { simplePlanner } from './planners/simplePlanner';
import { agentGenerateText, agentStreamText } from './text';
import { agentDecide } from './decision';
import { vercelAdapter } from './adapters/vercel';
import { getMachineHash, randomId } from './utils';
import { SomeZodObject, TypeOf } from 'zod';

export const agentLogic: AgentLogic<AnyEventObject> = fromTransition(
(state, event, { emit }) => {
Expand Down Expand Up @@ -81,14 +83,17 @@ export const agentLogic: AgentLogic<AnyEventObject> = fromTransition(
);

export function createAgent<
const TContextSchema extends ZodContextMapping,
const TEventSchemas extends ZodEventMapping,
TEvents extends EventObject = EventsFromZodEventMapping<TEventSchemas>
TEvents extends EventObject = EventsFromZodEventMapping<TEventSchemas>,
TContext = ContextFromZodContextMapping<TContextSchema>
>({
name,
description,
model,
events,
planner = simplePlanner as AgentPlanner<Agent<TEvents>>,
context,
planner = simplePlanner as AgentPlanner<Agent<TContext, TEvents>>,
stringify = JSON.stringify,
getMemory,
logic = agentLogic as AgentLogic<TEvents>,
Expand Down Expand Up @@ -123,21 +128,22 @@ export function createAgent<
* that the agent knows about.
*/
events: TEventSchemas;
planner?: AgentPlanner<Agent<TEvents>>;
context?: TContextSchema;
planner?: AgentPlanner<Agent<TContext, TEvents>>;
stringify?: typeof JSON.stringify;
/**
* A function that retrieves the agent's long term memory
*/
getMemory?: (agent: Agent<any>) => AgentLongTermMemory;
getMemory?: (agent: Agent<TContext, TEvents>) => AgentLongTermMemory;
/**
* Agent logic
*/
logic?: AgentLogic<TEvents>;
adapter?: AIAdapter;
} & GenerateTextOptions): Agent<TEvents> {
} & GenerateTextOptions): Agent<TContext, TEvents> {
const messageHistoryListeners: Observer<AgentMessageHistory>[] = [];

const agent = createActor(logic) as unknown as Agent<TEvents>;
const agent = createActor(logic) as unknown as Agent<TContext, TEvents>;
agent.events = events;
agent.model = model;
agent.name = name;
Expand Down Expand Up @@ -280,6 +286,8 @@ export function createAgent<
};
};

agent.types = {} as any;

agent.start();

return agent;
Expand Down
Loading

0 comments on commit dc81287

Please sign in to comment.