Skip to content

Commit

Permalink
Add machine hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkpiano committed Jun 25, 2024
1 parent 5076375 commit 094bb8a
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 13 deletions.
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"@langchain/core": "^0.1.63",
"@langchain/openai": "^0.0.28",
"@types/node": "^20.14.8",
"@types/object-hash": "^3.0.6",
"dotenv": "^16.4.5",
"json-schema-to-ts": "^3.1.0",
"ts-node": "^10.9.2",
Expand All @@ -48,6 +49,7 @@
"@ai-sdk/openai": "^0.0.31",
"@xstate/graph": "^2.0.0",
"ai": "^3.2.5",
"object-hash": "^3.0.0",
"xstate": "^5.14.0"
},
"packageManager": "[email protected]"
Expand Down
19 changes: 19 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 40 additions & 0 deletions src/agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,46 @@ test('agent.addObservation() adds to observations', () => {
);
});

test('agent.addObservation() adds to observations with machine hash', () => {
const agent = createAgent({
name: 'test',
events: {},
model: {} as any,
});

const machine = createMachine({
initial: 'playing',
states: {
playing: {
on: {
play: 'lost',
},
},
lost: {},
},
});

const observation = agent.addObservation({
prevState: { value: 'playing', context: {} },
event: { type: 'play', position: 3 },
state: { value: 'lost', context: {} },
machine,
});

expect(observation.sessionId).toEqual(agent.sessionId);

expect(agent.select((c) => c.observations)).toContainEqual(
expect.objectContaining({
prevState: { value: 'playing', context: {} },
event: { type: 'play', position: 3 },
state: { value: 'lost', context: {} },
machineHash: expect.any(String),
sessionId: expect.any(String),
timestamp: expect.any(Number),
})
);
});

test('agent.interact() observes machine actors (no 2nd arg)', () => {
const machine = createMachine({
initial: 'a',
Expand Down
17 changes: 13 additions & 4 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import {
ObservedState,
AgentObservationInput,
AgentMemoryContext,
AgentObservation,
} from './types';
import { simplePlanner } from './planners/simplePlanner';
import { agentGenerateText, agentStreamText } from './text';
import { agentDecide } from './decision';
import { vercelAdapter } from './adapters/vercel';
import { randomId } from './utils';
import { getMachineHash, randomId } from './utils';

export const agentLogic: AgentLogic<AnyEventObject> = fromTransition(
(state, event, { emit }) => {
Expand Down Expand Up @@ -189,12 +190,18 @@ export function createAgent<
};

agent.addObservation = (observationInput) => {
const { prevState, event, state } = observationInput;
const observation = {
...observationInput,
prevState,
event,
state,
id: observationInput.id ?? randomId(),
sessionId: agent.sessionId,
timestamp: observationInput.timestamp ?? Date.now(),
};
machineHash: observationInput.machine
? getMachineHash(observationInput.machine)
: undefined,
} satisfies AgentObservation<any>;

agent.send({
type: 'agent.observe',
Expand Down Expand Up @@ -249,7 +256,8 @@ export function createAgent<
event: inspEvent.event,
prevState,
state: inspEvent.snapshot as any,
};
machine: (actorRef as any).src,
} satisfies AgentObservationInput;

await handleObservation(observationInput);
},
Expand All @@ -261,6 +269,7 @@ export function createAgent<
prevState: undefined,
event: { type: '' }, // TODO: unknown events?
state: actorRef.getSnapshot(),
machine: (actorRef as any).src,
});
}

Expand Down
2 changes: 2 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ export interface AgentObservation<TActor extends AnyActorRef> {
prevState: SnapshotFrom<TActor> | undefined;
event: EventFrom<TActor>;
state: SnapshotFrom<TActor>;
machineHash: string | undefined;
sessionId: string;
timestamp: number;
}
Expand All @@ -178,6 +179,7 @@ export interface AgentObservationInput {
prevState: ObservedState | undefined;
event: AnyEventObject;
state: ObservedState;
machine?: AnyStateMachine;
timestamp?: number;
}

Expand Down
62 changes: 53 additions & 9 deletions src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,50 @@
import { AnyMachineSnapshot, AnyStateNode } from 'xstate';
import { AnyMachineSnapshot, AnyStateMachine, AnyStateNode } from 'xstate';
import hash from 'object-hash';
import { TransitionData } from './types';

export function getAllTransitions(state: AnyMachineSnapshot): TransitionData[] {
const nodes = state._nodes;
const transitions = (nodes as AnyStateNode[])
.map((node) => [...(node as AnyStateNode).transitions.values()])
.flat(2)
.map((transition) => ({
...transition,
guard:
typeof transition.guard === 'string'
? { type: transition.guard }
: (transition.guard as any), // TODO: fix
}));
.map((nodeTransitions) => {
return nodeTransitions.map((nodeEventTransitions) => {
return nodeEventTransitions.map((transition) => {
return {
...transition,
guard:
typeof transition.guard === 'string'
? { type: transition.guard }
: (transition.guard as any), // TODO: fix
};
});
});
})
.flat(2);

return transitions;
}

export function getAllMachineTransitions(
stateNode: AnyStateNode
): TransitionData[] {
const transitions: TransitionData[] = [...stateNode.transitions.values()]
.map((nodeTransitions) => {
return nodeTransitions.map((transition) => {
return {
...transition,
guard:
typeof transition.guard === 'string'
? { type: transition.guard }
: (transition.guard as any), // TODO: fix
};
});
})
.flat(2);

for (const s of Object.values(stateNode.states)) {
const stateTransitions = getAllMachineTransitions(s);
transitions.push(...stateTransitions);
}

return transitions;
}
Expand All @@ -26,3 +58,15 @@ export function randomId() {
const random = Math.random().toString(36).substring(2, 9);
return timestamp + random;
}

const machineHashes: WeakMap<AnyStateMachine, string> = new WeakMap();
/**
* Returns a string hash representing only the transitions in the state machine.
*/
export function getMachineHash(machine: AnyStateMachine): string {
if (machineHashes.has(machine)) return machineHashes.get(machine)!;
const transitions = getAllMachineTransitions(machine.root);
const machineHash = hash(transitions);
machineHashes.set(machine, machineHash);
return machineHash;
}

0 comments on commit 094bb8a

Please sign in to comment.