Skip to content

Commit

Permalink
feat: change chain default key to node_idx_parallel_idx style
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhahalong committed Jan 14, 2025
1 parent a175721 commit 40334dd
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 28 deletions.
43 changes: 17 additions & 26 deletions compose/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ type Chain[I, O any] struct {

gg *Graph[I, O]

namePrefix string
nodeIdx int
nodeIdx int

preNodeKeys []string

Expand Down Expand Up @@ -338,12 +337,12 @@ func (c *Chain[I, O]) AppendBranch(b *ChainBranch) *Chain[I, O] { // nolint: byt
return c
}

pName := c.nextNodeKey("Branch")
prefix := c.nextNodeKey()
key2NodeKey := make(map[string]string, len(b.key2BranchNode))

for key := range b.key2BranchNode {
node := b.key2BranchNode[key]
nodeKey := fmt.Sprintf("%s[%s]_%s", pName, key, genNodeKeySuffix(node.First))
nodeKey := fmt.Sprintf("%s_branch_%s", prefix, key)

if err := c.gg.addNode(nodeKey, node.First, node.Second); err != nil {
c.reportError(fmt.Errorf("add branch node[%s] to chain failed: %w", nodeKey, err))
Expand Down Expand Up @@ -467,18 +466,18 @@ func (c *Chain[I, O]) AppendParallel(p *Parallel) *Chain[I, O] {
return c
}

pName := c.nextNodeKey("Parallel")
prefix := c.nextNodeKey()
var nodeKeys []string

for i := range p.nodes {
node := p.nodes[i]
nodeKey := fmt.Sprintf("%s[%d]_%s", pName, i, genNodeKeySuffix(node.First))
nodeKey := fmt.Sprintf("%s_parallel_%d", prefix, i)
if err := c.gg.addNode(nodeKey, node.First, node.Second); err != nil {
c.reportError(fmt.Errorf("add parallel node[%s] to chain failed: %w", nodeKey, err))
c.reportError(fmt.Errorf("add parallel node [%s] to chain failed: %w", nodeKey, err))
return c
}
if err := c.gg.AddEdge(startNode, nodeKey); err != nil {
c.reportError(fmt.Errorf("add parallel edge[%s]-[%s] to chain failed: %w", startNode, nodeKey, err))
c.reportError(fmt.Errorf("add parallel edge [%s]-[%s] to chain failed: %w", startNode, nodeKey, err))
return c
}
nodeKeys = append(nodeKeys, nodeKey)
Expand Down Expand Up @@ -512,17 +511,15 @@ func (c *Chain[I, O]) AppendPassthrough(opts ...GraphAddNodeOpt) *Chain[I, O] {
return c
}

// nextNodeKey.
// get the next node key for the chain.
// e.g. "Chain[1]_ChatModel" => represent the second node of the chain, and is a ChatModel node.
// e.g. "Chain[2]_NameByUser" => represent the third node of the chain, and the node name is set by user of `NameByUser`.
func (c *Chain[I, O]) nextNodeKey(name string) string {
if c.namePrefix == "" {
c.namePrefix = string(ComponentOfChain)
}
fullKey := fmt.Sprintf("%s[%d]_%s", c.namePrefix, c.nodeIdx, name)
// nextIdx.
// get the next idx for the chain.
// chain key is: node_idx => eg: node_0 => represent the first node of the chain (idx start from 0)
// if has parallel: node_idx_parallel_idx => eg: node_0_parallel_1 => represent the first node of the chain, and is a parallel node, and the second node of the parallel
// if has branch: node_idx_branch_key => eg: node_1_branch_customkey => represent the second node of the chain, and is a branch node, and the 'customkey' is the key of the branch
func (c *Chain[I, O]) nextNodeKey() string {
idx := c.nodeIdx
c.nodeIdx++
return fullKey
return fmt.Sprintf("node_%d", idx)
}

// reportError.
Expand Down Expand Up @@ -551,8 +548,9 @@ func (c *Chain[I, O]) addNode(node *graphNode, options *graphAddNodeOpts) {
}

nodeKey := options.nodeOptions.nodeKey
defaultNodeKey := c.nextNodeKey()
if nodeKey == "" {
nodeKey = c.nextNodeKey(genNodeKeySuffix(node))
nodeKey = defaultNodeKey
}

err := c.gg.addNode(nodeKey, node, options)
Expand All @@ -575,10 +573,3 @@ func (c *Chain[I, O]) addNode(node *graphNode, options *graphAddNodeOpts) {

c.preNodeKeys = []string{nodeKey}
}

func genNodeKeySuffix(node *graphNode) string {
if len(node.nodeInfo.name) == 0 {
return node.executorMeta.componentImplType + string(node.executorMeta.component)
}
return node.nodeInfo.name
}
6 changes: 5 additions & 1 deletion compose/chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ func TestChainWithException(t *testing.T) {
// just pass through
t.Log("in view lambda: ", kvs)
return kvs, nil
}))
})).
AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) {
t.Log("in view lambda 02: ", kvs)
return kvs, nil
}), WithNodeKey("xlam"))

// items with parallels
parallel := NewParallel()
Expand Down
2 changes: 1 addition & 1 deletion compose/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func TestValidate(t *testing.T) {
p = NewParallel().AddLambda("1", lA).AddLambda("2", lAB)
c = NewChain[string, map[string]any]().AppendParallel(p)
_, err = c.Compile(context.Background())
assert.ErrorContains(t, err, "add parallel edge[start]-[Chain[0]_Parallel[0]_Lambda] to chain failed: graph edge[start]-[Chain[0]_Parallel[0]_Lambda]: start node's output type[string] and end node's input type[compose.A] mismatch")
assert.ErrorContains(t, err, "add parallel edge [start]-[node_1_parallel_0] to chain failed: graph edge[start]-[node_1_parallel_0]: start node's output type[string] and end node's input type[compose.A] mismatch")

// test graph output type check
gg := NewGraph[string, A]()
Expand Down

0 comments on commit 40334dd

Please sign in to comment.