Skip to content

Commit 44b6181

Browse files
authored
Create Rust AIProxy Client (deven96#75)
* WIP: AIProxy Client - Move client's connection send and read logic to trait * Create AIproxy rust client * add some tests to rust ai proxy client * update documentation for ai proxy * add tests for aiproxy get pred command * AI and DB Conn struct rename and doc update * Add more binary store tests, Change AIProxyServer, now owns aiproxyconfig * update aiclient get pred tests and remove unnessary clone
1 parent 6740dba commit 44b6181

File tree

18 files changed

+1205
-102
lines changed

18 files changed

+1205
-102
lines changed

ahnlich/ai/src/engine/store.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,7 @@ impl AIStoreHandler {
119119
return Err(AIProxyError::ReservedError(metadata_key.to_string()));
120120
}
121121
let store = self.get(store_name)?;
122-
let input_type = store_input.clone().into();
123122

124-
if store.r#type != input_type {
125-
return Err(AIProxyError::StoreTypeMismatch {
126-
store_type: store.r#type.clone(),
127-
input_type,
128-
});
129-
}
130123
let store_key = store.model.model_ndarray(&store_input);
131124
let metadata_value: MetadataValue = store_input.into();
132125
let mut final_store_value: StdHashMap<MetadataKey, MetadataValue> =

