Skip to content

Commit ae272ff

Browse files
authored
Merge pull request deven96#72 from deven96/diretnan/fix-persistence
Fixing persistence issues from predicate index
2 parents a0766ad + 6bf53d7 commit ae272ff

File tree

3 files changed

+53
-7
lines changed

3 files changed

+53
-7
lines changed

ahnlich/db/src/engine/predicate.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ use serde::Serialize;
1414
use std::collections::HashSet as StdHashSet;
1515
use std::mem::size_of_val;
1616

17-
type InnerPredicateIndex = ConcurrentHashMap<MetadataValue, ConcurrentHashSet<StoreKeyId>>;
17+
type InnerPredicateIndexVal = ConcurrentHashSet<StoreKeyId>;
18+
type InnerPredicateIndex = ConcurrentHashMap<MetadataValue, InnerPredicateIndexVal>;
1819
type InnerPredicateIndices = ConcurrentHashMap<MetadataKey, PredicateIndex>;
1920

2021
/// Predicate indices are all the indexes referenced by their names
@@ -200,7 +201,35 @@ impl PredicateIndices {
200201
/// ids. This is essential in helping us filter down the entire dataset using a predicate before
201202
/// performing similarity algorithmic search
202203
#[derive(Debug, Serialize, Deserialize)]
203-
struct PredicateIndex(InnerPredicateIndex);
204+
struct PredicateIndex(#[serde(with = "custom_metadata_map")] InnerPredicateIndex);
205+
206+
mod custom_metadata_map {
207+
use super::*;
208+
use serde::{self, Deserialize, Deserializer, Serialize, Serializer};
209+
210+
pub fn serialize<S>(map: &InnerPredicateIndex, serializer: S) -> Result<S::Ok, S::Error>
211+
where
212+
S: Serializer,
213+
{
214+
let vec: Vec<(MetadataValue, InnerPredicateIndexVal)> = map
215+
.iter(&map.guard())
216+
.map(|(k, v)| (k.clone(), v.clone()))
217+
.collect();
218+
vec.serialize(serializer)
219+
}
220+
221+
pub fn deserialize<'de, D>(deserializer: D) -> Result<InnerPredicateIndex, D::Error>
222+
where
223+
D: Deserializer<'de>,
224+
{
225+
let vec: Vec<(MetadataValue, InnerPredicateIndexVal)> = Vec::deserialize(deserializer)?;
226+
let map = ConcurrentHashMap::new();
227+
for (k, v) in vec {
228+
map.insert(k, v, &map.guard());
229+
}
230+
Ok(map)
231+
}
232+
}
204233

205234
impl PredicateIndex {
206235
fn size(&self) -> usize {

ahnlich/db/src/tests/server_test.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,13 @@ async fn test_server_with_persistence() {
421421
store: StoreName("Main".to_string()),
422422
inputs: vec![
423423
(StoreKey(array![1.0, 1.1, 1.2, 1.3]), HashMap::new()),
424-
(StoreKey(array![1.1, 1.2, 1.3, 1.4]), HashMap::new()),
424+
(
425+
StoreKey(array![1.1, 1.2, 1.3, 1.4]),
426+
HashMap::from_iter([(
427+
MetadataKey::new("medal".into()),
428+
MetadataValue::Binary(vec![1, 2, 3]),
429+
)]),
430+
),
425431
],
426432
},
427433
DBQuery::ListStores,
@@ -449,7 +455,7 @@ async fn test_server_with_persistence() {
449455
StoreInfo {
450456
name: StoreName("Main".to_string()),
451457
len: 2,
452-
size_in_bytes: 1880,
458+
size_in_bytes: 1936,
453459
},
454460
]))));
455461
expected.push(Err(
@@ -460,7 +466,7 @@ async fn test_server_with_persistence() {
460466
StoreInfo {
461467
name: StoreName("Main".to_string()),
462468
len: 1,
463-
size_in_bytes: 1808,
469+
size_in_bytes: 1864,
464470
},
465471
]))));
466472
let stream = TcpStream::connect(address).await.unwrap();
@@ -478,7 +484,11 @@ async fn test_server_with_persistence() {
478484
let address = server.local_addr().expect("Could not get local addr");
479485
let _ = tokio::spawn(async move { server.start().await });
480486
// Allow some time for the server to start
487+
let file_metadata =
488+
std::fs::metadata(&CONFIG_WITH_PERSISTENCE.persist_location.clone().unwrap()).unwrap();
489+
assert!(file_metadata.len() > 0, "The persistence file is empty");
481490
tokio::time::sleep(Duration::from_millis(100)).await;
491+
// check peristence was not overriden
482492
let message = ServerDBQuery::from_queries(&[
483493
// should error as store was loaded from previous persistence and main still exists
484494
DBQuery::CreateStore {
@@ -504,7 +514,10 @@ async fn test_server_with_persistence() {
504514
expected.push(Ok(ServerResponse::Del(0)));
505515
expected.push(Ok(ServerResponse::Get(vec![(
506516
StoreKey(array![1.1, 1.2, 1.3, 1.4]),
507-
HashMap::new(),
517+
HashMap::from_iter([(
518+
MetadataKey::new("medal".into()),
519+
MetadataValue::Binary(vec![1, 2, 3]),
520+
)]),
508521
)])));
509522
let stream = TcpStream::connect(address).await.unwrap();
510523
let mut reader = BufReader::new(stream);

ahnlich/utils/src/persistence.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use serde::de::DeserializeOwned;
22
use serde::Serialize;
33
use std::fs::File;
4+
use std::fs::OpenOptions;
45
use std::io::BufReader;
56
use std::path::Path;
67
use std::sync::atomic::AtomicBool;
@@ -42,7 +43,10 @@ impl<T: Serialize + DeserializeOwned> Persistence<T> {
4243
persist_location: &std::path::PathBuf,
4344
persist_object: T,
4445
) -> Self {
45-
let _ = File::create(persist_location)
46+
let _ = OpenOptions::new()
47+
.append(true)
48+
.create(true)
49+
.open(persist_location)
4650
.expect("Persistence enabled but could not open peristence file");
4751
Self {
4852
write_flag,

0 commit comments

Comments
 (0)