@@ -4,7 +4,7 @@ use crate::http::types::{
4
4
EmbedSparseResponse , Input , OpenAICompatEmbedding , OpenAICompatErrorResponse ,
5
5
OpenAICompatRequest , OpenAICompatResponse , OpenAICompatUsage , PredictInput , PredictRequest ,
6
6
PredictResponse , Prediction , Rank , RerankRequest , RerankResponse , Sequence , SimpleToken ,
7
- SparseValue , TokenizeRequest , TokenizeResponse , VertexInstance , VertexRequest , VertexResponse ,
7
+ SparseValue , TokenizeRequest , TokenizeResponse , VertexRequest , VertexResponse ,
8
8
VertexResponseInstance ,
9
9
} ;
10
10
use crate :: {
@@ -1180,11 +1180,6 @@ async fn vertex_compatibility(
1180
1180
let result = embed ( infer, info, Json ( req) ) . await ?;
1181
1181
Ok ( VertexResponseInstance :: Embed ( result. 1 . 0 ) )
1182
1182
} ;
1183
- let embed_all_future =
1184
- move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedAllRequest | async move {
1185
- let result = embed_all ( infer, info, Json ( req) ) . await ?;
1186
- Ok ( VertexResponseInstance :: EmbedAll ( result. 1 . 0 ) )
1187
- } ;
1188
1183
let embed_sparse_future =
1189
1184
move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedSparseRequest | async move {
1190
1185
let result = embed_sparse ( infer, info, Json ( req) ) . await ?;
@@ -1200,45 +1195,44 @@ async fn vertex_compatibility(
1200
1195
let result = rerank ( infer, info, Json ( req) ) . await ?;
1201
1196
Ok ( VertexResponseInstance :: Rerank ( result. 1 . 0 ) )
1202
1197
} ;
1203
- let tokenize_future =
1204
- move |infer : Extension < Infer > , info : Extension < Info > , req : TokenizeRequest | async move {
1205
- let result = tokenize ( infer, info, Json ( req) ) . await ?;
1206
- Ok ( VertexResponseInstance :: Tokenize ( result. 0 ) )
1207
- } ;
1208
1198
1209
1199
let mut futures = Vec :: with_capacity ( req. instances . len ( ) ) ;
1210
1200
for instance in req. instances {
1211
1201
let local_infer = infer. clone ( ) ;
1212
1202
let local_info = info. clone ( ) ;
1213
1203
1214
- match instance {
1215
- VertexInstance :: Embed ( req) => {
1216
- futures. push ( embed_future ( local_infer, local_info, req) . boxed ( ) ) ;
1217
- }
1218
- VertexInstance :: EmbedAll ( req) => {
1219
- futures. push ( embed_all_future ( local_infer, local_info, req) . boxed ( ) ) ;
1220
- }
1221
- VertexInstance :: EmbedSparse ( req) => {
1222
- futures. push ( embed_sparse_future ( local_infer, local_info, req) . boxed ( ) ) ;
1223
- }
1224
- VertexInstance :: Predict ( req) => {
1225
- futures. push ( predict_future ( local_infer, local_info, req) . boxed ( ) ) ;
1226
- }
1227
- VertexInstance :: Rerank ( req) => {
1228
- futures. push ( rerank_future ( local_infer, local_info, req) . boxed ( ) ) ;
1204
+ // Rerank is the only payload that can me matched safely
1205
+ if let Ok ( instance) = serde_json:: from_value :: < RerankRequest > ( instance. clone ( ) ) {
1206
+ futures. push ( rerank_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1207
+ continue ;
1208
+ }
1209
+
1210
+ match info. model_type {
1211
+ ModelType :: Classifier ( _) | ModelType :: Reranker ( _) => {
1212
+ let instance = serde_json:: from_value :: < PredictRequest > ( instance)
1213
+ . map_err ( ErrorResponse :: from) ?;
1214
+ futures. push ( predict_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1229
1215
}
1230
- VertexInstance :: Tokenize ( req) => {
1231
- futures. push ( tokenize_future ( local_infer, local_info, req) . boxed ( ) ) ;
1216
+ ModelType :: Embedding ( _) => {
1217
+ if infer. is_splade ( ) {
1218
+ let instance = serde_json:: from_value :: < EmbedSparseRequest > ( instance)
1219
+ . map_err ( ErrorResponse :: from) ?;
1220
+ futures. push ( embed_sparse_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1221
+ } else {
1222
+ let instance = serde_json:: from_value :: < EmbedRequest > ( instance)
1223
+ . map_err ( ErrorResponse :: from) ?;
1224
+ futures. push ( embed_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1225
+ }
1232
1226
}
1233
1227
}
1234
1228
}
1235
1229
1236
- let results = join_all ( futures)
1230
+ let predictions = join_all ( futures)
1237
1231
. await
1238
1232
. into_iter ( )
1239
1233
. collect :: < Result < Vec < VertexResponseInstance > , ( StatusCode , Json < ErrorResponse > ) > > ( ) ?;
1240
1234
1241
- Ok ( Json ( VertexResponse ( results ) ) )
1235
+ Ok ( Json ( VertexResponse { predictions } ) )
1242
1236
}
1243
1237
1244
1238
/// Prometheus metrics scrape endpoint
@@ -1350,12 +1344,7 @@ pub async fn run(
1350
1344
#[ derive( OpenApi ) ]
1351
1345
#[ openapi(
1352
1346
paths( vertex_compatibility) ,
1353
- components( schemas(
1354
- VertexInstance ,
1355
- VertexRequest ,
1356
- VertexResponse ,
1357
- VertexResponseInstance
1358
- ) )
1347
+ components( schemas( VertexRequest , VertexResponse , VertexResponseInstance ) )
1359
1348
) ]
1360
1349
struct VertextApiDoc ;
1361
1350
@@ -1394,43 +1383,42 @@ pub async fn run(
1394
1383
1395
1384
let mut app = Router :: new ( ) . merge ( base_routes) ;
1396
1385
1397
- // Set default routes
1398
- app = match & info. model_type {
1399
- ModelType :: Classifier ( _) => {
1400
- app. route ( "/" , post ( predict) )
1401
- // AWS Sagemaker route
1402
- . route ( "/invocations" , post ( predict) )
1403
- }
1404
- ModelType :: Reranker ( _) => {
1405
- app. route ( "/" , post ( rerank) )
1406
- // AWS Sagemaker route
1407
- . route ( "/invocations" , post ( rerank) )
1408
- }
1409
- ModelType :: Embedding ( model) => {
1410
- if model. pooling == "splade" {
1411
- app. route ( "/" , post ( embed_sparse) )
1412
- // AWS Sagemaker route
1413
- . route ( "/invocations" , post ( embed_sparse) )
1414
- } else {
1415
- app. route ( "/" , post ( embed) )
1416
- // AWS Sagemaker route
1417
- . route ( "/invocations" , post ( embed) )
1418
- }
1419
- }
1420
- } ;
1421
-
1422
1386
#[ cfg( feature = "google" ) ]
1423
1387
{
1424
1388
tracing:: info!( "Built with `google` feature" ) ;
1425
- tracing:: info!(
1426
- "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
1427
- ) ;
1428
- if let Ok ( env_predict_route) = std:: env:: var ( "AIP_PREDICT_ROUTE" ) {
1429
- app = app. route ( & env_predict_route, post ( vertex_compatibility) ) ;
1430
- }
1431
- if let Ok ( env_health_route) = std:: env:: var ( "AIP_HEALTH_ROUTE" ) {
1432
- app = app. route ( & env_health_route, get ( health) ) ;
1433
- }
1389
+ let env_predict_route = std:: env:: var ( "AIP_PREDICT_ROUTE" )
1390
+ . context ( "`AIP_PREDICT_ROUTE` env var must be set for Google Vertex deployments" ) ?;
1391
+ app = app. route ( & env_predict_route, post ( vertex_compatibility) ) ;
1392
+ let env_health_route = std:: env:: var ( "AIP_HEALTH_ROUTE" )
1393
+ . context ( "`AIP_HEALTH_ROUTE` env var must be set for Google Vertex deployments" ) ?;
1394
+ app = app. route ( & env_health_route, get ( health) ) ;
1395
+ }
1396
+ #[ cfg( not( feature = "google" ) ) ]
1397
+ {
1398
+ // Set default routes
1399
+ app = match & info. model_type {
1400
+ ModelType :: Classifier ( _) => {
1401
+ app. route ( "/" , post ( predict) )
1402
+ // AWS Sagemaker route
1403
+ . route ( "/invocations" , post ( predict) )
1404
+ }
1405
+ ModelType :: Reranker ( _) => {
1406
+ app. route ( "/" , post ( rerank) )
1407
+ // AWS Sagemaker route
1408
+ . route ( "/invocations" , post ( rerank) )
1409
+ }
1410
+ ModelType :: Embedding ( model) => {
1411
+ if model. pooling == "splade" {
1412
+ app. route ( "/" , post ( embed_sparse) )
1413
+ // AWS Sagemaker route
1414
+ . route ( "/invocations" , post ( embed_sparse) )
1415
+ } else {
1416
+ app. route ( "/" , post ( embed) )
1417
+ // AWS Sagemaker route
1418
+ . route ( "/invocations" , post ( embed) )
1419
+ }
1420
+ }
1421
+ } ;
1434
1422
}
1435
1423
1436
1424
let app = app
@@ -1485,3 +1473,12 @@ impl From<ErrorResponse> for (StatusCode, Json<OpenAICompatErrorResponse>) {
1485
1473
( StatusCode :: from ( & err. error_type ) , Json ( err. into ( ) ) )
1486
1474
}
1487
1475
}
1476
+
1477
+ impl From < serde_json:: Error > for ErrorResponse {
1478
+ fn from ( err : serde_json:: Error ) -> Self {
1479
+ ErrorResponse {
1480
+ error : err. to_string ( ) ,
1481
+ error_type : ErrorType :: Validation ,
1482
+ }
1483
+ }
1484
+ }
0 commit comments