Skip to content

Commit a0766ad

Browse files
authored
AI Proxy Persistence (deven96#70)
* Add persistence for ahnlich ai and add new command(destroy_database) for ai_proxy * update destroydb command to purge stores * add ai purge commands to python ai client and update readme
1 parent 20b6c63 commit a0766ad

File tree

12 files changed

+251
-9
lines changed

12 files changed

+251
-9
lines changed

ahnlich/ai/src/engine/store.rs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use serde::Serialize;
1515
use std::collections::HashMap as StdHashMap;
1616
use std::collections::HashSet as StdHashSet;
1717
use std::sync::atomic::AtomicBool;
18+
use std::sync::atomic::Ordering;
1819
use std::sync::Arc;
1920

2021
/// Contains all the stores that have been created in memory
@@ -34,6 +35,24 @@ impl AIStoreHandler {
3435
write_flag,
3536
}
3637
}
38+
pub(crate) fn get_stores(&self) -> AIStores {
39+
self.stores.clone()
40+
}
41+
42+
#[cfg(test)]
43+
pub fn write_flag(&self) -> Arc<AtomicBool> {
44+
self.write_flag.clone()
45+
}
46+
47+
fn set_write_flag(&self) {
48+
let _ = self
49+
.write_flag
50+
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
51+
}
52+
53+
pub(crate) fn use_snapshot(&mut self, stores_snapshot: AIStores) {
54+
self.stores = stores_snapshot;
55+
}
3756

3857
#[tracing::instrument(skip(self))]
3958
pub(crate) fn create_store(
@@ -42,7 +61,8 @@ impl AIStoreHandler {
4261
store_type: AIStoreType,
4362
model: AIModel,
4463
) -> Result<(), AIProxyError> {
45-
self.stores
64+
if self
65+
.stores
4666
.try_insert(
4767
store_name.clone(),
4868
Arc::new(AIStore::create(
@@ -52,8 +72,11 @@ impl AIStoreHandler {
5272
)),
5373
&self.stores.guard(),
5474
)
55-
.map_err(|_| AIProxyError::StoreAlreadyExists(store_name.clone()))?;
56-
75+
.is_err()
76+
{
77+
return Err(AIProxyError::StoreAlreadyExists(store_name.clone()));
78+
}
79+
self.set_write_flag();
5780
Ok(())
5881
}
5982

@@ -161,11 +184,20 @@ impl AIStoreHandler {
161184
let removed = if !removed {
162185
0
163186
} else {
164-
//self.set_write_flag();
187+
self.set_write_flag();
165188
1
166189
};
167190
Ok(removed)
168191
}
192+
193+
/// Matches DestroyDatabase - Drops all the stores in the database
194+
#[tracing::instrument(skip(self))]
195+
pub(crate) fn purge_stores(&self) -> usize {
196+
let store_length = self.stores.pin().len();
197+
let guard = self.stores.guard();
198+
self.stores.clear(&guard);
199+
store_length
200+
}
169201
}
170202

171203
#[derive(Debug, Serialize, Deserialize)]

ahnlich/ai/src/server/handler.rs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use tokio::net::{TcpListener, TcpStream};
1313
use tokio::select;
1414
use tokio_graceful::Shutdown;
1515
use tracing::Instrument;
16-
use utils::{client::ClientHandler, protocol::AhnlichProtocol};
16+
use utils::{client::ClientHandler, persistence::Persistence, protocol::AhnlichProtocol};
1717

1818
use ahnlich_client_rs::db::{DbClient, DbConnManager};
1919
use deadpool::managed::Pool;
@@ -54,8 +54,30 @@ impl<'a> AIProxyServer<'a> {
5454
}
5555
let write_flag = Arc::new(AtomicBool::new(false));
5656
let db_client = Self::build_db_client(config).await;
57-
let store_handler = AIStoreHandler::new(write_flag.clone());
57+
let mut store_handler = AIStoreHandler::new(write_flag.clone());
5858
let client_handler = Arc::new(ClientHandler::new(config.maximum_clients));
59+
60+
// persistence
61+
if let Some(persist_location) = &config.persist_location {
62+
match Persistence::load_snapshot(persist_location) {
63+
Err(e) => {
64+
tracing::error!("Failed to load snapshot from persist location {e}");
65+
}
66+
Ok(snapshot) => {
67+
store_handler.use_snapshot(snapshot);
68+
}
69+
}
70+
// spawn the persistence task
71+
let mut persistence_task = Persistence::task(
72+
write_flag,
73+
config.persistence_interval,
74+
persist_location,
75+
store_handler.get_stores(),
76+
);
77+
shutdown_token
78+
.spawn_task_fn(|guard| async move { persistence_task.monitor(guard).await });
79+
};
80+
5981
Ok(Self {
6082
listener,
6183
shutdown_token,
@@ -148,4 +170,9 @@ impl<'a> AIProxyServer<'a> {
148170

149171
DbClient::new_with_pool(pool)
150172
}
173+
174+
#[cfg(test)]
175+
pub fn write_flag(&self) -> Arc<AtomicBool> {
176+
self.store_handler.write_flag()
177+
}
151178
}

ahnlich/ai/src/server/task.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ impl AhnlichProtocol for AIProxyTask {
285285
);
286286
}
287287
}
288+
AIQuery::PurgeStores => {
289+
let destoryed = self.store_handler.purge_stores();
290+
Ok(AIServerResponse::Del(destoryed))
291+
}
288292
})
289293
}
290294
result

