Skip to content

Commit c2c984c

Browse files
use the model type to match safely
1 parent 7441c45 commit c2c984c

File tree

3 files changed

+76
-87
lines changed

3 files changed

+76
-87
lines changed

router/src/http/server.rs

Lines changed: 67 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::http::types::{
44
EmbedSparseResponse, Input, OpenAICompatEmbedding, OpenAICompatErrorResponse,
55
OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest,
66
PredictResponse, Prediction, Rank, RerankRequest, RerankResponse, Sequence, SimpleToken,
7-
SparseValue, TokenizeRequest, TokenizeResponse, VertexInstance, VertexRequest, VertexResponse,
7+
SparseValue, TokenizeRequest, TokenizeResponse, VertexRequest, VertexResponse,
88
VertexResponseInstance,
99
};
1010
use crate::{
@@ -1180,11 +1180,6 @@ async fn vertex_compatibility(
11801180
let result = embed(infer, info, Json(req)).await?;
11811181
Ok(VertexResponseInstance::Embed(result.1 .0))
11821182
};
1183-
let embed_all_future =
1184-
move |infer: Extension<Infer>, info: Extension<Info>, req: EmbedAllRequest| async move {
1185-
let result = embed_all(infer, info, Json(req)).await?;
1186-
Ok(VertexResponseInstance::EmbedAll(result.1 .0))
1187-
};
11881183
let embed_sparse_future =
11891184
move |infer: Extension<Infer>, info: Extension<Info>, req: EmbedSparseRequest| async move {
11901185
let result = embed_sparse(infer, info, Json(req)).await?;
@@ -1200,45 +1195,44 @@ async fn vertex_compatibility(
12001195
let result = rerank(infer, info, Json(req)).await?;
12011196
Ok(VertexResponseInstance::Rerank(result.1 .0))
12021197
};
1203-
let tokenize_future =
1204-
move |infer: Extension<Infer>, info: Extension<Info>, req: TokenizeRequest| async move {
1205-
let result = tokenize(infer, info, Json(req)).await?;
1206-
Ok(VertexResponseInstance::Tokenize(result.0))
1207-
};
12081198

12091199
let mut futures = Vec::with_capacity(req.instances.len());
12101200
for instance in req.instances {
12111201
let local_infer = infer.clone();
12121202
let local_info = info.clone();
12131203

1214-
match instance {
1215-
VertexInstance::Embed(req) => {
1216-
futures.push(embed_future(local_infer, local_info, req).boxed());
1217-
}
1218-
VertexInstance::EmbedAll(req) => {
1219-
futures.push(embed_all_future(local_infer, local_info, req).boxed());
1220-
}
1221-
VertexInstance::EmbedSparse(req) => {
1222-
futures.push(embed_sparse_future(local_infer, local_info, req).boxed());
1223-
}
1224-
VertexInstance::Predict(req) => {
1225-
futures.push(predict_future(local_infer, local_info, req).boxed());
1226-
}
1227-
VertexInstance::Rerank(req) => {
1228-
futures.push(rerank_future(local_infer, local_info, req).boxed());
1204+
// Rerank is the only payload that can me matched safely
1205+
if let Ok(instance) = serde_json::from_value::<RerankRequest>(instance.clone()) {
1206+
futures.push(rerank_future(local_infer, local_info, instance).boxed());
1207+
continue;
1208+
}
1209+
1210+
match info.model_type {
1211+
ModelType::Classifier(_) | ModelType::Reranker(_) => {
1212+
let instance = serde_json::from_value::<PredictRequest>(instance)
1213+
.map_err(ErrorResponse::from)?;
1214+
futures.push(predict_future(local_infer, local_info, instance).boxed());
12291215
}
1230-
VertexInstance::Tokenize(req) => {
1231-
futures.push(tokenize_future(local_infer, local_info, req).boxed());
1216+
ModelType::Embedding(_) => {
1217+
if infer.is_splade() {
1218+
let instance = serde_json::from_value::<EmbedSparseRequest>(instance)
1219+
.map_err(ErrorResponse::from)?;
1220+
futures.push(embed_sparse_future(local_infer, local_info, instance).boxed());
1221+
} else {
1222+
let instance = serde_json::from_value::<EmbedRequest>(instance)
1223+
.map_err(ErrorResponse::from)?;
1224+
futures.push(embed_future(local_infer, local_info, instance).boxed());
1225+
}
12321226
}
12331227
}
12341228
}
12351229

1236-
let results = join_all(futures)
1230+
let predictions = join_all(futures)
12371231
.await
12381232
.into_iter()
12391233
.collect::<Result<Vec<VertexResponseInstance>, (StatusCode, Json<ErrorResponse>)>>()?;
12401234

1241-
Ok(Json(VertexResponse(results)))
1235+
Ok(Json(VertexResponse { predictions }))
12421236
}
12431237

12441238
/// Prometheus metrics scrape endpoint
@@ -1350,12 +1344,7 @@ pub async fn run(
13501344
#[derive(OpenApi)]
13511345
#[openapi(
13521346
paths(vertex_compatibility),
1353-
components(schemas(
1354-
VertexInstance,
1355-
VertexRequest,
1356-
VertexResponse,
1357-
VertexResponseInstance
1358-
))
1347+
components(schemas(VertexRequest, VertexResponse, VertexResponseInstance))
13591348
)]
13601349
struct VertextApiDoc;
13611350

@@ -1394,43 +1383,42 @@ pub async fn run(
13941383

13951384
let mut app = Router::new().merge(base_routes);
13961385

1397-
// Set default routes
1398-
app = match &info.model_type {
1399-
ModelType::Classifier(_) => {
1400-
app.route("/", post(predict))
1401-
// AWS Sagemaker route
1402-
.route("/invocations", post(predict))
1403-
}
1404-
ModelType::Reranker(_) => {
1405-
app.route("/", post(rerank))
1406-
// AWS Sagemaker route
1407-
.route("/invocations", post(rerank))
1408-
}
1409-
ModelType::Embedding(model) => {
1410-
if model.pooling == "splade" {
1411-
app.route("/", post(embed_sparse))
1412-
// AWS Sagemaker route
1413-
.route("/invocations", post(embed_sparse))
1414-
} else {
1415-
app.route("/", post(embed))
1416-
// AWS Sagemaker route
1417-
.route("/invocations", post(embed))
1418-
}
1419-
}
1420-
};
1421-
14221386
#[cfg(feature = "google")]
14231387
{
14241388
tracing::info!("Built with `google` feature");
1425-
tracing::info!(
1426-
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
1427-
);
1428-
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
1429-
app = app.route(&env_predict_route, post(vertex_compatibility));
1430-
}
1431-
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
1432-
app = app.route(&env_health_route, get(health));
1433-
}
1389+
let env_predict_route = std::env::var("AIP_PREDICT_ROUTE")
1390+
.context("`AIP_PREDICT_ROUTE` env var must be set for Google Vertex deployments")?;
1391+
app = app.route(&env_predict_route, post(vertex_compatibility));
1392+
let env_health_route = std::env::var("AIP_HEALTH_ROUTE")
1393+
.context("`AIP_HEALTH_ROUTE` env var must be set for Google Vertex deployments")?;
1394+
app = app.route(&env_health_route, get(health));
1395+
}
1396+
#[cfg(not(feature = "google"))]
1397+
{
1398+
// Set default routes
1399+
app = match &info.model_type {
1400+
ModelType::Classifier(_) => {
1401+
app.route("/", post(predict))
1402+
// AWS Sagemaker route
1403+
.route("/invocations", post(predict))
1404+
}
1405+
ModelType::Reranker(_) => {
1406+
app.route("/", post(rerank))
1407+
// AWS Sagemaker route
1408+
.route("/invocations", post(rerank))
1409+
}
1410+
ModelType::Embedding(model) => {
1411+
if model.pooling == "splade" {
1412+
app.route("/", post(embed_sparse))
1413+
// AWS Sagemaker route
1414+
.route("/invocations", post(embed_sparse))
1415+
} else {
1416+
app.route("/", post(embed))
1417+
// AWS Sagemaker route
1418+
.route("/invocations", post(embed))
1419+
}
1420+
}
1421+
};
14341422
}
14351423

14361424
let app = app
@@ -1485,3 +1473,12 @@ impl From<ErrorResponse> for (StatusCode, Json<OpenAICompatErrorResponse>) {
14851473
(StatusCode::from(&err.error_type), Json(err.into()))
14861474
}
14871475
}
1476+
1477+
impl From<serde_json::Error> for ErrorResponse {
1478+
fn from(err: serde_json::Error) -> Self {
1479+
ErrorResponse {
1480+
error: err.to_string(),
1481+
error_type: ErrorType::Validation,
1482+
}
1483+
}
1484+
}

router/src/http/types.rs

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -382,32 +382,21 @@ pub(crate) struct SimpleToken {
382382
#[schema(example = json!([[{"id": 0, "text": "test", "special": false, "start": 0, "stop": 2}]]))]
383383
pub(crate) struct TokenizeResponse(pub Vec<Vec<SimpleToken>>);
384384

385-
#[derive(Deserialize, ToSchema)]
386-
#[serde(tag = "type", rename_all = "snake_case")]
387-
pub(crate) enum VertexInstance {
388-
Embed(EmbedRequest),
389-
EmbedAll(EmbedAllRequest),
390-
EmbedSparse(EmbedSparseRequest),
391-
Predict(PredictRequest),
392-
Rerank(RerankRequest),
393-
Tokenize(TokenizeRequest),
394-
}
395-
396385
#[derive(Deserialize, ToSchema)]
397386
pub(crate) struct VertexRequest {
398-
pub instances: Vec<VertexInstance>,
387+
pub instances: Vec<serde_json::Value>,
399388
}
400389

401390
#[derive(Serialize, ToSchema)]
402-
#[serde(tag = "type", content = "result", rename_all = "snake_case")]
391+
#[serde(untagged)]
403392
pub(crate) enum VertexResponseInstance {
404393
Embed(EmbedResponse),
405-
EmbedAll(EmbedAllResponse),
406394
EmbedSparse(EmbedSparseResponse),
407395
Predict(PredictResponse),
408396
Rerank(RerankResponse),
409-
Tokenize(TokenizeResponse),
410397
}
411398

412399
#[derive(Serialize, ToSchema)]
413-
pub(crate) struct VertexResponse(pub Vec<VertexResponseInstance>);
400+
pub(crate) struct VertexResponse {
401+
pub predictions: Vec<VertexResponseInstance>,
402+
}

router/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ pub async fn run(
244244
std::env::var("AIP_HTTP_PORT")
245245
.ok()
246246
.and_then(|p| p.parse().ok())
247-
.context("Invalid or unset AIP_HTTP_PORT")?
247+
.context("`AIP_HTTP_PORT` env var must be set for Google Vertex deployments")?
248248
} else {
249249
port
250250
};
@@ -262,6 +262,9 @@ pub async fn run(
262262
#[cfg(all(feature = "grpc", feature = "http"))]
263263
compile_error!("Features `http` and `grpc` cannot be enabled at the same time.");
264264

265+
#[cfg(all(feature = "grpc", feature = "google"))]
266+
compile_error!("Features `http` and `google` cannot be enabled at the same time.");
267+
265268
#[cfg(not(any(feature = "http", feature = "grpc")))]
266269
compile_error!("Either feature `http` or `grpc` must be enabled.");
267270

0 commit comments

Comments
 (0)