Skip to content

Commit a880489

Browse files
committed
Making DocumentDatabase and ChromaDatabase with async functions
1 parent 9986530 commit a880489

18 files changed

+290
-107
lines changed

src/parlant/adapters/db/chroma/database.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import operator
2121
from pathlib import Path
2222
from typing import Generic, Optional, Sequence, TypeVar, TypedDict, cast
23-
from typing_extensions import override
23+
import chromadb.api
24+
from typing_extensions import override, Self
2425
import chromadb
2526

2627
from parlant.core.common import Version
@@ -67,18 +68,15 @@ def __init__(
6768
dir_path: Path,
6869
embedder_factory: EmbedderFactory,
6970
) -> None:
71+
self._dir_path = dir_path
7072
self._logger = logger
7173
self._embedder_factory = embedder_factory
7274

73-
self._chroma_client = chromadb.PersistentClient(str(dir_path))
74-
self._collections: dict[str, ChromaCollection[BaseDocument]] = (
75-
self._load_chromadb_collections()
76-
)
75+
self._chroma_client: chromadb.api.ClientAPI
76+
self._collections: dict[str, ChromaCollection[BaseDocument]] = {}
7777

78-
def _load_chromadb_collections(
79-
self,
80-
) -> dict[str, ChromaCollection[BaseDocument]]:
81-
collections: dict[str, ChromaCollection[BaseDocument]] = {}
78+
async def __aenter__(self) -> Self:
79+
self._chroma_client = chromadb.PersistentClient(str(self._dir_path))
8280
for chromadb_collection in self._chroma_client.list_collections():
8381
embedder_module = importlib.import_module(
8482
chromadb_collection.metadata["embedder_module_path"]
@@ -94,7 +92,7 @@ def _load_chromadb_collections(
9492
embedding_function=None,
9593
)
9694

97-
collections[chromadb_collection.name] = ChromaCollection(
95+
self._collections[chromadb_collection.name] = ChromaCollection(
9896
logger=self._logger,
9997
chromadb_collection=chroma_collection,
10098
name=chromadb_collection.name,
@@ -103,9 +101,17 @@ def _load_chromadb_collections(
103101
),
104102
embedder=embedder,
105103
)
106-
return collections
104+
return self
105+
106+
async def __aexit__(
107+
self,
108+
exc_type: Optional[type[BaseException]],
109+
exc_value: Optional[BaseException],
110+
traceback: Optional[object],
111+
) -> None:
112+
pass
107113

108-
def create_collection(
114+
async def create_collection(
109115
self,
110116
name: str,
111117
schema: type[TDocument],
@@ -133,7 +139,7 @@ def create_collection(
133139

134140
return cast(ChromaCollection[TDocument], self._collections[name])
135141

136-
def get_collection(
142+
async def get_collection(
137143
self,
138144
name: str,
139145
) -> ChromaCollection[TDocument]:
@@ -142,7 +148,7 @@ def get_collection(
142148

143149
raise ValueError(f'ChromaDB collection "{name}" not found.')
144150

145-
def get_or_create_collection(
151+
async def get_or_create_collection(
146152
self,
147153
name: str,
148154
schema: type[TDocument],
@@ -171,7 +177,7 @@ def get_or_create_collection(
171177

172178
return cast(ChromaCollection[TDocument], self._collections[name])
173179

174-
def delete_collection(
180+
async def delete_collection(
175181
self,
176182
name: str,
177183
) -> None:

src/parlant/adapters/db/chroma/glossary.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from datetime import datetime, timezone
1717
from itertools import chain
1818
from typing import Optional, Sequence
19-
from typing_extensions import override, TypedDict
19+
from typing_extensions import override, TypedDict, Self
2020

21-
from parlant.adapters.db.chroma.database import ChromaDatabase
21+
from parlant.adapters.db.chroma.database import ChromaCollection, ChromaDatabase
2222
from parlant.core.common import (
2323
ItemNotFoundError,
2424
UniqueId,
@@ -49,12 +49,28 @@ def __init__(
4949
chroma_db: ChromaDatabase,
5050
embedder_type: type[Embedder],
5151
):
52-
self._collection = chroma_db.get_or_create_collection(
52+
self._embedder_type = embedder_type
53+
self._embedder = embedder_type()
54+
55+
self._chroma_db = chroma_db
56+
self._collection: ChromaCollection[_TermDocument]
57+
58+
async def __aenter__(self) -> Self:
59+
self._collection = await self._chroma_db.get_or_create_collection(
5360
name="glossary",
5461
schema=_TermDocument,
55-
embedder_type=embedder_type,
62+
embedder_type=self._embedder_type,
5663
)
57-
self._embedder = embedder_type()
64+
65+
return self
66+
67+
async def __aexit__(
68+
self,
69+
exc_type: Optional[type[BaseException]],
70+
exc_value: Optional[BaseException],
71+
traceback: Optional[object],
72+
) -> None:
73+
pass
5874

5975
def _serialize(self, term: Term, term_set: str, content: str) -> _TermDocument:
6076
return _TermDocument(

src/parlant/adapters/db/json_file.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import operator
2020
from pathlib import Path
2121
from typing import Any, Mapping, Optional, Sequence, cast
22-
from typing_extensions import override
22+
from typing_extensions import override, Self
2323
import aiofiles
2424

2525
from parlant.core.persistence.document_database import (
@@ -57,7 +57,7 @@ async def flush(self) -> None:
5757
async with self._lock:
5858
await self._flush_unlocked()
5959

60-
async def __aenter__(self) -> JSONFileDocumentDatabase:
60+
async def __aenter__(self) -> Self:
6161
async with self._lock:
6262
raw_data = await self._load_data()
6363

@@ -122,7 +122,7 @@ async def _save_data(
122122
await file.write(json_string)
123123

124124
@override
125-
def create_collection(
125+
async def create_collection(
126126
self,
127127
name: str,
128128
schema: type[TDocument],
@@ -138,7 +138,7 @@ def create_collection(
138138
return cast(JSONFileDocumentCollection[TDocument], self._collections[name])
139139

140140
@override
141-
def get_collection(
141+
async def get_collection(
142142
self,
143143
name: str,
144144
) -> JSONFileDocumentCollection[TDocument]:
@@ -147,7 +147,7 @@ def get_collection(
147147
raise ValueError(f'Collection "{name}" does not exists')
148148

149149
@override
150-
def get_or_create_collection(
150+
async def get_or_create_collection(
151151
self,
152152
name: str,
153153
schema: type[TDocument],
@@ -164,7 +164,7 @@ def get_or_create_collection(
164164
return cast(JSONFileDocumentCollection[TDocument], self._collections[name])
165165

166166
@override
167-
def delete_collection(
167+
async def delete_collection(
168168
self,
169169
name: str,
170170
) -> None:

src/parlant/adapters/db/transient.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self) -> None:
3636
self._collections: dict[str, _TransientDocumentCollection[BaseDocument]] = {}
3737

3838
@override
39-
def create_collection(
39+
async def create_collection(
4040
self,
4141
name: str,
4242
schema: type[TDocument],
@@ -52,7 +52,7 @@ def create_collection(
5252
return cast(_TransientDocumentCollection[TDocument], self._collections[name])
5353

5454
@override
55-
def get_collection(
55+
async def get_collection(
5656
self,
5757
name: str,
5858
) -> _TransientDocumentCollection[TDocument]:
@@ -61,7 +61,7 @@ def get_collection(
6161
raise ValueError(f'Collection "{name}" does not exist')
6262

6363
@override
64-
def get_or_create_collection(
64+
async def get_or_create_collection(
6565
self,
6666
name: str,
6767
schema: type[TDocument],
@@ -72,13 +72,13 @@ def get_or_create_collection(
7272
annotations = get_type_hints(schema)
7373
assert "id" in annotations and annotations["id"] == ObjectId
7474

75-
return self.create_collection(
75+
return await self.create_collection(
7676
name=name,
7777
schema=schema,
7878
)
7979

8080
@override
81-
def delete_collection(
81+
async def delete_collection(
8282
self,
8383
name: str,
8484
) -> None:

src/parlant/bin/server.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,25 @@ async def setup_container(nlp_service_name: str) -> AsyncIterator[Container]:
224224
JSONFileDocumentDatabase(LOGGER, PARLANT_HOME_DIR / "services.json")
225225
)
226226

227-
c[AgentStore] = AgentDocumentStore(agents_db)
228-
c[ContextVariableStore] = ContextVariableDocumentStore(context_variables_db)
229-
c[TagStore] = TagDocumentStore(tags_db)
230-
c[CustomerStore] = CustomerDocumentStore(customers_db)
231-
c[GuidelineStore] = GuidelineDocumentStore(guidelines_db)
232-
c[GuidelineToolAssociationStore] = GuidelineToolAssociationDocumentStore(
233-
guideline_tool_associations_db
234-
)
235-
c[GuidelineConnectionStore] = GuidelineConnectionDocumentStore(guideline_connections_db)
236-
c[SessionStore] = SessionDocumentStore(sessions_db)
227+
c[AgentStore] = await EXIT_STACK.enter_async_context(AgentDocumentStore(agents_db))
228+
c[ContextVariableStore] = await EXIT_STACK.enter_async_context(
229+
ContextVariableDocumentStore(context_variables_db)
230+
)
231+
c[TagStore] = await EXIT_STACK.enter_async_context(TagDocumentStore(tags_db))
232+
c[CustomerStore] = await EXIT_STACK.enter_async_context(CustomerDocumentStore(customers_db))
233+
c[GuidelineStore] = await EXIT_STACK.enter_async_context(GuidelineDocumentStore(guidelines_db))
234+
c[GuidelineToolAssociationStore] = await EXIT_STACK.enter_async_context(
235+
GuidelineToolAssociationDocumentStore(guideline_tool_associations_db)
236+
)
237+
c[GuidelineConnectionStore] = await EXIT_STACK.enter_async_context(
238+
GuidelineConnectionDocumentStore(guideline_connections_db)
239+
)
240+
c[SessionStore] = await EXIT_STACK.enter_async_context(SessionDocumentStore(sessions_db))
237241
c[SessionListener] = PollingSessionListener
238242

239-
c[EvaluationStore] = EvaluationDocumentStore(evaluations_db)
243+
c[EvaluationStore] = await EXIT_STACK.enter_async_context(
244+
EvaluationDocumentStore(evaluations_db)
245+
)
240246
c[EvaluationListener] = PollingEvaluationListener
241247

242248
c[EventEmitterFactory] = Singleton(EventPublisherFactory)

src/parlant/core/agents.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from dataclasses import dataclass
1717
from datetime import datetime, timezone
1818
from typing import NewType, Optional, Sequence, cast
19-
from typing_extensions import override, TypedDict
19+
from typing_extensions import override, TypedDict, Self
2020

2121
from parlant.core.common import ItemNotFoundError, UniqueId, Version, generate_id
2222
from parlant.core.persistence.document_database import (
23+
DocumentCollection,
2324
DocumentDatabase,
2425
ObjectId,
2526
)
@@ -93,10 +94,23 @@ def __init__(
9394
self,
9495
database: DocumentDatabase,
9596
):
96-
self._collection = database.get_or_create_collection(
97+
self._database = database
98+
self._collection: DocumentCollection[_AgentDocument]
99+
100+
async def __aenter__(self) -> Self:
101+
self._collection = await self._database.get_or_create_collection(
97102
name="agents",
98103
schema=_AgentDocument,
99104
)
105+
return self
106+
107+
async def __aexit__(
108+
self,
109+
exc_type: Optional[type[BaseException]],
110+
exc_value: Optional[BaseException],
111+
traceback: Optional[object],
112+
) -> None:
113+
pass
100114

101115
def _serialize(self, agent: Agent) -> _AgentDocument:
102116
return _AgentDocument(

src/parlant/core/context_variables.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616
from abc import ABC, abstractmethod
1717
from typing import NewType, Optional, Sequence, cast
18-
from typing_extensions import TypedDict, override
18+
from typing_extensions import TypedDict, override, Self
1919
from datetime import datetime, timezone
2020
from dataclasses import dataclass
2121

@@ -27,6 +27,7 @@
2727
generate_id,
2828
)
2929
from parlant.core.persistence.document_database import (
30+
DocumentCollection,
3031
DocumentDatabase,
3132
ObjectId,
3233
)
@@ -158,15 +159,29 @@ class ContextVariableDocumentStore(ContextVariableStore):
158159
VERSION = Version.from_string("0.1.0")
159160

160161
def __init__(self, database: DocumentDatabase):
161-
self._variable_collection = database.get_or_create_collection(
162+
self._database = database
163+
self._variable_collection: DocumentCollection[_ContextVariableDocument]
164+
self._value_collection: DocumentCollection[_ContextVariableValueDocument]
165+
166+
async def __aenter__(self) -> Self:
167+
self._variable_collection = await self._database.get_or_create_collection(
162168
name="variables",
163169
schema=_ContextVariableDocument,
164170
)
165171

166-
self._value_collection = database.get_or_create_collection(
172+
self._value_collection = await self._database.get_or_create_collection(
167173
name="values",
168174
schema=_ContextVariableValueDocument,
169175
)
176+
return self
177+
178+
async def __aexit__(
179+
self,
180+
exc_type: Optional[type[BaseException]],
181+
exc_value: Optional[BaseException],
182+
traceback: Optional[object],
183+
) -> None:
184+
pass
170185

171186
def _serialize_context_variable(
172187
self,

0 commit comments

Comments
 (0)