Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stores to chromem (WIP) #4659

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ service Backend {
rpc Status(HealthMessage) returns (StatusResponse) {}

rpc StoresSet(StoresSetOptions) returns (Result) {}
rpc StoresDelete(StoresDeleteOptions) returns (Result) {}
rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {}
rpc StoresReset(StoresResetOptions) returns (Result) {}
rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {}

rpc Rerank(RerankRequest) returns (RerankResult) {}
Expand Down Expand Up @@ -78,19 +77,10 @@ message StoresSetOptions {
repeated StoresValue Values = 2;
}

message StoresDeleteOptions {
message StoresResetOptions {
repeated StoresKey Keys = 1;
}

message StoresGetOptions {
repeated StoresKey Keys = 1;
}

message StoresGetResult {
repeated StoresKey Keys = 1;
repeated StoresValue Values = 2;
}

message StoresFindOptions {
StoresKey Key = 1;
int32 TopK = 2;
Expand Down
246 changes: 25 additions & 221 deletions backend/go/stores/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,101 +4,36 @@ package main
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"container/heap"
"context"
"fmt"
"math"
"slices"
"runtime"

"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
chromem "github.com/philippgille/chromem-go"

"github.com/rs/zerolog/log"
)

type Store struct {
base.SingleThread

// The sorted keys
keys [][]float32
// The sorted values
values [][]byte

// If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions
// TODO: Should we normalize incoming keys if they are not instead?
keysAreNormalized bool
// The first key decides the length of the keys
keyLen int
}

// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because
// that's theoretically best for memory layout and cache locality, but this isn't optimized yet.
type Pair struct {
Key []float32
Value []byte
*chromem.DB
*chromem.Collection
}

func NewStore() *Store {
return &Store{
keys: make([][]float32, 0),
values: make([][]byte, 0),
keysAreNormalized: true,
keyLen: -1,
}
}

func compareSlices(k1, k2 []float32) int {
assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2)))

return slices.Compare(k1, k2)
}

func hasKey(unsortedSlice [][]float32, target []float32) bool {
return slices.ContainsFunc(unsortedSlice, func(k []float32) bool {
return compareSlices(k, target) == 0
})
}

func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) {
return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int {
return compareSlices(k, t)
})
}

func isSortedPairs(kvs []Pair) bool {
for i := 1; i < len(kvs); i++ {
if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 {
return false
}
}

return true
}

func isSortedKeys(keys [][]float32) bool {
for i := 1; i < len(keys); i++ {
if compareSlices(keys[i-1], keys[i]) > 0 {
return false
}
}

return true
}

func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 {
ks := make([][]float32, len(keys))

for i, k := range keys {
ks[i] = k.Floats
}

slices.SortFunc(ks, compareSlices)

assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys)))
assert(isSortedKeys(ks), "keys are not sorted")

return ks
return &Store{}
}

func (s *Store) Load(opts *pb.ModelOptions) error {
db := chromem.NewDB()
collection, err := db.CreateCollection("all-documents", nil, nil)
if err != nil {
return err
}
s.DB = db
s.Collection = collection
return nil
}

Expand All @@ -111,156 +46,25 @@ func (s *Store) StoresSet(opts *pb.StoresSetOptions) error {
if len(opts.Keys) != len(opts.Values) {
return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values))
}

if s.keyLen == -1 {
s.keyLen = len(opts.Keys[0].Floats)
} else {
if len(opts.Keys[0].Floats) != s.keyLen {
return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
}
}

kvs := make([]Pair, len(opts.Keys))
docs := []chromem.Document{}

for i, k := range opts.Keys {
if s.keysAreNormalized && !isNormalized(k.Floats) {
s.keysAreNormalized = false
var sample []float32
if len(s.keys) > 5 {
sample = k.Floats[:5]
} else {
sample = k.Floats
}
log.Debug().Msgf("Key is not normalized: %v", sample)
}

kvs[i] = Pair{
Key: k.Floats,
Value: opts.Values[i].Bytes,
}
}

slices.SortFunc(kvs, func(a, b Pair) int {
return compareSlices(a.Key, b.Key)
})

assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys)))
assert(isSortedPairs(kvs), "keys are not sorted")

l := len(kvs) + len(s.keys)
merge_ks := make([][]float32, 0, l)
merge_vs := make([][]byte, 0, l)

