diff --git a/compose/chain.go b/compose/chain.go index 0f07e81..2a1dae4 100644 --- a/compose/chain.go +++ b/compose/chain.go @@ -75,8 +75,7 @@ type Chain[I, O any] struct { gg *Graph[I, O] - namePrefix string - nodeIdx int + nodeIdx int preNodeKeys []string @@ -338,12 +337,12 @@ func (c *Chain[I, O]) AppendBranch(b *ChainBranch) *Chain[I, O] { // nolint: byt return c } - pName := c.nextNodeKey("Branch") + prefix := c.nextNodeKey() key2NodeKey := make(map[string]string, len(b.key2BranchNode)) for key := range b.key2BranchNode { node := b.key2BranchNode[key] - nodeKey := fmt.Sprintf("%s[%s]_%s", pName, key, genNodeKeySuffix(node.First)) + nodeKey := fmt.Sprintf("%s_branch_%s", prefix, key) if err := c.gg.addNode(nodeKey, node.First, node.Second); err != nil { c.reportError(fmt.Errorf("add branch node[%s] to chain failed: %w", nodeKey, err)) @@ -467,18 +466,18 @@ func (c *Chain[I, O]) AppendParallel(p *Parallel) *Chain[I, O] { return c } - pName := c.nextNodeKey("Parallel") + prefix := c.nextNodeKey() var nodeKeys []string for i := range p.nodes { node := p.nodes[i] - nodeKey := fmt.Sprintf("%s[%d]_%s", pName, i, genNodeKeySuffix(node.First)) + nodeKey := fmt.Sprintf("%s_parallel_%d", prefix, i) if err := c.gg.addNode(nodeKey, node.First, node.Second); err != nil { - c.reportError(fmt.Errorf("add parallel node[%s] to chain failed: %w", nodeKey, err)) + c.reportError(fmt.Errorf("add parallel node [%s] to chain failed: %w", nodeKey, err)) return c } if err := c.gg.AddEdge(startNode, nodeKey); err != nil { - c.reportError(fmt.Errorf("add parallel edge[%s]-[%s] to chain failed: %w", startNode, nodeKey, err)) + c.reportError(fmt.Errorf("add parallel edge [%s]-[%s] to chain failed: %w", startNode, nodeKey, err)) return c } nodeKeys = append(nodeKeys, nodeKey) @@ -512,17 +511,15 @@ func (c *Chain[I, O]) AppendPassthrough(opts ...GraphAddNodeOpt) *Chain[I, O] { return c } -// nextNodeKey. -// get the next node key for the chain. -// e.g. "Chain[1]_ChatModel" => represent the second node of the chain, and is a ChatModel node. -// e.g. "Chain[2]_NameByUser" => represent the third node of the chain, and the node name is set by user of `NameByUser`. -func (c *Chain[I, O]) nextNodeKey(name string) string { - if c.namePrefix == "" { - c.namePrefix = string(ComponentOfChain) - } - fullKey := fmt.Sprintf("%s[%d]_%s", c.namePrefix, c.nodeIdx, name) +// nextIdx. +// get the next idx for the chain. +// chain key is: node_idx => eg: node_0 => represent the first node of the chain (idx start from 0) +// if has parallel: node_idx_parallel_idx => eg: node_0_parallel_1 => represent the first node of the chain, and is a parallel node, and the second node of the parallel +// if has branch: node_idx_branch_key => eg: node_1_branch_customkey => represent the second node of the chain, and is a branch node, and the 'customkey' is the key of the branch +func (c *Chain[I, O]) nextNodeKey() string { + idx := c.nodeIdx c.nodeIdx++ - return fullKey + return fmt.Sprintf("node_%d", idx) } // reportError. @@ -551,8 +548,9 @@ func (c *Chain[I, O]) addNode(node *graphNode, options *graphAddNodeOpts) { } nodeKey := options.nodeOptions.nodeKey + defaultNodeKey := c.nextNodeKey() if nodeKey == "" { - nodeKey = c.nextNodeKey(genNodeKeySuffix(node)) + nodeKey = defaultNodeKey } err := c.gg.addNode(nodeKey, node, options) @@ -575,10 +573,3 @@ func (c *Chain[I, O]) addNode(node *graphNode, options *graphAddNodeOpts) { c.preNodeKeys = []string{nodeKey} } - -func genNodeKeySuffix(node *graphNode) string { - if len(node.nodeInfo.name) == 0 { - return node.executorMeta.componentImplType + string(node.executorMeta.component) - } - return node.nodeInfo.name -} diff --git a/compose/chain_test.go b/compose/chain_test.go index a6fc68f..570ca0e 100644 --- a/compose/chain_test.go +++ b/compose/chain_test.go @@ -116,7 +116,11 @@ func TestChainWithException(t *testing.T) { // just pass through t.Log("in view lambda: ", kvs) return kvs, nil - })) + })). + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in view lambda 02: ", kvs) + return kvs, nil + }), WithNodeKey("xlam")) // items with parallels parallel := NewParallel() diff --git a/compose/graph_test.go b/compose/graph_test.go index 7cc5587..b4835c3 100644 --- a/compose/graph_test.go +++ b/compose/graph_test.go @@ -423,7 +423,7 @@ func TestValidate(t *testing.T) { p = NewParallel().AddLambda("1", lA).AddLambda("2", lAB) c = NewChain[string, map[string]any]().AppendParallel(p) _, err = c.Compile(context.Background()) - assert.ErrorContains(t, err, "add parallel edge[start]-[Chain[0]_Parallel[0]_Lambda] to chain failed: graph edge[start]-[Chain[0]_Parallel[0]_Lambda]: start node's output type[string] and end node's input type[compose.A] mismatch") + assert.ErrorContains(t, err, "add parallel edge [start]-[node_1_parallel_0] to chain failed: graph edge[start]-[node_1_parallel_0]: start node's output type[string] and end node's input type[compose.A] mismatch") // test graph output type check gg := NewGraph[string, A]()