@@ -13,12 +13,13 @@ use ahnlich_types::{
13
13
14
14
use once_cell:: sync:: Lazy ;
15
15
use pretty_assertions:: assert_eq;
16
- use std:: { collections:: HashSet , num:: NonZeroUsize } ;
16
+ use std:: { collections:: HashSet , num:: NonZeroUsize , sync :: atomic :: Ordering } ;
17
17
18
18
use crate :: cli:: AIProxyConfig ;
19
19
use crate :: server:: handler:: AIProxyServer ;
20
20
use ahnlich_types:: bincode:: BinCodeSerAndDeser ;
21
21
use std:: net:: SocketAddr ;
22
+ use std:: path:: PathBuf ;
22
23
use tokio:: io:: { AsyncReadExt , AsyncWriteExt , BufReader } ;
23
24
use tokio:: net:: TcpStream ;
24
25
use tokio:: time:: { timeout, Duration } ;
@@ -31,6 +32,16 @@ static AI_CONFIG: Lazy<AIProxyConfig> = Lazy::new(|| {
31
32
ai_proxy
32
33
} ) ;
33
34
35
+ static PERSISTENCE_FILE : Lazy < PathBuf > =
36
+ Lazy :: new ( || PathBuf :: from ( env ! ( "CARGO_MANIFEST_DIR" ) ) . join ( "ahnlich_ai_proxy.dat" ) ) ;
37
+
38
+ static AI_CONFIG_WITH_PERSISTENCE : Lazy < AIProxyConfig > = Lazy :: new ( || {
39
+ AIProxyConfig :: default ( )
40
+ . os_select_port ( )
41
+ . set_persistence_interval ( 200 )
42
+ . set_persist_location ( ( * PERSISTENCE_FILE ) . clone ( ) )
43
+ } ) ;
44
+
34
45
async fn get_server_response (
35
46
reader : & mut BufReader < TcpStream > ,
36
47
query : AIServerQuery ,
@@ -446,3 +457,129 @@ async fn test_ai_proxy_fails_db_server_unavailable() {
446
457
let err = res. err ( ) . unwrap ( ) ;
447
458
assert ! ( err. contains( " kind: ConnectionRefused," ) )
448
459
}
460
+
461
+ #[ tokio:: test]
462
+ async fn test_ai_proxy_test_with_persistence ( ) {
463
+ let server = Server :: new ( & CONFIG )
464
+ . await
465
+ . expect ( "Could not initialize server" ) ;
466
+ let ai_server = AIProxyServer :: new ( & AI_CONFIG_WITH_PERSISTENCE )
467
+ . await
468
+ . expect ( "Could not initialize ai proxy" ) ;
469
+
470
+ let address = ai_server. local_addr ( ) . expect ( "Could not get local addr" ) ;
471
+ let _ = tokio:: spawn ( async move { server. start ( ) . await } ) ;
472
+ let write_flag = ai_server. write_flag ( ) ;
473
+ // start up ai proxy
474
+ let _ = tokio:: spawn ( async move { ai_server. start ( ) . await } ) ;
475
+ // Allow some time for the servers to start
476
+ tokio:: time:: sleep ( Duration :: from_millis ( 200 ) ) . await ;
477
+
478
+ let store_name = StoreName ( String :: from ( "Main" ) ) ;
479
+ let store_name_2 = StoreName ( String :: from ( "Main2" ) ) ;
480
+ let first_stream = TcpStream :: connect ( address) . await . unwrap ( ) ;
481
+
482
+ let message = AIServerQuery :: from_queries ( & [
483
+ AIQuery :: CreateStore {
484
+ r#type : AIStoreType :: RawString ,
485
+ store : store_name. clone ( ) ,
486
+ model : AIModel :: Llama3 ,
487
+ predicates : HashSet :: from_iter ( [ ] ) ,
488
+ non_linear_indices : HashSet :: new ( ) ,
489
+ } ,
490
+ AIQuery :: CreateStore {
491
+ r#type : AIStoreType :: Binary ,
492
+ store : store_name_2. clone ( ) ,
493
+ model : AIModel :: Llama3 ,
494
+ predicates : HashSet :: from_iter ( [ ] ) ,
495
+ non_linear_indices : HashSet :: new ( ) ,
496
+ } ,
497
+ AIQuery :: DropStore {
498
+ store : store_name,
499
+ error_if_not_exists : true ,
500
+ } ,
501
+ ] ) ;
502
+
503
+ let mut expected = AIServerResult :: with_capacity ( 3 ) ;
504
+
505
+ expected. push ( Ok ( AIServerResponse :: Unit ) ) ;
506
+ expected. push ( Ok ( AIServerResponse :: Unit ) ) ;
507
+ expected. push ( Ok ( AIServerResponse :: Del ( 1 ) ) ) ;
508
+
509
+ let mut reader = BufReader :: new ( first_stream) ;
510
+ query_server_assert_result ( & mut reader, message, expected) . await ;
511
+
512
+ // write flag should show that a write has occured
513
+ assert ! ( write_flag. load( Ordering :: SeqCst ) ) ;
514
+ // Allow some time for persistence to kick in
515
+ tokio:: time:: sleep ( Duration :: from_millis ( 200 ) ) . await ;
516
+ // start another server with existing persistence
517
+
518
+ let persisted_server = AIProxyServer :: new ( & AI_CONFIG_WITH_PERSISTENCE )
519
+ . await
520
+ . unwrap ( ) ;
521
+
522
+ // Allow some time for the server to start
523
+ tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) . await ;
524
+
525
+ let address = persisted_server
526
+ . local_addr ( )
527
+ . expect ( "Could not get local addr" ) ;
528
+ let write_flag = persisted_server. write_flag ( ) ;
529
+ let _ = tokio:: spawn ( async move { persisted_server. start ( ) . await } ) ;
530
+ let second_stream = TcpStream :: connect ( address) . await . unwrap ( ) ;
531
+ let mut reader = BufReader :: new ( second_stream) ;
532
+
533
+ let message = AIServerQuery :: from_queries ( & [ AIQuery :: ListStores ] ) ;
534
+
535
+ let mut expected = AIServerResult :: with_capacity ( 1 ) ;
536
+
537
+ expected. push ( Ok ( AIServerResponse :: StoreList ( HashSet :: from_iter ( [
538
+ AIStoreInfo {
539
+ name : store_name_2. clone ( ) ,
540
+ r#type : AIStoreType :: Binary ,
541
+ model : AIModel :: Llama3 ,
542
+ embedding_size : AIModel :: Llama3 . embedding_size ( ) . into ( ) ,
543
+ } ,
544
+ ] ) ) ) ) ;
545
+
546
+ query_server_assert_result ( & mut reader, message, expected) . await ;
547
+ assert ! ( !write_flag. load( Ordering :: SeqCst ) ) ;
548
+ // delete persistence
549
+ let _ = std:: fs:: remove_file ( & * PERSISTENCE_FILE ) ;
550
+ }
551
+
552
+ #[ tokio:: test]
553
+ async fn test_ai_proxy_destroy_database ( ) {
554
+ let address = provision_test_servers ( ) . await ;
555
+ let second_stream = TcpStream :: connect ( address) . await . unwrap ( ) ;
556
+ let store_name = StoreName ( String :: from ( "Deven Kicks" ) ) ;
557
+ let message = AIServerQuery :: from_queries ( & [
558
+ AIQuery :: CreateStore {
559
+ r#type : AIStoreType :: RawString ,
560
+ store : store_name. clone ( ) ,
561
+ model : AIModel :: Llama3 ,
562
+ predicates : HashSet :: from_iter ( [ ] ) ,
563
+ non_linear_indices : HashSet :: new ( ) ,
564
+ } ,
565
+ AIQuery :: ListStores ,
566
+ AIQuery :: PurgeStores ,
567
+ AIQuery :: ListStores ,
568
+ ] ) ;
569
+ let mut expected = AIServerResult :: with_capacity ( 4 ) ;
570
+
571
+ expected. push ( Ok ( AIServerResponse :: Unit ) ) ;
572
+ expected. push ( Ok ( AIServerResponse :: StoreList ( HashSet :: from_iter ( [
573
+ AIStoreInfo {
574
+ name : store_name,
575
+ r#type : AIStoreType :: RawString ,
576
+ model : AIModel :: Llama3 ,
577
+ embedding_size : AIModel :: Llama3 . embedding_size ( ) . into ( ) ,
578
+ } ,
579
+ ] ) ) ) ) ;
580
+ expected. push ( Ok ( AIServerResponse :: Del ( 1 ) ) ) ;
581
+ expected. push ( Ok ( AIServerResponse :: StoreList ( HashSet :: from_iter ( [ ] ) ) ) ) ;
582
+
583
+ let mut reader = BufReader :: new ( second_stream) ;
584
+ query_server_assert_result ( & mut reader, message, expected) . await
585
+ }
0 commit comments