@@ -2,6 +2,7 @@ package rag
2
2
3
3
import (
4
4
"bytes"
5
+ "context"
5
6
"encoding/json"
6
7
"fmt"
7
8
"io"
@@ -15,6 +16,7 @@ import (
15
16
"github.com/mudler/localrecall/pkg/xlog"
16
17
"github.com/mudler/localrecall/rag/engine"
17
18
"github.com/mudler/localrecall/rag/types"
19
+ "github.com/sashabaranov/go-openai"
18
20
)
19
21
20
22
// CollectionState represents the persistent state of a collection
@@ -55,7 +57,7 @@ func loadDB(path string) (*CollectionState, error) {
55
57
return state , nil
56
58
}
57
59
58
- func NewPersistentCollectionKB (stateFile , assetDir string , store Engine , maxChunkSize int ) (* PersistentKB , error ) {
60
+ func NewPersistentCollectionKB (stateFile , assetDir string , store Engine , maxChunkSize int , llmClient * openai. Client , embeddingModel string ) (* PersistentKB , error ) {
59
61
// if file exists, try to load an existing state
60
62
// if file does not exist, create a new state
61
63
if err := os .MkdirAll (assetDir , 0755 ); err != nil {
@@ -89,6 +91,24 @@ func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChun
89
91
index : state .Index ,
90
92
}
91
93
94
+ // TODO: Automatically repopulate if embeddings dimensions are mismatching.
95
+ // To check if dimensions are mismatching, we can check the number of dimensions of the first embedding in the index if is the same as the
96
+ // dimension that the embedding model returns.
97
+ resp , err := llmClient .CreateEmbeddings (context .Background (),
98
+ openai.EmbeddingRequestStrings {
99
+ Input : []string {"test" },
100
+ Model : openai .EmbeddingModel (embeddingModel ),
101
+ },
102
+ )
103
+ if err == nil && len (resp .Data ) > 0 {
104
+ embedding := resp .Data [0 ].Embedding
105
+ embeddingDimensions , err := db .Engine .GetEmbeddingDimensions ()
106
+ if err == nil && len (embedding ) != embeddingDimensions {
107
+ xlog .Info ("Embedding dimensions mismatch, repopulating" , "embeddingDimensions" , embeddingDimensions , "embedding" , embedding )
108
+ return db , db .Repopulate ()
109
+ }
110
+ }
111
+
92
112
return db , nil
93
113
}
94
114
0 commit comments