Skip to content

Commit 1e2c772

Browse files
committed
feat: automatically repopulate if embedding model is switched
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 59a15e7 commit 1e2c772

File tree

5 files changed

+47
-3
lines changed

5 files changed

+47
-3
lines changed

rag/collection.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func NewPersistentChromeCollection(llmClient *openai.Client, collectionName, dbP
2626
filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)),
2727
filePath,
2828
chromemDB,
29-
maxChunkSize)
29+
maxChunkSize, llmClient, embeddingModel)
3030
if err != nil {
3131
xlog.Error("Failed to create PersistentKB", err)
3232
os.Exit(1)
@@ -44,7 +44,7 @@ func NewPersistentLocalAICollection(llmClient *openai.Client, apiURL, apiKey, co
4444
filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)),
4545
filePath,
4646
ragDB,
47-
maxChunkSize)
47+
maxChunkSize, llmClient, embeddingModel)
4848
if err != nil {
4949
xlog.Error("Failed to create PersistentKB", err)
5050
os.Exit(1)

rag/engine.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
type Engine interface {
99
Store(s string, metadata map[string]string) (engine.Result, error)
1010
StoreDocuments(s []string, metadata map[string]string) ([]engine.Result, error)
11+
GetEmbeddingDimensions() (int, error)
1112
Reset() error
1213
Search(s string, similarEntries int) ([]types.Result, error)
1314
Count() int

rag/engine/chromem.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ func NewChromemDBCollection(collection, path string, openaiClient *openai.Client
3939
}
4040
chromem.collection = c
4141

42+
count := c.Count()
43+
if count > 0 {
44+
chromem.index = count + 1
45+
}
46+
4247
return chromem, nil
4348
}
4449

@@ -59,6 +64,20 @@ func (c *ChromemDB) Reset() error {
5964
return nil
6065
}
6166

67+
func (c *ChromemDB) GetEmbeddingDimensions() (int, error) {
68+
count := c.collection.Count()
69+
if count == 0 {
70+
return 0, fmt.Errorf("no documents in collection")
71+
}
72+
73+
doc, err := c.collection.GetByID(context.Background(), fmt.Sprint(count))
74+
if err != nil {
75+
return 0, fmt.Errorf("error getting document: %v", err)
76+
}
77+
78+
return len(doc.Embedding), nil
79+
}
80+
6281
func (c *ChromemDB) embedding() chromem.EmbeddingFunc {
6382
return chromem.EmbeddingFunc(
6483
func(ctx context.Context, text string) ([]float32, error) {

rag/engine/localai.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ func (db *LocalAIRAGDB) Count() int {
3131
return 0
3232
}
3333

34+
func (db *LocalAIRAGDB) GetEmbeddingDimensions() (int, error) {
35+
return 0, fmt.Errorf("not implemented")
36+
}
37+
3438
func (db *LocalAIRAGDB) StoreDocuments(s []string, metadata map[string]string) ([]Result, error) {
3539
results := []Result{}
3640
for _, content := range s {

rag/persistency.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package rag
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"fmt"
78
"io"
@@ -15,6 +16,7 @@ import (
1516
"github.com/mudler/localrecall/pkg/xlog"
1617
"github.com/mudler/localrecall/rag/engine"
1718
"github.com/mudler/localrecall/rag/types"
19+
"github.com/sashabaranov/go-openai"
1820
)
1921

2022
// CollectionState represents the persistent state of a collection
@@ -55,7 +57,7 @@ func loadDB(path string) (*CollectionState, error) {
5557
return state, nil
5658
}
5759

58-
func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChunkSize int) (*PersistentKB, error) {
60+
func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChunkSize int, llmClient *openai.Client, embeddingModel string) (*PersistentKB, error) {
5961
// if file exists, try to load an existing state
6062
// if file does not exist, create a new state
6163
if err := os.MkdirAll(assetDir, 0755); err != nil {
@@ -89,6 +91,24 @@ func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChun
8991
index: state.Index,
9092
}
9193

94+
// TODO: Automatically repopulate if embeddings dimensions are mismatching.
95+
// To check if dimensions are mismatching, we can check the number of dimensions of the first embedding in the index if is the same as the
96+
// dimension that the embedding model returns.
97+
resp, err := llmClient.CreateEmbeddings(context.Background(),
98+
openai.EmbeddingRequestStrings{
99+
Input: []string{"test"},
100+
Model: openai.EmbeddingModel(embeddingModel),
101+
},
102+
)
103+
if err == nil && len(resp.Data) > 0 {
104+
embedding := resp.Data[0].Embedding
105+
embeddingDimensions, err := db.Engine.GetEmbeddingDimensions()
106+
if err == nil && len(embedding) != embeddingDimensions {
107+
xlog.Info("Embedding dimensions mismatch, repopulating", "embeddingDimensions", embeddingDimensions, "embedding", embedding)
108+
return db, db.Repopulate()
109+
}
110+
}
111+
92112
return db, nil
93113
}
94114

0 commit comments

Comments
 (0)