i, j := 0, 0
for {
if i+j >= l {
break
}

if i >= len(kvs) {
merge_ks = append(merge_ks, s.keys[j])
merge_vs = append(merge_vs, s.values[j])
j++
continue
}

if j >= len(s.keys) {
merge_ks = append(merge_ks, kvs[i].Key)
merge_vs = append(merge_vs, kvs[i].Value)
i++
continue
}

c := compareSlices(kvs[i].Key, s.keys[j])
if c < 0 {
merge_ks = append(merge_ks, kvs[i].Key)
merge_vs = append(merge_vs, kvs[i].Value)
i++
} else if c > 0 {
merge_ks = append(merge_ks, s.keys[j])
merge_vs = append(merge_vs, s.values[j])
j++
} else {
merge_ks = append(merge_ks, kvs[i].Key)
merge_vs = append(merge_vs, kvs[i].Value)
i++
j++
}
docs = append(docs, chromem.Document{
ID: k.String(),
Content: opts.Values[i].String(),
})
}

assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l))
assert(isSortedKeys(merge_ks), "merge keys are not sorted")

s.keys = merge_ks
s.values = merge_vs

return nil
return s.Collection.AddDocuments(context.Background(), docs, runtime.NumCPU())
}

func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error {
if len(opts.Keys) == 0 {
return fmt.Errorf("no keys to delete")
}

if len(opts.Keys) == 0 {
return fmt.Errorf("no keys to add")
func (s *Store) StoresReset(opts *pb.StoresResetOptions) error {
err := s.DB.DeleteCollection("all-documents")
if err != nil {
return err
}

if s.keyLen == -1 {
s.keyLen = len(opts.Keys[0].Floats)
} else {
if len(opts.Keys[0].Floats) != s.keyLen {
return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen)
}
}

ks := sortIntoKeySlicese(opts.Keys)

l := len(s.keys) - len(ks)
merge_ks := make([][]float32, 0, l)
merge_vs := make([][]byte, 0, l)

tail_ks := s.keys
tail_vs := s.values
for _, k := range ks {
j, found := findInSortedSlice(tail_ks, k)

if found {
merge_ks = append(merge_ks, tail_ks[:j]...)
merge_vs = append(merge_vs, tail_vs[:j]...)
tail_ks = tail_ks[j+1:]
tail_vs = tail_vs[j+1:]
} else {
assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k))
}

log.Debug().Msgf("Delete: found = %v, t = %d, j = %d, len(merge_ks) = %d, len(merge_vs) = %d", found, len(tail_ks), j, len(merge_ks), len(merge_vs))
}

merge_ks = append(merge_ks, tail_ks...)
merge_vs = append(merge_vs, tail_vs...)

assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys)))

s.keys = merge_ks
s.values = merge_vs

assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l))
assert(isSortedKeys(s.keys), "keys are not sorted")
assert(func() bool {
for _, k := range ks {
if _, found := findInSortedSlice(s.keys, k); found {
return false
}
}
return true
}(), "Keys to delete still present")

if len(s.keys) != l {
log.Debug().Msgf("Delete: Some keys not found: len(s.keys) = %d, l = %d", len(s.keys), l)
}

return nil
s.Collection, err = s.CreateCollection("all-documents", nil, nil)
return err
}

func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) {
Expand Down
2 changes: 1 addition & 1 deletion core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ var _ = Describe("API test", func() {
}
}

deleteBody := schema.StoresDelete{
deleteBody := schema.StoresReset{
Keys: [][]float32{
{0.1, 0.2, 0.3},
},
Expand Down
37 changes: 3 additions & 34 deletions core/http/endpoints/localai/stores.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/store"

Check failure on line 9 in core/http/endpoints/localai/stores.go

View workflow job for this annotation

GitHub Actions / build-linux-arm

no required module provides package github.com/mudler/LocalAI/pkg/store; to add it:
)

func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
Expand Down Expand Up @@ -36,9 +36,9 @@
}
}

func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
func StoresResetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresDelete)
input := new(schema.StoresReset)

if err := c.BodyParser(input); err != nil {
return err
Expand All @@ -49,45 +49,14 @@
return err
}

if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
if _, err := sb.StoresReset(c.Context(), nil); err != nil {
return err
}

return c.Send(nil)
}
}

func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresGet)

if err := c.BodyParser(input); err != nil {
return err
}

sb, err := backend.StoreBackend(sl, appConfig, input.Store)
if err != nil {
return err
}

keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
if err != nil {
return err
}

res := schema.StoresGetResponse{
Keys: keys,
Values: make([]string, len(vals)),
}

for i, v := range vals {
res.Values[i] = string(v)
}

return c.JSON(res)
}
}

func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresFind)
Expand Down
3 changes: 1 addition & 2 deletions core/http/routes/localai.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ func RegisterLocalAIRoutes(router *fiber.App,
// Stores
sl := model.NewModelLoader("")
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
router.Post("/stores/reset", localai.StoresDeleteEndpoint(sl, appConfig))
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))

if !appConfig.DisableMetrics {
Expand Down
Loading
Loading