ahnlich/ai/src/tests/aiproxy_test.rs

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ use ahnlich_types::{
1313

1414
use once_cell::sync::Lazy;
1515
use pretty_assertions::assert_eq;
16-
use std::{collections::HashSet, num::NonZeroUsize};
16+
use std::{collections::HashSet, num::NonZeroUsize, sync::atomic::Ordering};
1717

1818
use crate::cli::AIProxyConfig;
1919
use crate::server::handler::AIProxyServer;
2020
use ahnlich_types::bincode::BinCodeSerAndDeser;
2121
use std::net::SocketAddr;
22+
use std::path::PathBuf;
2223
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
2324
use tokio::net::TcpStream;
2425
use tokio::time::{timeout, Duration};
@@ -31,6 +32,16 @@ static AI_CONFIG: Lazy<AIProxyConfig> = Lazy::new(|| {
3132
ai_proxy
3233
});
3334

35+
static PERSISTENCE_FILE: Lazy<PathBuf> =
36+
Lazy::new(|| PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("ahnlich_ai_proxy.dat"));
37+
38+
static AI_CONFIG_WITH_PERSISTENCE: Lazy<AIProxyConfig> = Lazy::new(|| {
39+
AIProxyConfig::default()
40+
.os_select_port()
41+
.set_persistence_interval(200)
42+
.set_persist_location((*PERSISTENCE_FILE).clone())
43+
});
44+
3445
async fn get_server_response(
3546
reader: &mut BufReader<TcpStream>,
3647
query: AIServerQuery,
@@ -446,3 +457,129 @@ async fn test_ai_proxy_fails_db_server_unavailable() {
446457
let err = res.err().unwrap();
447458
assert!(err.contains(" kind: ConnectionRefused,"))
448459
}
460+
461+
#[tokio::test]
462+
async fn test_ai_proxy_test_with_persistence() {
463+
let server = Server::new(&CONFIG)
464+
.await
465+
.expect("Could not initialize server");
466+
let ai_server = AIProxyServer::new(&AI_CONFIG_WITH_PERSISTENCE)
467+
.await
468+
.expect("Could not initialize ai proxy");
469+
470+
let address = ai_server.local_addr().expect("Could not get local addr");
471+
let _ = tokio::spawn(async move { server.start().await });
472+
let write_flag = ai_server.write_flag();
473+
// start up ai proxy
474+
let _ = tokio::spawn(async move { ai_server.start().await });
475+
// Allow some time for the servers to start
476+
tokio::time::sleep(Duration::from_millis(200)).await;
477+
478+
let store_name = StoreName(String::from("Main"));
479+
let store_name_2 = StoreName(String::from("Main2"));
480+
let first_stream = TcpStream::connect(address).await.unwrap();
481+
482+
let message = AIServerQuery::from_queries(&[
483+
AIQuery::CreateStore {
484+
r#type: AIStoreType::RawString,
485+
store: store_name.clone(),
486+
model: AIModel::Llama3,
487+
predicates: HashSet::from_iter([]),
488+
non_linear_indices: HashSet::new(),
489+
},
490+
AIQuery::CreateStore {
491+
r#type: AIStoreType::Binary,
492+
store: store_name_2.clone(),
493+
model: AIModel::Llama3,
494+
predicates: HashSet::from_iter([]),
495+
non_linear_indices: HashSet::new(),
496+
},
497+
AIQuery::DropStore {
498+
store: store_name,
499+
error_if_not_exists: true,
500+
},
501+
]);
502+
503+
let mut expected = AIServerResult::with_capacity(3);
504+
505+
expected.push(Ok(AIServerResponse::Unit));
506+
expected.push(Ok(AIServerResponse::Unit));
507+
expected.push(Ok(AIServerResponse::Del(1)));
508+
509+
let mut reader = BufReader::new(first_stream);
510+
query_server_assert_result(&mut reader, message, expected).await;
511+
512+
// write flag should show that a write has occured
513+
assert!(write_flag.load(Ordering::SeqCst));
514+
// Allow some time for persistence to kick in
515+
tokio::time::sleep(Duration::from_millis(200)).await;
516+
// start another server with existing persistence
517+
518+
let persisted_server = AIProxyServer::new(&AI_CONFIG_WITH_PERSISTENCE)
519+
.await
520+
.unwrap();
521+
522+
// Allow some time for the server to start
523+
tokio::time::sleep(Duration::from_millis(100)).await;
524+
525+
let address = persisted_server
526+
.local_addr()
527+
.expect("Could not get local addr");
528+
let write_flag = persisted_server.write_flag();
529+
let _ = tokio::spawn(async move { persisted_server.start().await });
530+
let second_stream = TcpStream::connect(address).await.unwrap();
531+
let mut reader = BufReader::new(second_stream);
532+
533+
let message = AIServerQuery::from_queries(&[AIQuery::ListStores]);
534+
535+
let mut expected = AIServerResult::with_capacity(1);
536+
537+
expected.push(Ok(AIServerResponse::StoreList(HashSet::from_iter([
538+
AIStoreInfo {
539+
name: store_name_2.clone(),
540+
r#type: AIStoreType::Binary,
541+
model: AIModel::Llama3,
542+
embedding_size: AIModel::Llama3.embedding_size().into(),
543+
},
544+
]))));
545+
546+
query_server_assert_result(&mut reader, message, expected).await;
547+
assert!(!write_flag.load(Ordering::SeqCst));
548+
// delete persistence
549+
let _ = std::fs::remove_file(&*PERSISTENCE_FILE);
550+
}
551+
552+
#[tokio::test]
553+
async fn test_ai_proxy_destroy_database() {
554+
let address = provision_test_servers().await;
555+
let second_stream = TcpStream::connect(address).await.unwrap();
556+
let store_name = StoreName(String::from("Deven Kicks"));
557+
let message = AIServerQuery::from_queries(&[
558+
AIQuery::CreateStore {
559+
r#type: AIStoreType::RawString,
560+
store: store_name.clone(),
561+
model: AIModel::Llama3,
562+
predicates: HashSet::from_iter([]),
563+
non_linear_indices: HashSet::new(),
564+
},
565+
AIQuery::ListStores,
566+
AIQuery::PurgeStores,
567+
AIQuery::ListStores,
568+
]);
569+
let mut expected = AIServerResult::with_capacity(4);
570+
571+
expected.push(Ok(AIServerResponse::Unit));
572+
expected.push(Ok(AIServerResponse::StoreList(HashSet::from_iter([
573+
AIStoreInfo {
574+
name: store_name,
575+
r#type: AIStoreType::RawString,
576+
model: AIModel::Llama3,
577+
embedding_size: AIModel::Llama3.embedding_size().into(),
578+
},
579+
]))));
580+
expected.push(Ok(AIServerResponse::Del(1)));
581+
expected.push(Ok(AIServerResponse::StoreList(HashSet::from_iter([]))));
582+
583+
let mut reader = BufReader::new(second_stream);
584+
query_server_assert_result(&mut reader, message, expected).await
585+
}

ahnlich/types/src/ai/query.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub enum AIQuery {
5252
},
5353
InfoServer,
5454
ListStores,
55+
PurgeStores,
5556
Ping,
5657
}
5758

docs/ai.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ Here is a rudimentary list of commands for the AI proxy to accept
2121
- `LISTSTORES`: List all the stores on the server. It also returns information like store length/size, embedding size, AI model, e.t.c.
2222
- `PING`: Test server if the server is reachable
2323
- `DROPSTORE`: takes in a store and deletes it. Destroys everything pertaining the store
24+
- `PURGESTORES`: Destroys all created stores.

sdk/ahnlich-client-py/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ response = client.delete_key(
508508

509509

510510
## Bulk Requests
511-
The clients has the ability to send multiple requests at once, and these requests will be handled sequentially. The builder class takes care of this. The response is a list of all individual request responses.
511+
Clients have the ability to send multiple requests at once, and these requests will be handled sequentially. The builder class takes care of this. The response is a list of all individual request responses.
512512

513513

514514
```py

sdk/ahnlich-client-py/ahnlich_client_py/builders/ai.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def drop_store(self, store_name: str, error_if_not_exists: bool):
9898
)
9999
)
100100

101+
def purge_stores(self):
102+
self.queries.append(ai_query.AIQuery__PurgeStores())
103+
101104
def info_server(self):
102105
self.queries.append(ai_query.AIQuery__InfoServer())
103106

sdk/ahnlich-client-py/ahnlich_client_py/clients/ai.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def drop_store(self, store_name: str, error_if_not_exists: bool):
106106
)
107107
return self.process_request(self.builder.to_server_query())
108108

109+
def purge_stores(self):
110+
self.builder.purge_stores()
111+
return self.process_request(self.builder.to_server_query())
112+
109113
def info_server(self):
110114
self.builder.info_server()
111115
return self.process_request(self.builder.to_server_query())

sdk/ahnlich-client-py/ahnlich_client_py/internals/ai_query.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,17 @@ class AIQuery__ListStores(AIQuery):
123123

124124

125125
@dataclass(frozen=True)
126-
class AIQuery__Ping(AIQuery):
126+
class AIQuery__PurgeStores(AIQuery):
127127
INDEX = 10 # type: int
128128
pass
129129

130130

131+
@dataclass(frozen=True)
132+
class AIQuery__Ping(AIQuery):
133+
INDEX = 11 # type: int
134+
pass
135+
136+
131137
AIQuery.VARIANTS = [
132138
AIQuery__CreateStore,
133139
AIQuery__GetPred,
@@ -139,6 +145,7 @@ class AIQuery__Ping(AIQuery):
139145
AIQuery__DropStore,
140146
AIQuery__InfoServer,
141147
AIQuery__ListStores,
148+
AIQuery__PurgeStores,
142149
AIQuery__Ping,
143150
]
144151

0 commit comments

Comments
 (0)