20
20
import operator
21
21
from pathlib import Path
22
22
from typing import Generic , Optional , Sequence , TypeVar , TypedDict , cast
23
- from typing_extensions import override
23
+ import chromadb .api
24
+ from typing_extensions import override , Self
24
25
import chromadb
25
26
26
27
from parlant .core .common import Version
@@ -67,18 +68,15 @@ def __init__(
67
68
dir_path : Path ,
68
69
embedder_factory : EmbedderFactory ,
69
70
) -> None :
71
+ self ._dir_path = dir_path
70
72
self ._logger = logger
71
73
self ._embedder_factory = embedder_factory
72
74
73
- self ._chroma_client = chromadb .PersistentClient (str (dir_path ))
74
- self ._collections : dict [str , ChromaCollection [BaseDocument ]] = (
75
- self ._load_chromadb_collections ()
76
- )
75
+ self ._chroma_client : chromadb .api .ClientAPI
76
+ self ._collections : dict [str , ChromaCollection [BaseDocument ]] = {}
77
77
78
- def _load_chromadb_collections (
79
- self ,
80
- ) -> dict [str , ChromaCollection [BaseDocument ]]:
81
- collections : dict [str , ChromaCollection [BaseDocument ]] = {}
78
+ async def __aenter__ (self ) -> Self :
79
+ self ._chroma_client = chromadb .PersistentClient (str (self ._dir_path ))
82
80
for chromadb_collection in self ._chroma_client .list_collections ():
83
81
embedder_module = importlib .import_module (
84
82
chromadb_collection .metadata ["embedder_module_path" ]
@@ -94,7 +92,7 @@ def _load_chromadb_collections(
94
92
embedding_function = None ,
95
93
)
96
94
97
- collections [chromadb_collection .name ] = ChromaCollection (
95
+ self . _collections [chromadb_collection .name ] = ChromaCollection (
98
96
logger = self ._logger ,
99
97
chromadb_collection = chroma_collection ,
100
98
name = chromadb_collection .name ,
@@ -103,9 +101,17 @@ def _load_chromadb_collections(
103
101
),
104
102
embedder = embedder ,
105
103
)
106
- return collections
104
+ return self
105
+
106
+ async def __aexit__ (
107
+ self ,
108
+ exc_type : Optional [type [BaseException ]],
109
+ exc_value : Optional [BaseException ],
110
+ traceback : Optional [object ],
111
+ ) -> None :
112
+ pass
107
113
108
- def create_collection (
114
+ async def create_collection (
109
115
self ,
110
116
name : str ,
111
117
schema : type [TDocument ],
@@ -133,7 +139,7 @@ def create_collection(
133
139
134
140
return cast (ChromaCollection [TDocument ], self ._collections [name ])
135
141
136
- def get_collection (
142
+ async def get_collection (
137
143
self ,
138
144
name : str ,
139
145
) -> ChromaCollection [TDocument ]:
@@ -142,7 +148,7 @@ def get_collection(
142
148
143
149
raise ValueError (f'ChromaDB collection "{ name } " not found.' )
144
150
145
- def get_or_create_collection (
151
+ async def get_or_create_collection (
146
152
self ,
147
153
name : str ,
148
154
schema : type [TDocument ],
@@ -171,7 +177,7 @@ def get_or_create_collection(
171
177
172
178
return cast (ChromaCollection [TDocument ], self ._collections [name ])
173
179
174
- def delete_collection (
180
+ async def delete_collection (
175
181
self ,
176
182
name : str ,
177
183
) -> None :
0 commit comments