Skip to content

Commit 5c397a1

Browse files
authored
Setup Basic AI Proxy Types (#49)
* Create ai-proxy query enum * Use storeinputs instead of metadatavalue * add ai server response for queries * add aiserver response * Trace ai db queries and server responses for clients * Setup entry point for ai_proxy, using db parts * WIP: moving ahnlich protocol to traits * Use traits to handle ahnlich protocol * Cleanup unused errors by ai * rename to ai instead of ahnlich_ai * Fleshing out commands * Adding dbclient to tasks * Move dbclient to aiproxy task and some cleanups * Change ai set query format * Add set command for ai proxy * Add getsimn variant * Update ai query commands to match db(create_pred_index, drop_pred_index) * Initialize reserved metadatakey * Add ai allocator and begin tests for ai * More tests for ahnlich ai proxy * remove test for unavailability * update todos and regen typespecs * Match portion of error in aiproxytests
1 parent 94eb429 commit 5c397a1

File tree

38 files changed

+2674
-109
lines changed

38 files changed

+2674
-109
lines changed

ahnlich/Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ members = [
55
"client",
66
"tracer",
77
"typegen",
8-
"utils", "similarity",
8+
"utils",
9+
"similarity",
10+
"ai",
911
]
1012
resolver = "2"
1113

@@ -35,3 +37,5 @@ tokio = { version = "1.37.0", features = [
3537
tokio-graceful = "0.1.6"
3638
rand = "0.8"
3739
rayon = "1.10"
40+
cap = "0.1.2"
41+
deadpool = { version = "0.10", features = ["rt_tokio_1"]}

ahnlich/ai/Cargo.toml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
[package]
2+
name = "ai"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
# only used for rust client test and not to be released
7+
[lib]
8+
name = "ahnlich_ai_proxy"
9+
path = "src/lib.rs"
10+
11+
[dependencies]
12+
flurry.workspace = true
13+
tokio.workspace = true
14+
serde.workspace = true
15+
blake3.workspace = true
16+
ndarray.workspace = true
17+
bincode.workspace = true
18+
itertools.workspace = true
19+
clap.workspace = true
20+
thiserror.workspace = true
21+
async-trait.workspace = true
22+
utils = { path = "../utils", version = "*" }
23+
ahnlich_types = { path = "../types", version = "*" }
24+
tokio-graceful.workspace = true
25+
once_cell.workspace = true
26+
tracing.workspace = true
27+
tracer = { path = "../tracer", version = "*" }
28+
ahnlich_client_rs = { path = "../client", version = "*" }
29+
ahnlich_similarity = { path = "../similarity", version = "*", features = ["serde"] }
30+
cap.workspace = true
31+
deadpool.workspace = true
32+
33+
34+
[dev-dependencies]
35+
db = { path = "../db", version = "*" }
36+
pretty_assertions.workspace = true

ahnlich/ai/src/cli/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pub mod server;
2+
3+
pub use server::{AIProxyConfig, Cli, Commands};

ahnlich/ai/src/cli/server.rs

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
use clap::{ArgAction, Args, Parser, Subcommand};
2+
3+
#[derive(Parser)]
4+
#[command(version, about, long_about = None)]
5+
pub struct Cli {
6+
#[command(subcommand)]
7+
pub command: Commands,
8+
}
9+
10+
#[derive(Subcommand)]
11+
pub enum Commands {
12+
/// Starts Anhlich AI Proxy
13+
Start(AIProxyConfig),
14+
}
15+
16+
#[derive(Args, Debug, Clone)]
17+
pub struct AIProxyConfig {
18+
/// Ahnlich AI proxy host
19+
#[arg(long, default_value_t = String::from("127.0.0.1"))]
20+
pub host: String,
21+
22+
/// Ahnlich AI proxy port
23+
#[arg(long, default_value_t = 8000)]
24+
pub port: u16,
25+
26+
/// Allows server to persist data to disk on occassion
27+
#[arg(long, default_value_t = false, action=ArgAction::SetTrue)]
28+
pub(crate) enable_persistence: bool,
29+
30+
/// persistence location
31+
#[arg(long, requires_if("true", "enable_persistence"))]
32+
pub(crate) persist_location: Option<std::path::PathBuf>,
33+
34+
/// persistence interval in milliseconds
35+
/// A new persistence round would be scheduled for persistence_interval into the future after
36+
/// current persistence round is completed
37+
#[arg(long, default_value_t = 1000 * 60 * 5)]
38+
pub(crate) persistence_interval: u64,
39+
40+
/// Ahnlich Database Host
41+
#[arg(long, default_value_t = String::from("127.0.0.1"))]
42+
pub db_host: String,
43+
44+
/// Ahnlich Database port
45+
#[arg(long, default_value_t = 1369)]
46+
pub db_port: u16,
47+
48+
/// Ahnlich Database Client Connection Pool Size
49+
#[arg(long, default_value_t = 10)]
50+
pub db_client_pool_size: usize,
51+
52+
/// sets size(in bytes) for global allocator used
53+
/// Defaults to 1 Gi (1 * 1024 * 1024 * 1024)
54+
#[arg(long, default_value_t = 1_073_741_824)]
55+
pub allocator_size: usize,
56+
57+
/// limits the message size of expected messages, defaults to 1MiB (1 * 1024 * 1024)
58+
#[arg(long, default_value_t = 1_048_576)]
59+
pub message_size: usize,
60+
/// Allows enables tracing
61+
#[arg(long, default_value_t = false, action=ArgAction::SetTrue)]
62+
pub(crate) enable_tracing: bool,
63+
64+
/// Otel collector url to send traces to
65+
#[arg(long, requires_if("true", "enable_tracing"))]
66+
pub(crate) otel_endpoint: Option<String>,
67+
68+
/// Log level
69+
#[arg(long, default_value_t = String::from("info"))]
70+
pub(crate) log_level: String,
71+
72+
/// Maximum client connections allowed
73+
/// Defaults to 1000
74+
#[arg(long, default_value_t = 1000)]
75+
pub(crate) maximum_clients: usize,
76+
}
77+
78+
impl Default for AIProxyConfig {
79+
fn default() -> Self {
80+
Self {
81+
host: String::from("127.0.0.1"),
82+
port: 8000,
83+
enable_persistence: false,
84+
persist_location: None,
85+
persistence_interval: 1000 * 60 * 5,
86+
87+
db_host: String::from("127.0.0.1"),
88+
db_port: 1369,
89+
db_client_pool_size: 10,
90+
91+
allocator_size: 1_073_741_824,
92+
message_size: 1_048_576,
93+
94+
enable_tracing: false,
95+
otel_endpoint: None,
96+
log_level: String::from("info"),
97+
maximum_clients: 1000,
98+
}
99+
}
100+
}
101+
102+
impl AIProxyConfig {
103+
pub fn os_select_port(mut self) -> Self {
104+
// allow OS to pick a port
105+
self.port = 0;
106+
self
107+
}
108+
109+
pub fn set_persist_location(mut self, location: std::path::PathBuf) -> Self {
110+
self.persist_location = Some(location);
111+
self
112+
}
113+
114+
pub fn set_persistence_interval(mut self, interval: u64) -> Self {
115+
self.enable_persistence = true;
116+
self.persistence_interval = interval;
117+
self
118+
}
119+
120+
pub fn set_maximum_clients(mut self, maximum_clients: usize) -> Self {
121+
self.maximum_clients = maximum_clients;
122+
self
123+
}
124+
}

ahnlich/ai/src/engine/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod store;

ahnlich/ai/src/engine/store.rs

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
use crate::error::AIProxyError;
2+
use crate::AHNLICH_AI_RESERVED_META_KEY;
3+
use ahnlich_types::ai::AIModel;
4+
use ahnlich_types::ai::AIStoreInfo;
5+
use ahnlich_types::ai::AIStoreType;
6+
use ahnlich_types::keyval::StoreInput;
7+
use ahnlich_types::keyval::StoreKey;
8+
use ahnlich_types::keyval::StoreName;
9+
use ahnlich_types::keyval::StoreValue;
10+
use ahnlich_types::metadata::MetadataKey;
11+
use ahnlich_types::metadata::MetadataValue;
12+
use flurry::HashMap as ConcurrentHashMap;
13+
use serde::Deserialize;
14+
use serde::Serialize;
15+
use std::collections::HashMap as StdHashMap;
16+
use std::collections::HashSet as StdHashSet;
17+
use std::sync::atomic::AtomicBool;
18+
use std::sync::Arc;
19+
20+
/// Contains all the stores that have been created in memory
21+
#[derive(Debug)]
22+
pub struct AIStoreHandler {
23+
/// Making use of a concurrent hashmap, we should be able to create an engine that manages stores
24+
stores: AIStores,
25+
pub write_flag: Arc<AtomicBool>,
26+
}
27+
28+
pub type AIStores = Arc<ConcurrentHashMap<StoreName, Arc<AIStore>>>;
29+
30+
impl AIStoreHandler {
31+
pub fn new(write_flag: Arc<AtomicBool>) -> Self {
32+
Self {
33+
stores: Arc::new(ConcurrentHashMap::new()),
34+
write_flag,
35+
}
36+
}
37+
38+
#[tracing::instrument(skip(self))]
39+
pub(crate) fn create_store(
40+
&self,
41+
store_name: StoreName,
42+
store_type: AIStoreType,
43+
model: AIModel,
44+
) -> Result<(), AIProxyError> {
45+
self.stores
46+
.try_insert(
47+
store_name.clone(),
48+
Arc::new(AIStore::create(
49+
store_type,
50+
store_name.clone(),
51+
model.clone(),
52+
)),
53+
&self.stores.guard(),
54+
)
55+
.map_err(|_| AIProxyError::StoreAlreadyExists(store_name.clone()))?;
56+
57+
Ok(())
58+
}
59+
60+
/// matches LISTSTORES - to return statistics of all stores
61+
#[tracing::instrument(skip(self))]
62+
pub(crate) fn list_stores(&self) -> StdHashSet<AIStoreInfo> {
63+
self.stores
64+
.iter(&self.stores.guard())
65+
.map(|(store_name, store)| AIStoreInfo {
66+
name: store_name.clone(),
67+
model: store.model.clone(),
68+
r#type: store.r#type.clone(),
69+
embedding_size: store.model.embedding_size().into(),
70+
})
71+
.collect()
72+
}
73+
74+
/// Returns a store using the store name, else returns an error
75+
#[tracing::instrument(skip(self))]
76+
pub(crate) fn get(&self, store_name: &StoreName) -> Result<Arc<AIStore>, AIProxyError> {
77+
let store = self
78+
.stores
79+
.get(store_name, &self.stores.guard())
80+
.cloned()
81+
.ok_or(AIProxyError::StoreNotFound(store_name.clone()))?;
82+
Ok(store)
83+
}
84+
85+
/// Converts storeinput into a tuple of storekey and storevalue.
86+
/// Fails if the type of storeinput does not match the store type
87+
#[tracing::instrument(skip(self))]
88+
pub(crate) fn store_input_to_store_key_val(
89+
&self,
90+
store_name: &StoreName,
91+
store_input: StoreInput,
92+
store_value: &StoreValue,
93+
) -> Result<(StoreKey, StoreValue), AIProxyError> {
94+
let metadata_key = &*AHNLICH_AI_RESERVED_META_KEY;
95+
if store_value.contains_key(metadata_key) {
96+
return Err(AIProxyError::ReservedError(metadata_key.to_string()));
97+
}
98+
let store = self.get(store_name)?;
99+
let input_type = store_input.clone().into();
100+
101+
if store.r#type != input_type {
102+
return Err(AIProxyError::StoreTypeMismatch {
103+
store_type: store.r#type.clone(),
104+
input_type,
105+
});
106+
}
107+
let store_key = store.model.model_ndarray(&store_input);
108+
let metadata_value: MetadataValue = store_input.into();
109+
let mut final_store_value: StdHashMap<MetadataKey, MetadataValue> =
110+
store_value.clone().into_iter().collect();
111+
final_store_value.insert(metadata_key.clone(), metadata_value);
112+
return Ok((store_key, final_store_value));
113+
}
114+
115+
/// Converts (storekey, storevalue) into (storeinput, storevalue)
116+
/// by removing the reserved_key from storevalue
117+
#[tracing::instrument(skip(self))]
118+
pub(crate) fn store_key_val_to_store_input_val(
119+
&self,
120+
output: Vec<(StoreKey, StoreValue)>,
121+
) -> Vec<(StoreInput, StoreValue)> {
122+
let metadata_key = &*AHNLICH_AI_RESERVED_META_KEY;
123+
124+
// TODO: Will parallelized
125+
output
126+
.into_iter()
127+
.filter_map(|(_, mut store_value)| {
128+
store_value
129+
.remove(metadata_key)
130+
.map(|val| (val, store_value))
131+
})
132+
.map(|(metadata_value, store_value)| {
133+
let store_input: StoreInput = metadata_value.into();
134+
(store_input, store_value)
135+
})
136+
.collect()
137+
}
138+
139+
#[tracing::instrument(skip(self))]
140+
pub(crate) fn get_ndarray_repr_for_store(
141+
&self,
142+
store_name: &StoreName,
143+
store_input: &StoreInput,
144+
) -> Result<StoreKey, AIProxyError> {
145+
let store = self.get(store_name)?;
146+
Ok(store.model.model_ndarray(store_input))
147+
}
148+
149+
/// Matches DROPSTORE - Drops a store if exist, else returns an error
150+
#[tracing::instrument(skip(self))]
151+
pub(crate) fn drop_store(
152+
&self,
153+
store_name: StoreName,
154+
error_if_not_exists: bool,
155+
) -> Result<usize, AIProxyError> {
156+
let pinned = self.stores.pin();
157+
let removed = pinned.remove(&store_name).is_some();
158+
if !removed && error_if_not_exists {
159+
return Err(AIProxyError::StoreNotFound(store_name));
160+
}
161+
let removed = if !removed {
162+
0
163+
} else {
164+
//self.set_write_flag();
165+
1
166+
};
167+
Ok(removed)
168+
}
169+
}
170+
171+
#[derive(Debug, Serialize, Deserialize)]
172+
pub struct AIStore {
173+
name: StoreName,
174+
/// Making use of a concurrent hashmap, we should be able to create an engine that manages stores
175+
r#type: AIStoreType,
176+
model: AIModel,
177+
}
178+
179+
impl AIStore {
180+
pub(super) fn create(r#type: AIStoreType, store_name: StoreName, model: AIModel) -> Self {
181+
Self {
182+
r#type,
183+
name: store_name,
184+
model,
185+
}
186+
}
187+
}

ahnlich/ai/src/error.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use ahnlich_types::{ai::AIStoreType, keyval::StoreName};
2+
use thiserror::Error;
3+
4+
#[derive(Error, Debug, Eq, PartialEq, PartialOrd, Ord)]
5+
pub enum AIProxyError {
6+
#[error("Store {0} not found")]
7+
StoreNotFound(StoreName),
8+
#[error("Store {0} already exists")]
9+
StoreAlreadyExists(StoreName),
10+
#[error("Proxy Errored with {0} ")]
11+
StandardError(String),
12+
#[error("Proxy Errored with {0} ")]
13+
DatabaseClientError(String),
14+
#[error("Reserved key {0} used")]
15+
ReservedError(String),
16+
#[error("Unexpected DB Response {0} ")]
17+
UnexpectedDBResponse(String),
18+
#[error("Store dimension is [{store_type}], input dimension of [{input_type}] was specified")]
19+
StoreTypeMismatch {
20+
store_type: AIStoreType,
21+
input_type: AIStoreType,
22+
},
23+
}

0 commit comments

Comments
 (0)