Skip to content

Commit 8d65eb9

Browse files
feat: parent retriever/indexer
1 parent cc39e05 commit 8d65eb9

File tree

4 files changed

+401
-0
lines changed

4 files changed

+401
-0
lines changed

flow/indexer/parent/parent.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright 2024 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package parent
18+
19+
import (
20+
"context"
21+
"fmt"
22+
23+
"github.com/cloudwego/eino/components/document"
24+
"github.com/cloudwego/eino/components/indexer"
25+
"github.com/cloudwego/eino/schema"
26+
)
27+
28+
type Config struct {
29+
// Indexer specifies the original indexer used to create document index.
30+
Indexer indexer.Indexer
31+
// Transformer specifies the processor before creating document index, typically a splitter.
32+
Transformer document.Transformer
33+
// ParentIDKey specifies the key in the metadata of the sub-documents generated by the transformer to store the parent document ID.
34+
ParentIDKey string
35+
36+
// SubIDGenerator specifies the method for generating a specified number of sub-document IDs based on the parent document ID.
37+
SubIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error)
38+
}
39+
40+
func NewIndexer(ctx context.Context, config *Config) (indexer.Indexer, error) {
41+
if config.Indexer == nil {
42+
return nil, fmt.Errorf("indexer is empty")
43+
}
44+
if config.Transformer == nil {
45+
return nil, fmt.Errorf("transformer is empty")
46+
}
47+
if config.SubIDGenerator == nil {
48+
return nil, fmt.Errorf("sub id generator is empty")
49+
}
50+
51+
return &parentIndexer{
52+
indexer: config.Indexer,
53+
transformer: config.Transformer,
54+
parentIDKey: config.ParentIDKey,
55+
subIDGenerator: config.SubIDGenerator,
56+
}, nil
57+
}
58+
59+
type parentIndexer struct {
60+
indexer indexer.Indexer
61+
transformer document.Transformer
62+
parentIDKey string
63+
subIDGenerator func(ctx context.Context, parentID string, num int) ([]string, error)
64+
}
65+
66+
func (p *parentIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) {
67+
subDocs, err := p.transformer.Transform(ctx, docs)
68+
if err != nil {
69+
return nil, fmt.Errorf("transform docs fail: %w", err)
70+
}
71+
if len(subDocs) == 0 {
72+
return nil, fmt.Errorf("doc transformer returned no documents")
73+
}
74+
currentID := subDocs[0].ID
75+
startIdx := 0
76+
for i, subDoc := range subDocs {
77+
if subDoc.MetaData == nil {
78+
subDoc.MetaData = make(map[string]interface{})
79+
}
80+
subDoc.MetaData[p.parentIDKey] = subDoc.ID
81+
82+
if subDoc.ID == currentID {
83+
continue
84+
}
85+
86+
// generate new doc id
87+
subIDs, err := p.subIDGenerator(ctx, subDocs[startIdx].ID, i-startIdx)
88+
if err != nil {
89+
return nil, err
90+
}
91+
if len(subIDs) != i-startIdx {
92+
return nil, fmt.Errorf("generated sub IDs' num is unexpected")
93+
}
94+
for j := startIdx; j < i; j++ {
95+
subDocs[j].ID = subIDs[j-startIdx]
96+
}
97+
startIdx = i
98+
currentID = subDoc.ID
99+
}
100+
// generate new doc id
101+
subIDs, err := p.subIDGenerator(ctx, subDocs[startIdx].ID, len(subDocs)-startIdx)
102+
if err != nil {
103+
return nil, err
104+
}
105+
if len(subIDs) != len(subDocs)-startIdx {
106+
return nil, fmt.Errorf("generated sub IDs' num is unexpected")
107+
}
108+
for j := startIdx; j < len(subDocs); j++ {
109+
subDocs[j].ID = subIDs[j-startIdx]
110+
}
111+
112+
return p.indexer.Store(ctx, subDocs, opts...)
113+
}

