Skip to content

Commit 46b65dd

Browse files
authored
AIProxy Set Command Preprocessing actions for storeinputs (#77)
* Add a new field preprocess_action for aiproxy set command * Add preprocess_action field to python client * Add new fields for aiproxy create store commands * Update python client after changes to create store command * Remove index and query types from aiproxy create store * Update python client with changes to aiproxy create store params * - Pass preproccess actions to set command - Create base on how to edit inputtypes and convert to storekeys - Move preprocessing actions - Tie model info into aimodel - Fix fmt trait on AIStoreInputTypes * regenerate types for python client * Create ai model struct to hold info about supported models, Implement additional functions on aimodel types via traits * Cleanup and rename aistoreinputtypes * Improvements to preprocessing logic, remove &mut references
1 parent 0325f15 commit 46b65dd

File tree

25 files changed

+859
-240
lines changed

25 files changed

+859
-240
lines changed

ahnlich/Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ check: ## cargo check
3434
test: ## cargo test
3535
cargo nextest run --no-capture
3636

37+
generate-specs: ## cargo run --bin typegen generate
38+
cargo run --bin typegen generate
39+
3740

3841
bump-protocol-version: ## Bump project versions. Rules for bumpversion: patch, minor, major.
3942
@echo "Bumping Rust project version to $${RULE:-$(BUMP_RULE)}"

ahnlich/ai/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ ahnlich_client_rs = { path = "../client", version = "*" }
2929
ahnlich_similarity = { path = "../similarity", version = "*", features = ["serde"] }
3030
cap.workspace = true
3131
deadpool.workspace = true
32+
nonzero_ext = "0.3.0"
3233

3334

3435
[dev-dependencies]

ahnlich/ai/src/engine/ai/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
mod models;
2+
use ahnlich_types::keyval::{StoreInput, StoreKey};
3+
use models::ModelInfo;
4+
use std::num::NonZeroUsize;
5+
6+
pub trait AIModelManager {
7+
fn embedding_size(&self) -> NonZeroUsize;
8+
fn model_ndarray(&self, storeinput: &StoreInput) -> StoreKey;
9+
fn model_info(&self) -> ModelInfo;
10+
}

ahnlich/ai/src/engine/ai/models.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
use crate::engine::ai::AIModelManager;
2+
use ahnlich_types::{
3+
ai::{AIModel, AIStoreInputType},
4+
keyval::{StoreInput, StoreKey},
5+
};
6+
use ndarray::Array1;
7+
use nonzero_ext::nonzero;
8+
use serde::{Deserialize, Serialize};
9+
use std::num::NonZeroUsize;
10+
11+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
12+
pub struct ModelInfo {
13+
pub name: String,
14+
pub embedding_size: NonZeroUsize,
15+
pub input_type: AIStoreInputType,
16+
}
17+
18+
impl AIModelManager for AIModel {
19+
fn embedding_size(&self) -> NonZeroUsize {
20+
self.model_info().embedding_size
21+
}
22+
// TODO: model ndarray values is based on length of string or vec, so for now make sure strings
23+
// or vecs have different lengths
24+
fn model_ndarray(&self, storeinput: &StoreInput) -> StoreKey {
25+
let length = storeinput.len() as f32;
26+
StoreKey(
27+
Array1::from_iter(0..self.model_info().embedding_size.into())
28+
.mapv(|v| v as f32 * length),
29+
)
30+
}
31+
32+
fn model_info(&self) -> ModelInfo {
33+
match self {
34+
AIModel::Llama3 => ModelInfo {
35+
name: String::from("Llama3"),
36+
embedding_size: nonzero!(100usize),
37+
input_type: AIStoreInputType::RawString,
38+
},
39+
AIModel::DALLE3 => ModelInfo {
40+
name: String::from("DALL.E 3"),
41+
embedding_size: nonzero!(300usize),
42+
input_type: AIStoreInputType::Image,
43+
},
44+
}
45+
}
46+
}

ahnlich/ai/src/engine/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
pub mod ai;
12
pub mod store;

ahnlich/ai/src/engine/store.rs

Lines changed: 137 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
use crate::error::AIProxyError;
21
use crate::AHNLICH_AI_RESERVED_META_KEY;
3-
use ahnlich_types::ai::AIModel;
4-
use ahnlich_types::ai::AIStoreInfo;
2+
use crate::{engine::ai::AIModelManager, error::AIProxyError};
3+
use ahnlich_types::ai::{
4+
AIModel, AIStoreInfo, AIStoreInputType, ImageAction, PreprocessAction, StringAction,
5+
};
56
use ahnlich_types::keyval::StoreInput;
67
use ahnlich_types::keyval::StoreKey;
78
use ahnlich_types::keyval::StoreName;
@@ -57,14 +58,18 @@ impl AIStoreHandler {
5758
pub(crate) fn create_store(
5859
&self,
5960
store_name: StoreName,
60-
61-
model: AIModel,
61+
query_model: AIModel,
62+
index_model: AIModel,
6263
) -> Result<(), AIProxyError> {
6364
if self
6465
.stores
6566
.try_insert(
6667
store_name.clone(),
67-
Arc::new(AIStore::create(store_name.clone(), model.clone())),
68+
Arc::new(AIStore::create(
69+
store_name.clone(),
70+
query_model,
71+
index_model,
72+
)),
6873
&self.stores.guard(),
6974
)
7075
.is_err()
@@ -82,8 +87,9 @@ impl AIStoreHandler {
8287
.iter(&self.stores.guard())
8388
.map(|(store_name, store)| AIStoreInfo {
8489
name: store_name.clone(),
85-
model: store.model.clone(),
86-
embedding_size: store.model.embedding_size().into(),
90+
query_model: store.query_model.clone(),
91+
index_model: store.index_model.clone(),
92+
embedding_size: store.index_model.embedding_size().into(),
8793
})
8894
.collect()
8995
}
@@ -100,28 +106,58 @@ impl AIStoreHandler {
100106
}
101107

102108
/// Converts storeinput into a tuple of storekey and storevalue.
103-
/// Fails if the type of storeinput does not match the store type
109+
/// Fails if the store input type does not match the store index_type
104110
#[tracing::instrument(skip(self))]
105111
pub(crate) fn store_input_to_store_key_val(
106112
&self,
107113
store_name: &StoreName,
108114
store_input: StoreInput,
109-
store_value: &StoreValue,
115+
store_value: StoreValue,
116+
preprocess_action: &PreprocessAction,
110117
) -> Result<(StoreKey, StoreValue), AIProxyError> {
111118
let metadata_key = &*AHNLICH_AI_RESERVED_META_KEY;
112119
if store_value.contains_key(metadata_key) {
113120
return Err(AIProxyError::ReservedError(metadata_key.to_string()));
114121
}
115122
let store = self.get(store_name)?;
116123

117-
let store_key = store.model.model_ndarray(&store_input);
118-
let metadata_value: MetadataValue = store_input.into();
124+
let store_input_type: AIStoreInputType = (&store_input).into();
125+
let store_index_model_info = store.index_model.model_info();
126+
127+
if store_input_type != store_index_model_info.input_type {
128+
return Err(AIProxyError::StoreSetTypeMismatchError {
129+
store_index_model_type: store_index_model_info.input_type,
130+
storeinput_type: store_input_type,
131+
});
132+
}
133+
134+
let metadata_value: MetadataValue = store_input.clone().into();
119135
let mut final_store_value: StdHashMap<MetadataKey, MetadataValue> =
120136
store_value.clone().into_iter().collect();
121137
final_store_value.insert(metadata_key.clone(), metadata_value);
138+
139+
let store_key =
140+
self.create_store_key(store_input, &store.index_model, preprocess_action)?;
122141
return Ok((store_key, final_store_value));
123142
}
124143

144+
/// Converts storeinput into a tuple of storekey and storevalue.
145+
/// Fails if the store input type does not match the store index_type
146+
#[tracing::instrument(skip(self))]
147+
pub(crate) fn create_store_key(
148+
&self,
149+
store_input: StoreInput,
150+
index_store_model: &AIModel,
151+
preprocess_action: &PreprocessAction,
152+
) -> Result<StoreKey, AIProxyError> {
153+
// Process the inner value of a store input and convert it into a ndarray by passing
154+
// it into index model. Create a storekey from ndarray
155+
let processed_input =
156+
self.preprocess_store_input(preprocess_action, store_input, index_store_model)?;
157+
let store_key = index_store_model.model_ndarray(&processed_input);
158+
Ok(store_key)
159+
}
160+
125161
/// Converts (storekey, storevalue) into (storeinput, storevalue)
126162
/// by removing the reserved_key from storevalue
127163
#[tracing::instrument(skip(self))]
@@ -153,7 +189,7 @@ impl AIStoreHandler {
153189
store_input: &StoreInput,
154190
) -> Result<StoreKey, AIProxyError> {
155191
let store = self.get(store_name)?;
156-
Ok(store.model.model_ndarray(store_input))
192+
Ok(store.index_model.model_ndarray(store_input))
157193
}
158194

159195
/// Matches DROPSTORE - Drops a store if exist, else returns an error
@@ -185,20 +221,105 @@ impl AIStoreHandler {
185221
self.stores.clear(&guard);
186222
store_length
187223
}
224+
225+
#[tracing::instrument(skip(self))]
226+
pub(crate) fn preprocess_store_input(
227+
&self,
228+
process_action: &PreprocessAction,
229+
input: StoreInput,
230+
index_model: &AIModel,
231+
) -> Result<StoreInput, AIProxyError> {
232+
match (process_action, input) {
233+
(PreprocessAction::Image(image_action), StoreInput::Image(image_input)) => {
234+
// resize image and edit
235+
let output = self.process_image(image_input, index_model, image_action)?;
236+
Ok(output)
237+
}
238+
(PreprocessAction::RawString(string_action), StoreInput::RawString(string_input)) => {
239+
let output =
240+
self.preprocess_raw_string(string_input, index_model, string_action)?;
241+
Ok(output)
242+
}
243+
(PreprocessAction::RawString(_), StoreInput::Image(_)) => {
244+
Err(AIProxyError::PreprocessingMismatchError {
245+
input_type: AIStoreInputType::Image,
246+
preprocess_action: process_action.clone(),
247+
})
248+
}
249+
250+
(PreprocessAction::Image(_), StoreInput::RawString(_)) => {
251+
Err(AIProxyError::PreprocessingMismatchError {
252+
input_type: AIStoreInputType::RawString,
253+
preprocess_action: process_action.clone(),
254+
})
255+
}
256+
}
257+
}
258+
fn preprocess_raw_string(
259+
&self,
260+
input: String,
261+
index_model: &AIModel,
262+
string_action: &StringAction,
263+
) -> Result<StoreInput, AIProxyError> {
264+
// tokenize string, return error if max token
265+
//let tokenized_input;
266+
let model_embedding_dim = index_model.model_info().embedding_size;
267+
if input.len() > model_embedding_dim.into() {
268+
if let StringAction::ErrorIfTokensExceed = string_action {
269+
return Err(AIProxyError::TokenExceededError {
270+
input_token_size: input.len(),
271+
model_embedding_size: model_embedding_dim.into(),
272+
});
273+
} else {
274+
// truncate raw string
275+
// let tokenized_input;
276+
let _input = input.as_str()[..model_embedding_dim.into()].to_string();
277+
}
278+
}
279+
Ok(StoreInput::RawString(input))
280+
}
281+
282+
fn process_image(
283+
&self,
284+
input: Vec<u8>,
285+
index_model: &AIModel,
286+
image_action: &ImageAction,
287+
) -> Result<StoreInput, AIProxyError> {
288+
// process image, return error if max dimensions exceeded
289+
// let image_data;
290+
let model_embedding_dim = index_model.embedding_size().into();
291+
if input.len() > model_embedding_dim {
292+
if let ImageAction::ErrorIfDimensionsMismatch = image_action {
293+
return Err(AIProxyError::ImageDimensionsMismatchError {
294+
image_dimensions: input.len(),
295+
max_dimensions: model_embedding_dim,
296+
});
297+
} else {
298+
// resize image
299+
}
300+
}
301+
Ok(StoreInput::Image(input))
302+
}
188303
}
189304

190305
#[derive(Debug, Serialize, Deserialize)]
191306
pub struct AIStore {
192307
name: StoreName,
193308
/// Making use of a concurrent hashmap, we should be able to create an engine that manages stores
194-
model: AIModel,
309+
query_model: AIModel,
310+
index_model: AIModel,
195311
}
196312

197313
impl AIStore {
198-
pub(super) fn create(store_name: StoreName, model: AIModel) -> Self {
314+
pub(super) fn create(
315+
store_name: StoreName,
316+
query_model: AIModel,
317+
index_model: AIModel,
318+
) -> Self {
199319
Self {
200320
name: store_name,
201-
model,
321+
query_model,
322+
index_model,
202323
}
203324
}
204325
}

ahnlich/ai/src/error.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use ahnlich_types::keyval::StoreName;
1+
use ahnlich_types::{
2+
ai::{AIStoreInputType, PreprocessAction},
3+
keyval::StoreName,
4+
};
25
use thiserror::Error;
36

47
#[derive(Error, Debug, Eq, PartialEq, PartialOrd, Ord)]
@@ -7,12 +10,44 @@ pub enum AIProxyError {
710
StoreNotFound(StoreName),
811
#[error("Store {0} already exists")]
912
StoreAlreadyExists(StoreName),
13+
1014
#[error("Proxy Errored with {0} ")]
1115
StandardError(String),
16+
1217
#[error("Proxy Errored with {0} ")]
1318
DatabaseClientError(String),
1419
#[error("Reserved key {0} used")]
1520
ReservedError(String),
1621
#[error("Unexpected DB Response {0} ")]
1722
UnexpectedDBResponse(String),
23+
#[error("Cannot Query Using Input. Store expects [{store_query_model_type}], but input type [{storeinput_type}] was provided")]
24+
StoreQueryTypeMismatchError {
25+
store_query_model_type: AIStoreInputType,
26+
storeinput_type: AIStoreInputType,
27+
},
28+
#[error("Cannot Set Input. Store expects [{store_index_model_type}], input type [{storeinput_type}] was provided")]
29+
StoreSetTypeMismatchError {
30+
store_index_model_type: AIStoreInputType,
31+
storeinput_type: AIStoreInputType,
32+
},
33+
34+
#[error("Max Token Exceeded. Model Expects [{model_embedding_size}], input type was [{input_token_size}] ")]
35+
TokenExceededError {
36+
model_embedding_size: usize,
37+
input_token_size: usize,
38+
},
39+
40+
#[error(
41+
"Image Dimensions [{image_dimensions}] exceeds max model dimensions [{max_dimensions}] "
42+
)]
43+
ImageDimensionsMismatchError {
44+
image_dimensions: usize,
45+
max_dimensions: usize,
46+
},
47+
48+
#[error("Used [{preprocess_action}] for [{input_type}] type")]
49+
PreprocessingMismatchError {
50+
input_type: AIStoreInputType,
51+
preprocess_action: PreprocessAction,
52+
},
1853
}

0 commit comments

Comments
 (0)