ahnlich/ai/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::error::Error;
55
#[tokio::main]
66
async fn main() -> Result<(), Box<dyn Error>> {
77
let cli = ahnlich_ai_proxy::cli::Cli::parse();
8-
match &cli.command {
8+
match cli.command {
99
ahnlich_ai_proxy::cli::Commands::Start(config) => {
1010
let server = ahnlich_ai_proxy::server::handler::AIProxyServer::new(config).await?;
1111
server.start().await?;

ahnlich/ai/src/server/handler.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@ use deadpool::managed::Pool;
2222
//#[global_allocator]
2323
pub(super) static AI_ALLOCATOR: Cap<alloc::System> = Cap::new(alloc::System, usize::max_value());
2424

25-
pub struct AIProxyServer<'a> {
25+
pub struct AIProxyServer {
2626
listener: TcpListener,
27-
config: &'a AIProxyConfig,
27+
config: AIProxyConfig,
2828
client_handler: Arc<ClientHandler>,
2929
store_handler: Arc<AIStoreHandler>,
3030
shutdown_token: Shutdown,
3131
db_client: Arc<DbClient>,
3232
}
3333

34-
impl<'a> AIProxyServer<'a> {
35-
pub async fn new(config: &'a AIProxyConfig) -> IoResult<Self> {
34+
impl AIProxyServer {
35+
pub async fn new(config: AIProxyConfig) -> IoResult<Self> {
3636
let shutdown_token = Shutdown::default();
3737
Self::build(config, shutdown_token).await
3838
}
3939

40-
pub async fn build(config: &'a AIProxyConfig, shutdown_token: Shutdown) -> IoResult<Self> {
40+
pub async fn build(config: AIProxyConfig, shutdown_token: Shutdown) -> IoResult<Self> {
4141
AI_ALLOCATOR
4242
.set_limit(config.allocator_size)
4343
.expect("Could not set up ai-proxy with allocator_size");
@@ -53,7 +53,7 @@ impl<'a> AIProxyServer<'a> {
5353
tracer::init_tracing("ahnlich-db", Some(&config.log_level), &otel_url)
5454
}
5555
let write_flag = Arc::new(AtomicBool::new(false));
56-
let db_client = Self::build_db_client(config).await;
56+
let db_client = Self::build_db_client(&config).await;
5757
let mut store_handler = AIStoreHandler::new(write_flag.clone());
5858
let client_handler = Arc::new(ClientHandler::new(config.maximum_clients));
5959

@@ -161,7 +161,7 @@ impl<'a> AIProxyServer<'a> {
161161
})
162162
}
163163

164-
async fn build_db_client(config: &'a AIProxyConfig) -> DbClient {
164+
async fn build_db_client(config: &AIProxyConfig) -> DbClient {
165165
let manager = DbConnManager::new(config.db_host.clone(), config.db_port);
166166
let pool = Pool::builder(manager)
167167
.max_size(config.db_client_pool_size)

ahnlich/ai/src/tests/aiproxy_test.rs

Lines changed: 217 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,7 @@ use tokio::net::TcpStream;
2525
use tokio::time::{timeout, Duration};
2626

2727
static CONFIG: Lazy<ServerConfig> = Lazy::new(|| ServerConfig::default());
28-
static AI_CONFIG: Lazy<AIProxyConfig> = Lazy::new(|| {
29-
let mut ai_proxy = AIProxyConfig::default().os_select_port();
30-
ai_proxy.db_port = CONFIG.port.clone();
31-
ai_proxy.db_host = CONFIG.host.clone();
32-
ai_proxy
33-
});
28+
static AI_CONFIG: Lazy<AIProxyConfig> = Lazy::new(|| AIProxyConfig::default().os_select_port());
3429

3530
static PERSISTENCE_FILE: Lazy<PathBuf> =
3631
Lazy::new(|| PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("ahnlich_ai_proxy.dat"));
@@ -89,7 +84,11 @@ async fn provision_test_servers() -> SocketAddr {
8984
let server = Server::new(&CONFIG)
9085
.await
9186
.expect("Could not initialize server");
92-
let ai_server = AIProxyServer::new(&AI_CONFIG)
87+
let db_port = server.local_addr().unwrap().port();
88+
let mut config = AI_CONFIG.clone();
89+
config.db_port = db_port;
90+
91+
let ai_server = AIProxyServer::new(config)
9392
.await
9493
.expect("Could not initialize ai proxy");
9594

@@ -422,7 +421,7 @@ async fn test_ai_proxy_del_key_drop_store() {
422421

423422
#[tokio::test]
424423
async fn test_ai_proxy_fails_db_server_unavailable() {
425-
let ai_server = AIProxyServer::new(&AI_CONFIG)
424+
let ai_server = AIProxyServer::new(AI_CONFIG.clone())
426425
.await
427426
.expect("Could not initialize ai proxy");
428427

@@ -463,7 +462,7 @@ async fn test_ai_proxy_test_with_persistence() {
463462
let server = Server::new(&CONFIG)
464463
.await
465464
.expect("Could not initialize server");
466-
let ai_server = AIProxyServer::new(&AI_CONFIG_WITH_PERSISTENCE)
465+
let ai_server = AIProxyServer::new(AI_CONFIG_WITH_PERSISTENCE.clone())
467466
.await
468467
.expect("Could not initialize ai proxy");
469468

@@ -515,7 +514,7 @@ async fn test_ai_proxy_test_with_persistence() {
515514
tokio::time::sleep(Duration::from_millis(200)).await;
516515
// start another server with existing persistence
517516

518-
let persisted_server = AIProxyServer::new(&AI_CONFIG_WITH_PERSISTENCE)
517+
let persisted_server = AIProxyServer::new(AI_CONFIG_WITH_PERSISTENCE.clone())
519518
.await
520519
.unwrap();
521520

@@ -583,3 +582,211 @@ async fn test_ai_proxy_destroy_database() {
583582
let mut reader = BufReader::new(second_stream);
584583
query_server_assert_result(&mut reader, message, expected).await
585584
}
585+
586+
#[tokio::test]
587+
async fn test_ai_proxy_binary_store_actions() {
588+
let address = provision_test_servers().await;
589+
590+
let store_name = StoreName(String::from("Deven Image Store"));
591+
let matching_metadatakey = MetadataKey::new("Name".to_owned());
592+
let matching_metadatavalue = MetadataValue::RawString("Greatness".to_owned());
593+
594+
let store_value_1 =
595+
StoreValue::from_iter([(matching_metadatakey.clone(), matching_metadatavalue.clone())]);
596+
let store_value_2 = StoreValue::from_iter([(
597+
matching_metadatakey.clone(),
598+
MetadataValue::RawString("Deven".to_owned()),
599+
)]);
600+
let store_data = vec![
601+
(
602+
StoreInput::Binary(vec![93, 4, 1, 6, 2, 8, 8, 32, 45]),
603+
store_value_1.clone(),
604+
),
605+
(
606+
StoreInput::Binary(vec![102, 3, 4, 6, 7, 8, 4, 190]),
607+
store_value_2.clone(),
608+
),
609+
(
610+
StoreInput::Binary(vec![211, 2, 4, 6, 7, 8, 8, 92, 21, 10]),
611+
StoreValue::from_iter([(
612+
matching_metadatakey.clone(),
613+
MetadataValue::RawString("Daniel".to_owned()),
614+
)]),
615+
),
616+
];
617+
618+
let message = AIServerQuery::from_queries(&[
619+
AIQuery::CreateStore {
620+
r#type: AIStoreType::Binary,
621+
store: store_name.clone(),
622+
model: AIModel::Llama3,
623+
predicates: HashSet::new(),
624+
non_linear_indices: HashSet::new(),
625+
},
626+
AIQuery::ListStores,
627+
AIQuery::CreatePredIndex {
628+
store: store_name.clone(),
629+
predicates: HashSet::from_iter([
630+
MetadataKey::new("Name".to_string()),
631+
MetadataKey::new("Age".to_string()),
632+
]),
633+
},
634+
AIQuery::Set {
635+
store: store_name.clone(),
636+
inputs: store_data,
637+
},
638+
AIQuery::DropPredIndex {
639+
store: store_name.clone(),
640+
predicates: HashSet::from_iter([MetadataKey::new("Age".to_string())]),
641+
error_if_not_exists: true,
642+
},
643+
AIQuery::GetPred {
644+
store: store_name.clone(),
645+
condition: PredicateCondition::Value(Predicate::Equals {
646+
key: matching_metadatakey.clone(),
647+
value: matching_metadatavalue,
648+
}),
649+
},
650+
AIQuery::PurgeStores,
651+
]);
652+
653+
let mut expected = AIServerResult::with_capacity(7);
654+
655+
expected.push(Ok(AIServerResponse::Unit));
656+
expected.push(Ok(AIServerResponse::StoreList(HashSet::from_iter([
657+
AIStoreInfo {
658+
name: store_name,
659+
r#type: AIStoreType::Binary,
660+
model: AIModel::Llama3,
661+
embedding_size: AIModel::Llama3.embedding_size().into(),
662+
},
663+
]))));
664+
expected.push(Ok(AIServerResponse::CreateIndex(2)));
665+
expected.push(Ok(AIServerResponse::Set(StoreUpsert {
666+
inserted: 3,
667+
updated: 0,
668+
})));
669+
expected.push(Ok(AIServerResponse::Del(1)));
670+
expected.push(Ok(AIServerResponse::Get(vec![(
671+
StoreInput::Binary(vec![93, 4, 1, 6, 2, 8, 8, 32, 45]),
672+
store_value_1.clone(),
673+
)])));
674+
expected.push(Ok(AIServerResponse::Del(1)));
675+
676+
let connected_stream = TcpStream::connect(address).await.unwrap();
677+
let mut reader = BufReader::new(connected_stream);
678+
679+
query_server_assert_result(&mut reader, message, expected).await;
680+
}
681+
682+
#[tokio::test]
683+
async fn test_ai_proxy_binary_store_with_text_and_binary() {
684+
let address = provision_test_servers().await;
685+
686+
let store_name = StoreName(String::from("Deven Mixed Store"));
687+
let matching_metadatakey = MetadataKey::new("Brand".to_owned());
688+
let matching_metadatavalue = MetadataValue::RawString("Nike".to_owned());
689+
690+
let store_value_1 =
691+
StoreValue::from_iter([(matching_metadatakey.clone(), matching_metadatavalue.clone())]);
692+
let store_value_2 = StoreValue::from_iter([(
693+
matching_metadatakey.clone(),
694+
MetadataValue::RawString("Deven".to_owned()),
695+
)]);
696+
let store_data = vec![
697+
(
698+
StoreInput::Binary(vec![93, 4, 1, 6, 2, 8, 8, 32, 45]),
699+
store_value_1.clone(),
700+
),
701+
(
702+
StoreInput::Binary(vec![102, 3, 4, 6, 7, 8, 4, 190]),
703+
store_value_2.clone(),
704+
),
705+
(
706+
StoreInput::Binary(vec![211, 2, 4, 6, 7, 8, 8, 92, 21, 10]),
707+
StoreValue::from_iter([(
708+
matching_metadatakey.clone(),
709+
MetadataValue::RawString("Daniel".to_owned()),
710+
)]),
711+
),
712+
(
713+
StoreInput::RawString(String::from("Buster Matthews is the name")),
714+
StoreValue::from_iter([(
715+
MetadataKey::new("Description".to_string()),
716+
MetadataValue::RawString("20 year old line backer".to_owned()),
717+
)]),
718+
),
719+
];
720+
721+
let message = AIServerQuery::from_queries(&[
722+
AIQuery::CreateStore {
723+
r#type: AIStoreType::Binary,
724+
store: store_name.clone(),
725+
model: AIModel::Llama3,
726+
predicates: HashSet::new(),
727+
non_linear_indices: HashSet::new(),
728+
},
729+
AIQuery::ListStores,
730+
AIQuery::CreatePredIndex {
731+
store: store_name.clone(),
732+
predicates: HashSet::from_iter([
733+
MetadataKey::new("Name".to_string()),
734+
MetadataKey::new("Description".to_string()),
735+
]),
736+
},
737+
AIQuery::Set {
738+
store: store_name.clone(),
739+
inputs: store_data,
740+
},
741+
AIQuery::DropPredIndex {
742+
store: store_name.clone(),
743+
predicates: HashSet::from_iter([MetadataKey::new("Age".to_string())]),
744+
error_if_not_exists: true,
745+
},
746+
AIQuery::GetPred {
747+
store: store_name.clone(),
748+
condition: PredicateCondition::Value(Predicate::In {
749+
key: MetadataKey::new("Description".to_owned()),
750+
value: HashSet::from_iter([MetadataValue::RawString(
751+
"20 year old line backer".to_owned(),
752+
)]),
753+
}),
754+
},
755+
AIQuery::PurgeStores,
756+
]);
757+
758+
let mut expected = AIServerResult::with_capacity(7);
759+
760+
expected.push(Ok(AIServerResponse::Unit));
761+
expected.push(Ok(AIServerResponse::StoreList(HashSet::from_iter([
762+
AIStoreInfo {
763+
name: store_name,
764+
r#type: AIStoreType::Binary,
765+
model: AIModel::Llama3,
766+
embedding_size: AIModel::Llama3.embedding_size().into(),
767+
},
768+
]))));
769+
expected.push(Ok(AIServerResponse::CreateIndex(2)));
770+
expected.push(Ok(AIServerResponse::Set(StoreUpsert {
771+
inserted: 4,
772+
updated: 0,
773+
})));
774+
expected.push(Err(
775+
"db error Predicate Age not found in store, attempt CREATEPREDINDEX with predicate"
776+
.to_string(),
777+
));
778+
779+
expected.push(Ok(AIServerResponse::Get(vec![(
780+
StoreInput::RawString(String::from("Buster Matthews is the name")),
781+
StoreValue::from_iter([(
782+
MetadataKey::new("Description".to_owned()),
783+
MetadataValue::RawString("20 year old line backer".to_owned()),
784+
)]),
785+
)])));
786+
expected.push(Ok(AIServerResponse::Del(1)));
787+
788+
let connected_stream = TcpStream::connect(address).await.unwrap();
789+
let mut reader = BufReader::new(connected_stream);
790+
791+
query_server_assert_result(&mut reader, message, expected).await;
792+
}

ahnlich/client/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ tokio.workspace = true
2424
deadpool.workspace = true
2525
[dev-dependencies]
2626
db = { path = "../db", version = "*" }
27+
ai = { path = "../ai", version = "*" }
2728
pretty_assertions.workspace = true
2829
ndarray.workspace = true

0 commit comments

Comments
 (0)