flow/indexer/parent/parent_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright 2024 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package parent
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"reflect"
23+
"strconv"
24+
"strings"
25+
"testing"
26+
27+
"github.com/cloudwego/eino/components/document"
28+
"github.com/cloudwego/eino/components/indexer"
29+
"github.com/cloudwego/eino/schema"
30+
)
31+
32+
type testIndexer struct{}
33+
34+
func (t *testIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
35+
ret := make([]string, len(docs))
36+
for i, d := range docs {
37+
ret[i] = d.ID
38+
if !strings.HasPrefix(d.ID, d.MetaData["parent"].(string)) {
39+
return nil, fmt.Errorf("invalid parent key")
40+
}
41+
}
42+
return ret, nil
43+
}
44+
45+
type testTransformer struct {
46+
}
47+
48+
func (t *testTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) {
49+
var ret []*schema.Document
50+
for _, d := range src {
51+
ret = append(ret, &schema.Document{
52+
ID: d.ID,
53+
Content: d.Content[:len(d.Content)/2],
54+
MetaData: deepCopyMap(d.MetaData),
55+
}, &schema.Document{
56+
ID: d.ID,
57+
Content: d.Content[len(d.Content)/2:],
58+
MetaData: deepCopyMap(d.MetaData),
59+
})
60+
}
61+
return ret, nil
62+
}
63+
64+
func TestParentIndexer(t *testing.T) {
65+
tests := []struct {
66+
name string
67+
config *Config
68+
input []*schema.Document
69+
want []string
70+
}{
71+
{
72+
name: "success",
73+
config: &Config{
74+
Indexer: &testIndexer{},
75+
Transformer: &testTransformer{},
76+
ParentIDKey: "parent",
77+
SubIDGenerator: func(ctx context.Context, parentID string, num int) ([]string, error) {
78+
ret := make([]string, num)
79+
for i := range ret {
80+
ret[i] = parentID + strconv.Itoa(i)
81+
}
82+
return ret, nil
83+
},
84+
},
85+
input: []*schema.Document{{
86+
ID: "id",
87+
Content: "1234567890",
88+
MetaData: map[string]interface{}{},
89+
}, {
90+
ID: "ID",
91+
Content: "0987654321",
92+
MetaData: map[string]interface{}{},
93+
}},
94+
want: []string{"id0", "id1", "ID0", "ID1"},
95+
},
96+
}
97+
ctx := context.Background()
98+
for _, tt := range tests {
99+
t.Run(tt.name, func(t *testing.T) {
100+
index, err := NewIndexer(ctx, tt.config)
101+
if err != nil {
102+
t.Fatal(err)
103+
}
104+
ret, err := index.Store(ctx, tt.input)
105+
if err != nil {
106+
t.Fatal(err)
107+
}
108+
if !reflect.DeepEqual(ret, tt.want) {
109+
t.Errorf("NewHeaderSplitter() got = %v, want %v", ret, tt.want)
110+
}
111+
})
112+
}
113+
}
114+
115+
func deepCopyMap(in map[string]interface{}) map[string]interface{} {
116+
out := make(map[string]interface{})
117+
for k, v := range in {
118+
out[k] = v
119+
}
120+
return out
121+
}

flow/retriever/parent/parent.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright 2024 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package parent
18+
19+
import (
20+
"context"
21+
"fmt"
22+
23+
"github.com/cloudwego/eino/components/retriever"
24+
"github.com/cloudwego/eino/schema"
25+
)
26+
27+
type Config struct {
28+
// Retriever specifies the original retriever used to retrieve documents.
29+
Retriever retriever.Retriever
30+
// ParentIDKey specifies the key used in the sub-document metadata to store the parent document ID. Documents without this key will be removed from the recall results.
31+
ParentIDKey string
32+
// OrigDocGetter specifies the method for getting original documents by ids from the sub-document metadata.
33+
OrigDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error)
34+
}
35+
36+
func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, error) {
37+
if config.Retriever == nil {
38+
return nil, fmt.Errorf("retriever is required")
39+
}
40+
if config.OrigDocGetter == nil {
41+
return nil, fmt.Errorf("orig doc getter is required")
42+
}
43+
return &parentRetriever{
44+
retriever: config.Retriever,
45+
parentIDKey: config.ParentIDKey,
46+
origDocGetter: config.OrigDocGetter,
47+
}, nil
48+
}
49+
50+
type parentRetriever struct {
51+
retriever retriever.Retriever
52+
parentIDKey string
53+
origDocGetter func(ctx context.Context, ids []string) ([]*schema.Document, error)
54+
}
55+
56+
func (p *parentRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
57+
subDocs, err := p.retriever.Retrieve(ctx, query, opts...)
58+
if err != nil {
59+
return nil, err
60+
}
61+
ids := make([]string, 0, len(subDocs))
62+
for _, subDoc := range subDocs {
63+
if k, ok := subDoc.MetaData[p.parentIDKey]; ok {
64+
if s, okk := k.(string); okk && !inList(s, ids) {
65+
ids = append(ids, s)
66+
}
67+
}
68+
}
69+
return p.origDocGetter(ctx, ids)
70+
}
71+
72+
func inList(elem string, list []string) bool {
73+
for _, v := range list {
74+
if v == elem {
75+
return true
76+
}
77+
}
78+
return false
79+
}

0 commit comments

Comments
 (0)