Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions ahnlich/ai/src/engine/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use serde::Serialize;
use std::collections::HashMap as StdHashMap;
use std::collections::HashSet as StdHashSet;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;

/// Contains all the stores that have been created in memory
Expand All @@ -34,6 +35,24 @@ impl AIStoreHandler {
write_flag,
}
}
pub(crate) fn get_stores(&self) -> AIStores {
self.stores.clone()
}

#[cfg(test)]
pub fn write_flag(&self) -> Arc<AtomicBool> {
self.write_flag.clone()
}

fn set_write_flag(&self) {
let _ = self
.write_flag
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
}

pub(crate) fn use_snapshot(&mut self, stores_snapshot: AIStores) {
self.stores = stores_snapshot;
}

#[tracing::instrument(skip(self))]
pub(crate) fn create_store(
Expand All @@ -42,7 +61,8 @@ impl AIStoreHandler {
store_type: AIStoreType,
model: AIModel,
) -> Result<(), AIProxyError> {
self.stores
if self
.stores
.try_insert(
store_name.clone(),
Arc::new(AIStore::create(
Expand All @@ -52,8 +72,11 @@ impl AIStoreHandler {
)),
&self.stores.guard(),
)
.map_err(|_| AIProxyError::StoreAlreadyExists(store_name.clone()))?;

.is_err()
{
return Err(AIProxyError::StoreAlreadyExists(store_name.clone()));
}
self.set_write_flag();
Ok(())
}

Expand Down Expand Up @@ -161,11 +184,20 @@ impl AIStoreHandler {
let removed = if !removed {
0
} else {
//self.set_write_flag();
self.set_write_flag();
1
};
Ok(removed)
}

/// Matches DestroyDatabase - Drops all the stores in the database
#[tracing::instrument(skip(self))]
pub(crate) fn destory_database(&self) -> usize {
let store_length = self.stores.pin().len();
let guard = self.stores.guard();
self.stores.clear(&guard);
store_length
}
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
31 changes: 29 additions & 2 deletions ahnlich/ai/src/server/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::select;
use tokio_graceful::Shutdown;
use tracing::Instrument;
use utils::{client::ClientHandler, protocol::AhnlichProtocol};
use utils::{client::ClientHandler, persistence::Persistence, protocol::AhnlichProtocol};

use ahnlich_client_rs::db::{DbClient, DbConnManager};
use deadpool::managed::Pool;
Expand Down Expand Up @@ -54,8 +54,30 @@ impl<'a> AIProxyServer<'a> {
}
let write_flag = Arc::new(AtomicBool::new(false));
let db_client = Self::build_db_client(config).await;
let store_handler = AIStoreHandler::new(write_flag.clone());
let mut store_handler = AIStoreHandler::new(write_flag.clone());
let client_handler = Arc::new(ClientHandler::new(config.maximum_clients));

// persistence
if let Some(persist_location) = &config.persist_location {
match Persistence::load_snapshot(persist_location) {
Err(e) => {
tracing::error!("Failed to load snapshot from persist location {e}");
}
Ok(snapshot) => {
store_handler.use_snapshot(snapshot);
}
}
// spawn the persistence task
let mut persistence_task = Persistence::task(
write_flag,
config.persistence_interval,
persist_location,
store_handler.get_stores(),
);
shutdown_token
.spawn_task_fn(|guard| async move { persistence_task.monitor(guard).await });
};

Ok(Self {
listener,
shutdown_token,
Expand Down Expand Up @@ -148,4 +170,9 @@ impl<'a> AIProxyServer<'a> {

DbClient::new_with_pool(pool)
}

#[cfg(test)]
pub fn write_flag(&self) -> Arc<AtomicBool> {
self.store_handler.write_flag()
}
}
4 changes: 4 additions & 0 deletions ahnlich/ai/src/server/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ impl AhnlichProtocol for AIProxyTask {
);
}
}
AIQuery::DestoryDatabase => {
let destoryed = self.store_handler.destory_database();
Ok(AIServerResponse::Del(destoryed))
}
})
}
result
Expand Down
139 changes: 138 additions & 1 deletion ahnlich/ai/src/tests/aiproxy_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ use ahnlich_types::{

use once_cell::sync::Lazy;
use pretty_assertions::assert_eq;
use std::{collections::HashSet, num::NonZeroUsize};
use std::{collections::HashSet, num::NonZeroUsize, sync::atomic::Ordering};

use crate::cli::AIProxyConfig;
use crate::server::handler::AIProxyServer;
use ahnlich_types::bincode::BinCodeSerAndDeser;
use std::net::SocketAddr;
use std::path::PathBuf;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio::time::{timeout, Duration};
Expand All @@ -31,6 +32,16 @@ static AI_CONFIG: Lazy<AIProxyConfig> = Lazy::new(|| {
ai_proxy
});

static PERSISTENCE_FILE: Lazy<PathBuf> =
Lazy::new(|| PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("ahnlich_ai_proxy.dat"));

static AI_CONFIG_WITH_PERSISTENCE: Lazy<AIProxyConfig> = Lazy::new(|| {
AIProxyConfig::default()
.os_select_port()
.set_persistence_interval(200)
.set_persist_location((*PERSISTENCE_FILE).clone())
});

async fn get_server_response(
reader: &mut BufReader<TcpStream>,
query: AIServerQuery,
Expand Down Expand Up @@ -446,3 +457,129 @@ async fn test_ai_proxy_fails_db_server_unavailable() {
let err = res.err().unwrap();
assert!(err.contains(" kind: ConnectionRefused,"))
}

#[tokio::test]
async fn test_ai_proxy_test_with_persistence() {
let server = Server::new(&CONFIG)
.await
.expect("Could not initialize server");
let ai_server = AIProxyServer::new(&AI_CONFIG_WITH_PERSISTENCE)
.await
.expect("Could not initialize ai proxy");

let address = ai_server.local_addr().expect("Could not get local addr");
let _ = tokio::spawn(async move { server.start().await });
let write_flag = ai_server.write_flag();
// start up ai proxy
let _ = tokio::spawn(async move { ai_server.start().await });
// Allow some time for the servers to start
tokio::time::sleep(Duration::from_millis(200)).await;

let store_name = StoreName(String::from("Main"));
let store_name_2 = StoreName(String::from("Main2"));
let first_stream = TcpStream::connect(address).await.unwrap();

let message = AIServerQuery::from_queries(&[
AIQuery::CreateStore {
r#type: AIStoreType::RawString,
store: store_name.clone(),
model: AIModel::Llama3,
predicates: HashSet::from_iter([]),
non_linear_indices: HashSet::new(),
},
AIQuery::CreateStore {
r#type: AIStoreType::Binary,
store: store_name_2.clone(),
model: AIModel::Llama3,
predicates: HashSet::from_iter([]),
non_linear_indices: HashSet::new(),
},
AIQuery::DropStore {
store: store_name,
error_if_not_exists: true,
},
]);

let mut expected = AIServerResult::with_capacity(3);

expected.push(Ok(AIServerResponse::Unit));
expected.push(Ok(AIServerResponse::Unit));
expected.push(Ok(AIServerResponse::Del(1)));

let mut reader = BufReader::new(first_stream);
query_server_assert_result(&mut reader, message, expected).await;

// write flag should show that a write has occured
assert!(write_flag.load(Ordering::SeqCst));
// Allow some time for persistence to kick in
tokio::time::sleep(Duration::from_millis(200)).await;
// start another server with existing persistence

let persisted_server = AIProxyServer::new(&AI_CONFIG_WITH_PERSISTENCE)
.await
.unwrap();

// Allow some time for the server to start
tokio::time::sleep(Duration::from_millis(100)).await;

let address = persisted_server
.local_addr()
.expect("Could not get local addr");
let write_flag = persisted_server.write_flag();
let _ = tokio::spawn(async move { persisted_server.start().await });
let second_stream = TcpStream::connect(address).await.unwrap();
let mut reader = BufReader::new(second_stream);

let message = AIServerQuery::from_queries(&[AIQuery::ListStores]);

let mut expected = AIServerResult::with_capacity(1);

expected.push(Ok(AIServerResponse::StoreList(HashSet::from_iter([
AIStoreInfo {
name: store_name_2.clone(),
r#type: AIStoreType::Binary,
model: AIModel::Llama3,
embedding_size: AIModel::Llama3.embedding_size().into(),
},
]))));

query_server_assert_result(&mut reader, message, expected).await;
assert!(!write_flag.load(Ordering::SeqCst));
// delete persistence
let _ = std::fs::remove_file(&*PERSISTENCE_FILE);
}

#[tokio::test]
async fn test_ai_proxy_destroy_database() {
let address = provision_test_servers().await;
let second_stream = TcpStream::connect(address).await.unwrap();
let store_name = StoreName(String::from("Deven Kicks"));
let message = AIServerQuery::from_queries(&[
AIQuery::CreateStore {
r#type: AIStoreType::RawString,
store: store_name.clone(),
model: AIModel::Llama3,
predicates: HashSet::from_iter([]),
non_linear_indices: HashSet::new(),
},
AIQuery::ListStores,
AIQuery::DestoryDatabase,
AIQuery::ListStores,
]);
let mut expected = AIServerResult::with_capacity(4);

expected.push(Ok(AIServerResponse::Unit));
expected.push(Ok(AIServerResponse::StoreList(HashSet::from_iter([
AIStoreInfo {
name: store_name,
r#type: AIStoreType::RawString,
model: AIModel::Llama3,
embedding_size: AIModel::Llama3.embedding_size().into(),
},
]))));
expected.push(Ok(AIServerResponse::Del(1)));
expected.push(Ok(AIServerResponse::StoreList(HashSet::from_iter([]))));

let mut reader = BufReader::new(second_stream);
query_server_assert_result(&mut reader, message, expected).await
}
1 change: 1 addition & 0 deletions ahnlich/types/src/ai/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub enum AIQuery {
},
InfoServer,
ListStores,
DestoryDatabase,
Ping,
}

Expand Down
3 changes: 3 additions & 0 deletions type_specs/query/ai_query.json
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@
"ListStores": "UNIT"
},
"10": {
"DestoryDatabase": "UNIT"
},
"11": {
"Ping": "UNIT"
}
}
Expand Down