diff --git a/Cargo.toml b/Cargo.toml index 8b7ef64..20605f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,24 +16,23 @@ categories = ["algorithms", "data-structures"] [dependencies] bitpacking = "0.9.2" -bytecheck = { version = "~0.6.8", default-features = false, optional = true } num = "0.4.1" -rkyv = { version = "0.7.42", features = ["validation", "strict"], optional = true } -wyhash = "0.5.0" +rkyv = { version = "0.8", optional = true } +wyhash = "0.6" [dev-dependencies] bitvec = "1.0.1" -criterion = { version = "0.5.1", features = ["html_reports"] } +criterion = { version = "0.7", features = ["html_reports"] } paste = "1.0.14" proptest = "1.4.0" rand = "0.8.5" rand_chacha = "0.3.1" -rkyv = { version = "0.7.42", features = ["validation", "strict"] } +rustc-hash = "2" test-case = "3.3.1" [features] default = [] -rkyv_derive = ["rkyv", "bytecheck"] +rkyv_derive = ["rkyv"] [[bench]] name = "rank" @@ -44,6 +43,11 @@ name = "mphf" harness = false required-features = ["rkyv_derive"] +[[bench]] +name = "map" +harness = false +required-features = ["rkyv_derive"] + [[bench]] name = "map_with_dict" harness = false @@ -59,11 +63,8 @@ name = "set" harness = false required-features = ["rkyv_derive"] -[profile.bench] -debug = true - -[profile.release] -codegen-units = 1 -debug = true -lto = "fat" -opt-level = 3 +# [profile.release] +# codegen-units = 1 +# debug = true +# lto = "fat" +# opt-level = 3 diff --git a/benches/map.rs b/benches/map.rs new file mode 100644 index 0000000..30e78d2 --- /dev/null +++ b/benches/map.rs @@ -0,0 +1,107 @@ +use std::{collections::HashMap, env, hint::black_box, time::Instant}; + +use entropy_map::{ArchivedMap, Map}; + +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +/// Benchmark results for N = 1M: +/// +/// Map construction took: 13.64868523s +/// +/// Map/HashMap get +/// time: [18.350 ms 18.518 ms 18.707 ms] +/// thrpt: [53.455 Melem/s 54.001 Melem/s 54.496 Melem/s] +/// +/// Map/entropy get +/// time: [37.033 ms 37.293 ms 37.613 ms] +/// thrpt: [26.587 Melem/s 26.815 Melem/s 27.003 Melem/s] +/// +/// Map/HashMap archived get +/// time: [37.152 ms 37.373 ms 37.712 ms] +/// thrpt: [26.517 Melem/s 26.757 Melem/s 26.917 Melem/s] +/// +/// Map rkyv serialization took: 4.447392ms +/// +/// Map/entropy archived get +/// time: [40.613 ms 41.039 ms 41.563 ms] +/// thrpt: [24.060 Melem/s 24.367 Melem/s 24.623 Melem/s] +pub fn benchmark(c: &mut Criterion) { + let n: usize = env::var("N").unwrap_or("1000000".to_string()).parse().unwrap(); + let query_n: usize = env::var("QN").unwrap_or("1000000".to_string()).parse().unwrap(); + + let mut rng = ChaCha8Rng::seed_from_u64(123); + + let original_map: HashMap = (0..n) + .map(|_| { + let key = rng.gen::(); + let value = rng.gen::(); + (key, value) + }) + .collect(); + + // created with another hasher so the memory order is different to check random access + let hash_map: HashMap = HashMap::from_iter(original_map.clone()); + + let t0 = Instant::now(); + let map: Map = Map::from_iter_with_params(original_map.clone(), 2.4).unwrap(); + println!("Map construction took: {:?}", t0.elapsed()); + + let mut group = c.benchmark_group("Map"); + group.throughput(Throughput::Elements(query_n as u64)); + + group.bench_function("HashMap get", |b| { + b.iter(|| { + for key in original_map.keys().take(query_n) { + black_box(hash_map.get(key).unwrap()); + } + }); + }); + + group.bench_function("entropy get", |b| { + b.iter(|| { + for key in original_map.keys().take(query_n) { + black_box(map.get(key).unwrap()); + } + }); + }); + + let rkyv_bytes = rkyv::to_bytes::(&hash_map).unwrap(); + let rkyv_hash_map = rkyv::access::< + rkyv::collections::swiss_table::map::ArchivedHashMap, + rkyv::rancor::Error, + >(&rkyv_bytes) + .unwrap(); + + group.bench_function("HashMap archived get", |b| { + b.iter(|| { + for key in original_map.keys().take(query_n) { + black_box(rkyv_hash_map.get(key).unwrap()); + } + }); + }); + + let t0 = Instant::now(); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); + println!("Map rkyv serialization took: {:?}", t0.elapsed()); + + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); + + group.bench_function("entropy archived get", |b| { + b.iter(|| { + for key in original_map.keys().take(query_n) { + black_box(rkyv_map.get(key).unwrap()); + } + }); + }); + + group.finish(); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = benchmark, +} +criterion_main!(benches); diff --git a/benches/map_with_dict.rs b/benches/map_with_dict.rs index c7e960b..568c70d 100644 --- a/benches/map_with_dict.rs +++ b/benches/map_with_dict.rs @@ -1,67 +1,79 @@ -use std::collections::HashMap; -use std::env; -use std::time::Instant; +use std::{collections::HashMap, env, hint::black_box, time::Instant}; -use entropy_map::MapWithDict; +use entropy_map::{ArchivedMapWithDict, MapWithDict}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; /// Benchmark results for N = 1M: /// -/// map generation took: 55.309498ms -/// map_with_dict construction took: 1.411034205s -/// map_with_dict rkyv serialization took: 8.233451ms +/// MapWithDict construction took: 1.103815976s /// -/// # map_with_dict/get -/// time: [75.423 ms 75.814 ms 76.304 ms] -/// thrpt: [13.106 Melem/s 13.190 Melem/s 13.259 Melem/s] +/// MapWithDict/HashMap get +/// time: [18.856 ms 18.921 ms 18.994 ms] +/// thrpt: [52.650 Melem/s 52.850 Melem/s 53.033 Melem/s] /// -/// # map_with_dict/get-rkyv -/// time: [74.267 ms 74.681 ms 75.225 ms] -/// thrpt: [13.293 Melem/s 13.390 Melem/s 13.465 Melem/s] +/// MapWithDict/entropy get +/// time: [45.107 ms 45.406 ms 45.728 ms] +/// thrpt: [21.868 Melem/s 22.023 Melem/s 22.170 Melem/s] +/// +/// MapWithDict rkyv serialization took: 2.496905ms +/// +/// MapWithDict/entropy archived get +/// time: [40.738 ms 41.139 ms 41.575 ms] +/// thrpt: [24.053 Melem/s 24.308 Melem/s 24.547 Melem/s] pub fn benchmark(c: &mut Criterion) { let n: usize = env::var("N").unwrap_or("1000000".to_string()).parse().unwrap(); let query_n: usize = env::var("QN").unwrap_or("1000000".to_string()).parse().unwrap(); let mut rng = ChaCha8Rng::seed_from_u64(123); - let t0 = Instant::now(); let original_map: HashMap = (0..n) .map(|_| { let key = rng.gen::(); - let value = rng.gen_range(1..=10); + // let value = rng.gen_range(1..=10); + let value = rng.gen::(); (key, value) }) .collect(); - println!("map generation took: {:?}", t0.elapsed()); + + // created with another hasher so the memory order is different to check random access + let hash_map: HashMap = HashMap::from_iter(original_map.clone()); let t0 = Instant::now(); let map = MapWithDict::try_from(original_map.clone()).expect("failed to build map"); - println!("map_with_dict construction took: {:?}", t0.elapsed()); + println!("MapWithDict construction took: {:?}", t0.elapsed()); - let mut group = c.benchmark_group("map_with_dict"); + let mut group = c.benchmark_group("MapWithDict"); group.throughput(Throughput::Elements(query_n as u64)); - group.bench_function("get", |b| { + group.bench_function("HashMap get", |b| { + b.iter(|| { + for key in original_map.keys().take(query_n) { + black_box(hash_map.get(key).unwrap()); + } + }); + }); + + group.bench_function("entropy get", |b| { b.iter(|| { for key in original_map.keys().take(query_n) { - map.get(black_box(key)).unwrap(); + black_box(map.get(key).unwrap()); } }); }); let t0 = Instant::now(); - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap(); - println!("map_with_dict rkyv serialization took: {:?}", t0.elapsed()); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); + println!("MapWithDict rkyv serialization took: {:?}", t0.elapsed()); - let rkyv_map = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); - group.bench_function("get-rkyv", |b| { + group.bench_function("entropy archived get", |b| { b.iter(|| { for key in original_map.keys().take(query_n) { - rkyv_map.get(black_box(key)).unwrap(); + black_box(rkyv_map.get(key).unwrap()); } }); }); diff --git a/benches/map_with_dict_bitpacked.rs b/benches/map_with_dict_bitpacked.rs index 92121ec..fbaa069 100644 --- a/benches/map_with_dict_bitpacked.rs +++ b/benches/map_with_dict_bitpacked.rs @@ -1,33 +1,30 @@ -use std::collections::HashMap; -use std::env; -use std::time::Instant; +use std::{collections::HashMap, env, hint::black_box, time::Instant}; -use entropy_map::MapWithDictBitpacked; +use entropy_map::{ArchivedMapWithDictBitpacked, MapWithDictBitpacked}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; /// Benchmark results for N = 1M: /// -/// map generation took: 199.621887ms -/// map_with_dict_bitpacked construction took: 2.36439657s -/// map_with_dict_bitpacked rkyv serialization took: 20.455775ms +/// MapWithDictBitpacked construction took: 1.530962829s /// -/// # map_with_dict_bitpacked/get_values -/// time: [169.36 ms 170.24 ms 171.06 ms] -/// thrpt: [5.8459 Melem/s 5.8740 Melem/s 5.9044 Melem/s] +/// MapWithDictBitpacked/get_values +/// time: [95.556 ms 96.288 ms 97.068 ms] +/// thrpt: [10.302 Melem/s 10.385 Melem/s 10.465 Melem/s] /// -/// # map_with_dict_bitpacked/get_values-rkyv -/// time: [167.92 ms 168.82 ms 169.65 ms] -/// thrpt: [5.8946 Melem/s 5.9233 Melem/s 5.9553 Melem/s] +/// MapWithDictBitpacked rkyv serialization took: 4.85859ms +/// +/// MapWithDictBitpacked/archived get_values +/// time: [79.066 ms 79.977 ms 81.002 ms] +/// thrpt: [12.345 Melem/s 12.504 Melem/s 12.648 Melem/s] pub fn benchmark(c: &mut Criterion) { let n: usize = env::var("N").unwrap_or("1000000".to_string()).parse().unwrap(); let query_n: usize = env::var("QN").unwrap_or("1000000".to_string()).parse().unwrap(); let mut rng = ChaCha8Rng::seed_from_u64(123); - let t0 = Instant::now(); let mut values_buf = vec![0; 10]; let original_map: HashMap> = (0..n) .map(|_| { @@ -36,33 +33,32 @@ pub fn benchmark(c: &mut Criterion) { (key, value) }) .collect(); - println!("map generation took: {:?}", t0.elapsed()); let t0 = Instant::now(); let map = MapWithDictBitpacked::try_from(original_map.clone()).expect("failed to build map"); - println!("map_with_dict_bitpacked construction took: {:?}", t0.elapsed()); + println!("MapWithDictBitpacked construction took: {:?}", t0.elapsed()); - let mut group = c.benchmark_group("map_with_dict_bitpacked"); + let mut group = c.benchmark_group("MapWithDictBitpacked"); group.throughput(Throughput::Elements(query_n as u64)); group.bench_function("get_values", |b| { b.iter(|| { for key in original_map.keys().take(query_n) { - map.get_values(black_box(key), &mut values_buf); + black_box(map.get_values(key, &mut values_buf)); } }); }); let t0 = Instant::now(); - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap(); - println!("map_with_dict_bitpacked rkyv serialization took: {:?}", t0.elapsed()); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); + println!("MapWithDictBitpacked rkyv serialization took: {:?}", t0.elapsed()); - let rkyv_map = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); - group.bench_function("get-rkyv", |b| { + group.bench_function("archived get_values", |b| { b.iter(|| { for key in original_map.keys().take(query_n) { - rkyv_map.get_values(black_box(key), &mut values_buf); + black_box(rkyv_map.get_values(key, &mut values_buf)); } }); }); diff --git a/benches/mphf.rs b/benches/mphf.rs index 9cc3d00..150136b 100644 --- a/benches/mphf.rs +++ b/benches/mphf.rs @@ -1,74 +1,73 @@ -use std::env; -use std::time::Instant; +use std::{env, hint::black_box, time::Instant}; -use entropy_map::Mphf; +use entropy_map::{ArchivedMphf, Mphf}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use rand::random; /// # Benchmark results for N = 1M: /// -/// items generation took: 7.164763ms +/// Mphf (1.0) construction took: 1.291480619s, bits per key: 2.10 /// -/// # mphf/mphf-get/gamma-1.0 -/// mphf (1.0) construction took: 1.510804159s, bits per key = 2.10 -/// time: [14.326 ms 14.372 ms 14.427 ms] -/// thrpt: [69.315 Melem/s 69.582 Melem/s 69.803 Melem/s] +/// Mphf/get gamma: 1.0 +/// time: [26.144 ms 26.267 ms 26.412 ms] +/// thrpt: [37.862 Melem/s 38.071 Melem/s 38.250 Melem/s] /// -/// # mphf/rkyv-mphf-get/gamma-1.0 -/// mphf (1.0) rkyv serialization took: 128.191µs -/// time: [14.389 ms 14.413 ms 14.446 ms] -/// thrpt: [69.225 Melem/s 69.382 Melem/s 69.499 Melem/s] +/// Mphf (1.0) rkyv serialization took: 21.024µs /// -/// # mphf/mphf-get/gamma-2.0 -/// mphf (2.0) construction took: 1.188994719s, bits per key = 2.72 -/// time: [4.5842 ms 4.5959 ms 4.6084 ms] -/// thrpt: [217.00 Melem/s 217.59 Melem/s 218.14 Melem/s] +/// Mphf/archived get gamma: 1.0 +/// time: [26.309 ms 26.397 ms 26.520 ms] +/// thrpt: [37.707 Melem/s 37.883 Melem/s 38.010 Melem/s] /// -/// # mphf/rkyv-mphf-get/gamma-2.0 -/// mphf (2.0) rkyv serialization took: 165.901µs -/// time: [4.6885 ms 4.7272 ms 4.7728 ms] -/// thrpt: [209.52 Melem/s 211.54 Melem/s 213.29 Melem/s] +/// Mphf (2.0) construction took: 982.578471ms, bits per key: 2.72 +/// +/// Mphf/get gamma: 2.0 +/// time: [19.458 ms 19.683 ms 19.928 ms] +/// thrpt: [50.179 Melem/s 50.805 Melem/s 51.392 Melem/s] +/// +/// Mphf (2.0) rkyv serialization took: 24.643µs +/// +/// Mphf/archived get gamma: 2.0 +/// time: [19.901 ms 20.239 ms 20.663 ms] +/// thrpt: [48.396 Melem/s 49.411 Melem/s 50.250 Melem/s] pub fn benchmark(c: &mut Criterion) { let n: usize = env::var("N").unwrap_or("1000000".to_string()).parse().unwrap(); let query_n: usize = env::var("QN").unwrap_or("1000000".to_string()).parse().unwrap(); - let mut group = c.benchmark_group("mphf"); + let mut group = c.benchmark_group("Mphf"); group.throughput(Throughput::Elements(query_n as u64)); - let t0 = Instant::now(); let items: Vec = (0..n).map(|_| random()).collect(); - println!("items generation took: {:?}", t0.elapsed()); for &gamma in &[1.0_f32, 2.0_f32] { let t0 = Instant::now(); let mphf = Mphf::<32, 8>::from_slice(&items, gamma).expect("failed to build mphf"); let bits = (mphf.size() as f32) * 8.0 / (n as f32); println!( - "mphf ({:.1}) construction took: {:?}, bits per key: {:.2}", + "Mphf ({:.1}) construction took: {:?}, bits per key: {:.2}", gamma, t0.elapsed(), bits ); - group.bench_function(format!("mphf-get/gamma-{:.1}", gamma), |b| { + group.bench_function(format!("get gamma: {:.1}", gamma), |b| { b.iter(|| { for item in items.iter().take(query_n) { - mphf.get(black_box(item)).unwrap(); + black_box(mphf.get(item).unwrap()); } }); }); let t0 = Instant::now(); - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&mphf).unwrap(); - println!("mphf ({:.1}) rkyv serialization took: {:?}", gamma, t0.elapsed()); + let rkyv_bytes = rkyv::to_bytes::(&mphf).unwrap(); + println!("Mphf ({:.1}) rkyv serialization took: {:?}", gamma, t0.elapsed()); - let rkyv_mphf = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_mphf = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); - group.bench_function(format!("rkyv-mphf-get/gamma-{:.1}", gamma), |b| { + group.bench_function(format!("archived get gamma: {:.1}", gamma), |b| { b.iter(|| { for item in items.iter().take(query_n) { - rkyv_mphf.get(black_box(item)).unwrap(); + black_box(rkyv_mphf.get(item).unwrap()); } }); }); diff --git a/benches/rank.rs b/benches/rank.rs index c038ef8..a0f2fa3 100644 --- a/benches/rank.rs +++ b/benches/rank.rs @@ -1,46 +1,41 @@ -use std::env; -use std::time::Instant; +use std::{env, hint::black_box, time::Instant}; use entropy_map::{RankedBits, RankedBitsAccess}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; -use rand::prelude::SliceRandom; -use rand::random; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use rand::{prelude::SliceRandom, random}; /// Benchmark results for N = 1M: /// -/// indices generation took: 15.462759ms -/// ranked bits construction took: 21.978µs, overhead: 3.16% +/// RankedBits construction took: 9.904µs, overhead: 3.16% /// -/// # ranked_bits/rank -/// time: [616.89 µs 629.04 µs 643.46 µs] -/// thrpt: [1.5541 Gelem/s 1.5897 Gelem/s 1.6210 Gelem/s] +/// RankedBits/rank +/// time: [8.3597 ms 8.4021 ms 8.4608 ms] +/// thrpt: [118.19 Melem/s 119.02 Melem/s 119.62 Melem/s] pub fn benchmark(c: &mut Criterion) { let n: usize = env::var("N").unwrap_or("1000000".to_string()).parse().unwrap(); let query_n: usize = env::var("QN").unwrap_or("1000000".to_string()).parse().unwrap(); let n_u64 = n / 64; - let t0 = Instant::now(); let data: Vec = (0..n_u64).map(|_| random()).collect(); let mut indices: Vec = (0..n).collect(); indices.shuffle(&mut rand::thread_rng()); - println!("indices generation took: {:?}", t0.elapsed()); let t0 = Instant::now(); let ranked_bits = RankedBits::new(data.into_boxed_slice()); let overhead = ((ranked_bits.size() as f32) * 8.0 / (n as f32) - 1.0) * 100.0; println!( - "ranked bits construction took: {:?}, overhead: {:.2}%", + "RankedBits construction took: {:?}, overhead: {:.2}%", t0.elapsed(), overhead ); - let mut group = c.benchmark_group("ranked_bits"); + let mut group = c.benchmark_group("RankedBits"); group.throughput(Throughput::Elements(query_n as u64)); group.bench_function("rank", |b| { b.iter(|| { for &idx in indices.iter().take(query_n) { - ranked_bits.rank(black_box(idx)).unwrap_or_default(); + black_box(ranked_bits.rank(idx).unwrap_or_default()); } }); }); diff --git a/benches/set.rs b/benches/set.rs index d4654b7..ac5ed30 100644 --- a/benches/set.rs +++ b/benches/set.rs @@ -1,65 +1,99 @@ -use std::env; -use std::hash::{BuildHasherDefault, DefaultHasher}; -use std::time::Instant; -use std::{collections::HashSet, default}; +use std::{ + collections::HashSet, + env, + hash::{BuildHasherDefault, DefaultHasher}, + hint::black_box, + time::Instant, +}; use entropy_map::{Set, DEFAULT_GAMMA}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha8Rng; -use rkyv::collections; +/// Benchmark results for N = 1M: +/// +/// Set construction took: 1.022528566s +/// +/// Set/entropy contains +/// time: [26.171 ms 26.454 ms 26.795 ms] +/// thrpt: [37.320 Melem/s 37.802 Melem/s 38.210 Melem/s] +/// +/// Set/entropy-fxhash contains +/// time: [24.902 ms 25.265 ms 25.667 ms] +/// thrpt: [38.960 Melem/s 39.581 Melem/s 40.158 Melem/s] +/// +/// Set/entropy-DefaultHasher contains +/// time: [32.400 ms 32.839 ms 33.343 ms] +/// thrpt: [29.991 Melem/s 30.452 Melem/s 30.864 Melem/s] +/// +/// Set/HashSet-fxhash contains +/// time: [14.398 ms 14.704 ms 15.039 ms] +/// thrpt: [66.494 Melem/s 68.007 Melem/s 69.454 Melem/s] +/// +/// Set/HashSet-DefaultHasher contains +/// time: [34.512 ms 34.877 ms 35.292 ms] +/// thrpt: [28.335 Melem/s 28.673 Melem/s 28.975 Melem/s] pub fn benchmark(c: &mut Criterion) { let n: usize = env::var("N").unwrap_or("1000000".to_string()).parse().unwrap(); let query_n: usize = env::var("QN").unwrap_or("1000000".to_string()).parse().unwrap(); let mut rng = ChaCha8Rng::seed_from_u64(123); - let t0 = Instant::now(); let original_set: HashSet = (0..n).map(|_| rng.gen::()).collect(); - println!("set generation took: {:?}", t0.elapsed()); let t0 = Instant::now(); - let set = Set::try_from(original_set.clone()).expect("failed to build set"); - println!("set construction took: {:?}", t0.elapsed()); + let set = + Set::::from_iter_with_params(original_set.iter().cloned(), DEFAULT_GAMMA).expect("failed to build set"); + println!("Set construction took: {:?}", t0.elapsed()); - let mut group = c.benchmark_group("set"); + let mut group = c.benchmark_group("Set"); group.throughput(Throughput::Elements(query_n as u64)); - group.bench_function("entropy-contains-fxhash", |b| { + group.bench_function("entropy contains", |b| { + b.iter(|| { + for key in original_set.iter().take(query_n) { + black_box(set.contains(key)); + } + }); + }); + + let set_fxhash: Set = + Set::from_iter_with_params(original_set.iter().cloned(), DEFAULT_GAMMA).expect("failed to build set"); + group.bench_function("entropy-fxhash contains", |b| { b.iter(|| { for key in original_set.iter().take(query_n) { - set.contains(black_box(key)); + black_box(set_fxhash.contains(key)); } }); }); let set_default_hasher: Set = Set::from_iter_with_params(original_set.iter().cloned(), DEFAULT_GAMMA).expect("failed to build set"); - group.bench_function("entropy-contains-defaulthasher", |b| { + group.bench_function("entropy-DefaultHasher contains", |b| { b.iter(|| { for key in original_set.iter().take(query_n) { - set_default_hasher.contains(black_box(key)); + black_box(set_default_hasher.contains(key)); } }); }); - let fxhash_set: HashSet = HashSet::from_iter(original_set.iter().cloned()); - group.bench_function("std-contains-fxhash", |b| { + let fxhash_set: HashSet = HashSet::from_iter(original_set.iter().cloned()); + group.bench_function("HashSet-fxhash contains", |b| { b.iter(|| { for key in original_set.iter().take(query_n) { - fxhash_set.contains(black_box(key)); + black_box(fxhash_set.contains(key)); } }); }); let defaulthasher_set: HashSet> = HashSet::from_iter(original_set.iter().cloned()); - group.bench_function("std-contains-defaulthasher", |b| { + group.bench_function("HashSet-DefaultHasher contains", |b| { b.iter(|| { for key in original_set.iter().take(query_n) { - defaulthasher_set.contains(black_box(key)); + black_box(defaulthasher_set.contains(key)); } }); }); diff --git a/examples/map_with_dict.rs b/examples/map_with_dict.rs index 210058d..fd93783 100644 --- a/examples/map_with_dict.rs +++ b/examples/map_with_dict.rs @@ -20,9 +20,11 @@ fn main() { #[cfg(feature = "rkyv_derive")] { + use entropy_map::ArchivedMapWithDict; + // Serialize map to rkyv and test again - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap(); - let rkyv_map = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); assert_eq!(rkyv_map.get(&1).unwrap(), &"Dog".to_string()); assert_eq!(rkyv_map.get(&2).unwrap(), &"Cat".to_string()); diff --git a/examples/map_with_dict_bitpacked.rs b/examples/map_with_dict_bitpacked.rs index 096ea30..40a548a 100644 --- a/examples/map_with_dict_bitpacked.rs +++ b/examples/map_with_dict_bitpacked.rs @@ -23,9 +23,11 @@ fn main() { #[cfg(feature = "rkyv_derive")] { + use entropy_map::ArchivedMapWithDictBitpacked; + // Serialize map to rkyv and test again - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap(); - let rkyv_map = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); assert!(rkyv_map.get_values(&1, &mut values_buf)); assert_eq!(values_buf, vec![1, 2, 3]); diff --git a/examples/mphf.rs b/examples/mphf.rs index 5039813..d31f012 100644 --- a/examples/mphf.rs +++ b/examples/mphf.rs @@ -15,9 +15,11 @@ fn main() { #[cfg(feature = "rkyv_derive")] { + use entropy_map::ArchivedMphf; + // Serialize mphf to rkyv and test again - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&mphf).unwrap(); - let rkyv_mphf = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&mphf).unwrap(); + let rkyv_mphf = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); assert!(rkyv_mphf.get(&1).is_some()); assert!(rkyv_mphf.get(&5).is_some()); diff --git a/examples/set.rs b/examples/set.rs index 00ebc5d..6319a21 100644 --- a/examples/set.rs +++ b/examples/set.rs @@ -19,8 +19,10 @@ fn main() { #[cfg(feature = "rkyv_derive")] { - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&set).unwrap(); - let rkyv = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + use entropy_map::ArchivedSet; + + let rkyv_bytes = rkyv::to_bytes::(&set).unwrap(); + let rkyv = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); assert!(rkyv.contains(&1)); assert!(rkyv.contains(&2)); diff --git a/rustfmt.toml b/rustfmt.toml index 9998926..ddfd589 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,2 +1,3 @@ max_width = 120 -struct_lit_width = 80 \ No newline at end of file +struct_lit_width = 80 +imports_granularity = "Crate" diff --git a/src/lib.rs b/src/lib.rs index 76ab64c..9838f1b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,97 @@ +pub mod map; pub mod map_with_dict; pub mod map_with_dict_bitpacked; pub mod mphf; pub mod rank; pub mod set; +pub use map::*; pub use map_with_dict::*; pub use map_with_dict_bitpacked::*; pub use mphf::*; pub use rank::*; pub use set::*; + +pub trait IntoGroupSeed: Copy { + fn into_u32(self) -> u32; +} + +impl IntoGroupSeed for u8 { + #[inline(always)] + fn into_u32(self) -> u32 { + self as u32 + } +} + +impl IntoGroupSeed for u16 { + #[inline(always)] + fn into_u32(self) -> u32 { + self as u32 + } +} + +#[cfg(feature = "rkyv_derive")] +impl IntoGroupSeed for rkyv::rend::u16_le { + #[inline(always)] + fn into_u32(self) -> u32 { + self.to_native() as u32 + } +} + +#[cfg(feature = "rkyv_derive")] +impl IntoGroupSeed for rkyv::rend::u16_be { + #[inline(always)] + fn into_u32(self) -> u32 { + self.to_native() as u32 + } +} + +impl IntoGroupSeed for u32 { + #[inline(always)] + fn into_u32(self) -> u32 { + self + } +} + +#[cfg(feature = "rkyv_derive")] +impl IntoGroupSeed for rkyv::rend::u32_le { + #[inline(always)] + fn into_u32(self) -> u32 { + self.to_native() + } +} + +#[cfg(feature = "rkyv_derive")] +impl IntoGroupSeed for rkyv::rend::u32_be { + #[inline(always)] + fn into_u32(self) -> u32 { + self.to_native() + } +} + +pub trait IntoRankBits: Copy { + fn into_u64(self) -> u64; +} + +impl IntoRankBits for u64 { + #[inline(always)] + fn into_u64(self) -> u64 { + self + } +} + +#[cfg(feature = "rkyv_derive")] +impl IntoRankBits for rkyv::rend::u64_le { + #[inline(always)] + fn into_u64(self) -> u64 { + self.to_native() + } +} + +#[cfg(feature = "rkyv_derive")] +impl IntoRankBits for rkyv::rend::u64_be { + #[inline(always)] + fn into_u64(self) -> u64 { + self.to_native() + } +} diff --git a/src/map.rs b/src/map.rs new file mode 100644 index 0000000..d1c1a0a --- /dev/null +++ b/src/map.rs @@ -0,0 +1,510 @@ +//! A module providing `Map`, an immutable hash map implementation. +//! +//! `Map` is a hash map structure that optimizes for space by utilizing a minimal perfect +//! hash function (MPHF) for indexing the map's keys. +//! The MPHF provides direct access to the indices of keys. +//! Keys are stored to ensure that `get` operation will return `None` if key +//! wasn't present in the original set. + +use std::{ + borrow::Borrow, + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + mem::size_of_val, +}; + +use num::{PrimInt, Unsigned}; +use wyhash::WyHash; + +use crate::{ + mphf::{Mphf, MphfError, DEFAULT_GAMMA}, + IntoGroupSeed, +}; + +/// An efficient, immutable hash map. +#[derive(Default)] +#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))] +pub struct Map +where + ST: PrimInt + Unsigned, + H: Hasher + Default, +{ + /// Minimally Perfect Hash Function for keys indices retrieval + mphf: Mphf, + keys_vals: Box<[(K, V)]>, +} + +impl Map +where + K: Hash, + ST: PrimInt + Unsigned + IntoGroupSeed, + H: Hasher + Default, +{ + /// Constructs a `Map` from an iterator of key-value pairs and MPHF function params. + pub fn from_iter_with_params(iter: I, gamma: f32) -> Result + where + I: IntoIterator, + { + let mut keys_vals: Vec<_> = iter.into_iter().collect(); + + let mphf = Mphf::from_iter(keys_vals.iter().map(|(k, _v)| k), gamma)?; + + // Re-order `keys` and `values_index` according to `mphf` + for i in 0..keys_vals.len() { + loop { + let idx = mphf.get(&keys_vals[i].0).unwrap(); + if idx == i { + break; + } + keys_vals.swap(i, idx); + } + } + + Ok(Self { mphf, keys_vals: keys_vals.into_boxed_slice() }) + } + + /// Returns a reference to the value corresponding to the key. Returns `None` if the key is + /// not present in the map. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// assert_eq!(map.get(&1), Some(&2)); + /// assert_eq!(map.get(&5), None); + /// ``` + #[inline] + pub fn get(&self, key: &Q) -> Option<&V> + where + K: Borrow + PartialEq, + Q: Hash + Eq + ?Sized, + { + let idx = self.mphf.get(key)?; + + // SAFETY: `idx` is always within bounds (ensured during construction) + unsafe { + let (k, v) = self.keys_vals.get_unchecked(idx); + if k == key { + Some(v) + } else { + None + } + } + } + + /// Returns the number of key-value pairs in the map. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// assert_eq!(map.len(), 2); + /// ``` + #[inline] + pub fn len(&self) -> usize { + self.keys_vals.len() + } + + /// Returns `true` if the map contains no elements. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(0, 0); 0])).unwrap(); + /// assert_eq!(map.is_empty(), true); + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// assert_eq!(map.is_empty(), false); + /// ``` + #[inline] + pub fn is_empty(&self) -> bool { + self.keys_vals.is_empty() + } + + /// Checks if the map contains the specified key. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// assert_eq!(map.contains_key(&1), true); + /// assert_eq!(map.contains_key(&2), false); + /// ``` + #[inline] + pub fn contains_key(&self, key: &Q) -> bool + where + K: Borrow + PartialEq, + Q: Hash + Eq + ?Sized, + { + if let Some(idx) = self.mphf.get(key) { + // SAFETY: `idx` is always within bounds (ensured during construction) + unsafe { &self.keys_vals.get_unchecked(idx).0 == key } + } else { + false + } + } + + /// Returns an iterator over the map, yielding key-value pairs. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// for (key, val) in map.iter() { + /// println!("key: {key} val: {val}"); + /// } + /// ``` + #[inline] + pub fn iter(&self) -> impl Iterator { + self.keys_vals.iter() + } + + /// Returns an iterator over the keys of the map. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// for key in map.keys() { + /// println!("{key}"); + /// } + /// ``` + #[inline] + pub fn keys(&self) -> impl Iterator { + self.keys_vals.iter().map(|(k, _v)| k) + } + + /// Returns an iterator over the values of the map. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// for val in map.values() { + /// println!("{val}"); + /// } + /// ``` + #[inline] + pub fn values(&self) -> impl Iterator { + self.keys_vals.iter().map(|(_k, v)| v) + } + + /// Returns the total number of bytes occupied by the structure. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// assert_eq!(map.size(), 222); + /// ``` + #[inline] + pub fn size(&self) -> usize { + size_of_val(self) + self.mphf.size() + size_of_val(self.keys_vals.as_ref()) + } +} + +/// Creates a `Map` from a `HashMap`. +impl TryFrom> for Map +where + K: Hash, + B: BuildHasher, +{ + type Error = MphfError; + + #[inline] + fn try_from(value: HashMap) -> Result { + Self::from_iter_with_params(value, DEFAULT_GAMMA) + } +} + +/// Implement `get` for `Archived` version of `Map` if feature is enabled +#[cfg(feature = "rkyv_derive")] +impl ArchivedMap +where + K: PartialEq + Hash + rkyv::Archive, + K::Archived: PartialEq, + V: rkyv::Archive, + ST: PrimInt + Unsigned + rkyv::Archive, + ::Archived: IntoGroupSeed, + H: Hasher + Default, +{ + /// Checks if the map contains the specified key. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::ArchivedMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// let bytes = rkyv::to_bytes::(&map).unwrap(); + /// let archived_map = rkyv::access::, rkyv::rancor::Error>(&bytes).unwrap(); + /// assert_eq!(archived_map.contains_key(&1), true); + /// assert_eq!(archived_map.contains_key(&2), false); + /// ``` + #[inline] + pub fn contains_key(&self, key: &Q) -> bool + where + K: Borrow, + ::Archived: PartialEq, + Q: Hash + Eq + ?Sized, + { + if let Some(idx) = self.mphf.get(key) { + // SAFETY: `idx` is always within bounds (ensured during construction) + let rkyv::tuple::ArchivedTuple2(k, _v) = unsafe { self.keys_vals.get_unchecked(idx) }; + + k == key + } else { + false + } + } + + /// Returns a reference to the value corresponding to the key. Returns `None` if the key is + /// not present in the map. + /// + /// # Examples + /// ``` + /// # use std::collections::HashMap; + /// # use entropy_map::ArchivedMap; + /// # use entropy_map::Map; + /// let map = Map::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); + /// let bytes = rkyv::to_bytes::(&map).unwrap(); + /// let archived_map = rkyv::access::, rkyv::rancor::Error>(&bytes).unwrap(); + /// assert_eq!(archived_map.get(&1).map(|v| v.to_native()), Some(2)); + /// assert_eq!(archived_map.get(&5).map(|v| v.to_native()), None); + /// ``` + #[inline] + pub fn get(&self, key: &Q) -> Option<&V::Archived> + where + K: Borrow, + ::Archived: PartialEq, + Q: Hash + Eq + ?Sized, + { + let idx = self.mphf.get(key)?; + + // SAFETY: `idx` is always within bounds (ensured during construction) + unsafe { + let rkyv::tuple::ArchivedTuple2(k, v) = self.keys_vals.get_unchecked(idx); + if k == key { + Some(v) + } else { + None + } + } + } + + /// Returns an iterator over the archived map, yielding archived key-value pairs. + #[inline] + pub fn iter(&self) -> impl Iterator::Archived> { + self.keys_vals.iter() + } + + /// Returns the number of key-value pairs in the map. + #[inline] + pub fn len(&self) -> usize { + self.keys_vals.len() + } + + /// Returns `true` if the map contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.keys_vals.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use paste::paste; + use proptest::prelude::*; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha8Rng; + use std::collections::{hash_map::RandomState, HashSet}; + + fn gen_map(items_num: usize) -> HashMap { + let mut rng = ChaCha8Rng::seed_from_u64(123); + + (0..items_num) + .map(|_| { + let key = rng.gen::(); + let value = rng.gen_range(1..=10); + (key, value) + }) + .collect() + } + + #[test] + fn test_map() { + // Collect original key-value pairs directly into a HashMap + let original_map = gen_map(1000); + + // Create the map from the iterator + let map: Map = Map::from_iter_with_params(original_map.clone(), DEFAULT_GAMMA).unwrap(); + + // Test len + assert_eq!(map.len(), original_map.len()); + + // Test is_empty + assert_eq!(map.is_empty(), original_map.is_empty()); + + // Test get, contains_key + for (key, value) in &original_map { + assert_eq!(map.get(key), Some(value)); + assert!(map.contains_key(key)); + } + + // Test iter + for (k, v) in map.iter() { + assert_eq!(original_map.get(k), Some(v)); + } + + // Test keys + for k in map.keys() { + assert!(original_map.contains_key(k)); + } + + // Test values + for &v in map.values() { + assert!(original_map.values().any(|&val| val == v)); + } + + // Test size + assert_eq!(map.size(), 16530); + } + + /// Assert that we can call `.get()` with `K::borrow()`. + #[test] + fn test_get_borrow() { + let original_map: HashMap = + HashMap::from_iter([("a".to_string(), ()), ("b".to_string(), ())]); + let map = Map::try_from(original_map).unwrap(); + + assert_eq!(map.get("a"), Some(&())); + assert!(map.contains_key("a")); + assert_eq!(map.get("b"), Some(&())); + assert!(map.contains_key("b")); + assert_eq!(map.get("c"), None); + assert!(!map.contains_key("c")); + } + + #[cfg(feature = "rkyv_derive")] + #[test] + fn test_rkyv() { + // create regular `HashMap`, then `Map`, then serialize to `rkyv` bytes. + let original_map = gen_map(1000); + let map: Map = Map::from_iter_with_params(original_map.clone(), DEFAULT_GAMMA).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); + + assert_eq!(rkyv_bytes.len(), 16424); + + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); + + // Test get on `Archived` version + for (k, v) in original_map.iter() { + assert_eq!(v, rkyv_map.get(k).unwrap()); + } + + // Test iter on `Archived` version + for rkyv::tuple::ArchivedTuple2(k, v) in rkyv_map.iter() { + assert_eq!(original_map.get(&k.to_native()), Some(&v.to_native())); + } + } + + #[cfg(feature = "rkyv_derive")] + #[test] + fn test_rkyv_get_borrow() { + let original_map: HashMap = + HashMap::from_iter([("a".to_string(), ()), ("b".to_string(), ())]); + let map = Map::try_from(original_map).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); + + assert_eq!(map.get("a"), Some(&())); + assert!(rkyv_map.contains_key("a")); + assert_eq!(map.get("b"), Some(&())); + assert!(rkyv_map.contains_key("b")); + assert_eq!(map.get("c"), None); + assert!(!rkyv_map.contains_key("c")); + } + + macro_rules! proptest_map_model { + ($(($b:expr, $s:expr, $gamma:expr)),* $(,)?) => { + $( + paste! { + proptest! { + #[test] + fn [](model: HashMap, arbitrary: HashSet) { + let entropy_map: Map = Map::from_iter_with_params( + model.clone(), + $gamma as f32 / 100.0 + ).unwrap(); + + // Assert that length matches model. + assert_eq!(entropy_map.len(), model.len()); + assert_eq!(entropy_map.is_empty(), model.is_empty()); + + // Assert that keys and values match model. + assert_eq!( + HashSet::<_, RandomState>::from_iter(entropy_map.keys()), + HashSet::from_iter(model.keys()) + ); + assert_eq!( + HashSet::<_, RandomState>::from_iter(entropy_map.values()), + HashSet::from_iter(model.values()) + ); + + // Assert that contains and get operations match model for contained elements. + for (k, v) in &model { + assert!(entropy_map.contains_key(&k)); + assert_eq!(entropy_map.get(&k), Some(v)); + } + + // Assert that contains and get operations match model for random elements. + for k in arbitrary { + assert_eq!( + model.contains_key(&k), + entropy_map.contains_key(&k), + ); + assert_eq!(entropy_map.get(&k), model.get(&k)); + } + } + } + } + )* + }; + } + + proptest_map_model!( + // (1, 8, 100), + (2, 8, 100), + (4, 8, 100), + (7, 8, 100), + (8, 8, 100), + (15, 8, 100), + (16, 8, 100), + (23, 8, 100), + (24, 8, 100), + (31, 8, 100), + (32, 8, 100), + (33, 8, 100), + (48, 8, 100), + (53, 8, 100), + (61, 8, 100), + (63, 8, 100), + (64, 8, 100), + (32, 7, 100), + (32, 5, 100), + (32, 4, 100), + (32, 3, 100), + (32, 1, 100), + (32, 0, 100), + (32, 8, 200), + (32, 6, 200), + ); +} diff --git a/src/map_with_dict.rs b/src/map_with_dict.rs index 66fdf83..637e0b1 100644 --- a/src/map_with_dict.rs +++ b/src/map_with_dict.rs @@ -5,22 +5,26 @@ //! as it reduces the overall memory footprint by packing unique values into a dictionary. The MPHF //! provides direct access to the indices of keys, which correspond to their respective values in //! the values dictionary. Keys are stored to ensure that `get` operation will return `None` if key -//! wasn't present in original set. +//! wasn't present in the original set. -use std::borrow::Borrow; -use std::collections::HashMap; -use std::hash::{Hash, Hasher}; -use std::mem::size_of_val; +use std::{ + borrow::Borrow, + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + mem::size_of_val, +}; use num::{PrimInt, Unsigned}; use wyhash::WyHash; -use crate::mphf::{Mphf, MphfError, DEFAULT_GAMMA}; +use crate::{ + mphf::{Mphf, MphfError, DEFAULT_GAMMA}, + IntoGroupSeed, +}; /// An efficient, immutable hash map with values dictionary-packed for optimized space usage. #[derive(Default)] #[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))] -#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))] pub struct MapWithDict where ST: PrimInt + Unsigned, @@ -30,7 +34,9 @@ where mphf: Mphf, /// Map keys keys: Box<[K]>, - /// Points to the value index in the dictionary + /// Points to the value index in the dictionary. + /// If rkyv pointer width feature is not enabled, it will serialize usize as 32-bit integers by default. + /// So it limits the max amount of values if you use archived MapWithDict. values_index: Box<[usize]>, /// Map unique values values_dict: Box<[V]>, @@ -38,9 +44,9 @@ where impl MapWithDict where - K: Eq + Hash + Clone, + K: Hash, V: Eq + Clone + Hash, - ST: PrimInt + Unsigned, + ST: PrimInt + Unsigned + IntoGroupSeed, H: Hasher + Default, { /// Constructs a `MapWithDict` from an iterator of key-value pairs and MPHF function params. @@ -54,7 +60,7 @@ where let mut offsets_cache = HashMap::new(); for (k, v) in iter { - keys.push(k.clone()); + keys.push(k); if let Some(&offset) = offsets_cache.get(&v) { // re-use dictionary offset if found in cache @@ -64,7 +70,7 @@ where let offset = values_dict.len(); offsets_cache.insert(v.clone(), offset); values_index.push(offset); - values_dict.push(v.clone()); + values_dict.push(v); } } @@ -82,7 +88,7 @@ where } } - Ok(MapWithDict { + Ok(Self { mphf, keys: keys.into_boxed_slice(), values_index: values_index.into_boxed_slice(), @@ -109,10 +115,9 @@ where { let idx = self.mphf.get(key)?; - // SAFETY: `idx` is always within bounds (ensured during construction) + // SAFETY: `idx` and `value_idx` are always within bounds (ensured during construction) unsafe { if self.keys.get_unchecked(idx) == key { - // SAFETY: `idx` and `value_idx` are always within bounds (ensure during construction) let value_idx = *self.values_index.get_unchecked(idx); Some(self.values_dict.get_unchecked(value_idx)) } else { @@ -253,16 +258,17 @@ where } /// Creates a `MapWithDict` from a `HashMap`. -impl TryFrom> for MapWithDict +impl TryFrom> for MapWithDict where - K: Eq + Hash + Clone, + K: Hash, V: Eq + Clone + Hash, + B: BuildHasher, { type Error = MphfError; #[inline] - fn try_from(value: HashMap) -> Result { - MapWithDict::::from_iter_with_params(value, DEFAULT_GAMMA) + fn try_from(value: HashMap) -> Result { + Self::from_iter_with_params(value, DEFAULT_GAMMA) } } @@ -273,7 +279,8 @@ where K: PartialEq + Hash + rkyv::Archive, K::Archived: PartialEq, V: rkyv::Archive, - ST: PrimInt + Unsigned + rkyv::Archive, + ST: PrimInt + Unsigned + rkyv::Archive, + ::Archived: IntoGroupSeed, H: Hasher + Default, { /// Checks if the map contains the specified key. @@ -281,20 +288,20 @@ where /// # Examples /// ``` /// # use std::collections::HashMap; + /// # use entropy_map::ArchivedMapWithDict; /// # use entropy_map::MapWithDict; /// let map = MapWithDict::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); - /// let archived_map = rkyv::from_bytes::>( - /// &rkyv::to_bytes::<_, 1024>(&map).unwrap() - /// ).unwrap(); + /// let bytes = rkyv::to_bytes::(&map).unwrap(); + /// let archived_map = rkyv::access::, rkyv::rancor::Error>(&bytes).unwrap(); /// assert_eq!(archived_map.contains_key(&1), true); /// assert_eq!(archived_map.contains_key(&2), false); /// ``` #[inline] - pub fn contains_key(&self, key: &Q) -> bool + pub fn contains_key(&self, key: &Q) -> bool where K: Borrow, ::Archived: PartialEq, - Q: Hash + Eq, + Q: Hash + Eq + ?Sized, { if let Some(idx) = self.mphf.get(key) { // SAFETY: `idx` is always within bounds (ensured during construction) @@ -310,28 +317,27 @@ where /// # Examples /// ``` /// # use std::collections::HashMap; + /// # use entropy_map::ArchivedMapWithDict; /// # use entropy_map::MapWithDict; /// let map = MapWithDict::try_from(HashMap::from([(1, 2), (3, 4)])).unwrap(); - /// let archived_map = rkyv::from_bytes::>( - /// &rkyv::to_bytes::<_, 1024>(&map).unwrap() - /// ).unwrap(); - /// assert_eq!(archived_map.get(&1), Some(&2)); - /// assert_eq!(archived_map.get(&5), None); + /// let bytes = rkyv::to_bytes::(&map).unwrap(); + /// let archived_map = rkyv::access::, rkyv::rancor::Error>(&bytes).unwrap(); + /// assert_eq!(archived_map.get(&1).map(|v| v.to_native()), Some(2)); + /// assert_eq!(archived_map.get(&5).map(|v| v.to_native()), None); /// ``` #[inline] - pub fn get(&self, key: &Q) -> Option<&V::Archived> + pub fn get(&self, key: &Q) -> Option<&V::Archived> where K: Borrow, ::Archived: PartialEq, - Q: Hash + Eq, + Q: Hash + Eq + ?Sized, { let idx = self.mphf.get(key)?; - // SAFETY: `idx` is always within bounds (ensured during construction) + // SAFETY: `idx` and `value_idx` are always within bounds (ensured during construction) unsafe { if self.keys.get_unchecked(idx) == key { - // SAFETY: `idx` and `value_idx` are always within bounds (ensure during construction) - let value_idx = *self.values_index.get_unchecked(idx) as usize; + let value_idx = self.values_index.get_unchecked(idx).to_native() as usize; Some(self.values_dict.get_unchecked(value_idx)) } else { None @@ -347,7 +353,7 @@ where .zip(self.values_index.iter()) .map(move |(key, &value_idx)| { // SAFETY: `value_idx` is always within bounds (ensured during construction) - let value = unsafe { self.values_dict.get_unchecked(value_idx as usize) }; + let value = unsafe { self.values_dict.get_unchecked(value_idx.to_native() as usize) }; (key, value) }) } @@ -416,7 +422,8 @@ mod tests { /// Assert that we can call `.get()` with `K::borrow()`. #[test] fn test_get_borrow() { - let original_map = HashMap::from_iter([("a".to_string(), ()), ("b".to_string(), ())]); + let original_map: HashMap = + HashMap::from_iter([("a".to_string(), ()), ("b".to_string(), ())]); let map = MapWithDict::try_from(original_map).unwrap(); assert_eq!(map.get("a"), Some(&())); @@ -433,11 +440,11 @@ mod tests { // create regular `HashMap`, then `MapWithDict`, then serialize to `rkyv` bytes. let original_map = gen_map(1000); let map = MapWithDict::try_from(original_map.clone()).unwrap(); - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); - assert_eq!(rkyv_bytes.len(), 12464); + assert_eq!(rkyv_bytes.len(), 12480); - let rkyv_map = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); // Test get on `Archived` version for (k, v) in original_map.iter() { @@ -445,18 +452,19 @@ mod tests { } // Test iter on `Archived` version - for (&k, &v) in rkyv_map.iter() { - assert_eq!(original_map.get(&k), Some(&v)); + for (k, v) in rkyv_map.iter() { + assert_eq!(original_map.get(&k.to_native()), Some(&v.to_native())); } } #[cfg(feature = "rkyv_derive")] #[test] fn test_rkyv_get_borrow() { - let original_map = HashMap::from_iter([("a".to_string(), ()), ("b".to_string(), ())]); + let original_map: HashMap = + HashMap::from_iter([("a".to_string(), ()), ("b".to_string(), ())]); let map = MapWithDict::try_from(original_map).unwrap(); - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap(); - let rkyv_map = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); assert_eq!(map.get("a"), Some(&())); assert!(rkyv_map.contains_key("a")); diff --git a/src/map_with_dict_bitpacked.rs b/src/map_with_dict_bitpacked.rs index 203f176..d74eb48 100644 --- a/src/map_with_dict_bitpacked.rs +++ b/src/map_with_dict_bitpacked.rs @@ -11,21 +11,25 @@ //! stored in the byte dictionary. Keys are maintained for validation during retrieval. A `get` //! query for a non-existent key at construction returns `false`, similar to `MapWithDict`. -use std::borrow::Borrow; -use std::collections::HashMap; -use std::hash::{Hash, Hasher}; -use std::mem::size_of_val; +use std::{ + borrow::Borrow, + collections::HashMap, + hash::{BuildHasher, Hash, Hasher}, + mem::size_of_val, +}; use bitpacking::{BitPacker, BitPacker1x}; use num::{PrimInt, Unsigned}; use wyhash::WyHash; -use crate::mphf::{Mphf, DEFAULT_GAMMA}; +use crate::{ + mphf::{Mphf, DEFAULT_GAMMA}, + IntoGroupSeed, +}; /// An efficient, immutable hash map with bit-packed `Vec` values for optimized space usage. #[derive(Default)] #[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))] -#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))] pub struct MapWithDictBitpacked where ST: PrimInt + Unsigned, @@ -52,8 +56,8 @@ pub enum Error { impl MapWithDictBitpacked where - K: Hash + PartialEq + Clone, - ST: PrimInt + Unsigned, + K: Hash, + ST: PrimInt + Unsigned + IntoGroupSeed, H: Hasher + Default, { /// Constructs a `MapWithDictBitpacked` from an iterator of key-value pairs and MPHF function params. @@ -70,7 +74,7 @@ where let v_len = iter.peek().map_or(0, |(_, v)| v.len()); for (k, v) in iter { - keys.push(k.clone()); + keys.push(k); if v.len() != v_len { return Err(Error::NotEqualValuesLengths); @@ -80,13 +84,14 @@ where // re-use dictionary offset if found in cache values_index.push(offset); } else { - // store current dictionary length as an offset in both index and cache let offset = values_dict.len(); - offsets_cache.insert(v.clone(), offset); - values_index.push(offset); // append packed values to the dictionary pack_values(&v, &mut values_dict); + + // store dictionary length as an offset in both index and cache + offsets_cache.insert(v, offset); + values_index.push(offset); } } @@ -139,13 +144,12 @@ where None => return false, }; - // SAFETY: `idx` is always within bounds (ensured during construction) + // SAFETY: `idx` and `value_idx` are always within bounds (ensured during construction) unsafe { if self.keys.get_unchecked(idx) != key { return false; } - // SAFETY: `idx` and `value_idx` are always within bounds (ensure during construction) let value_idx = *self.values_index.get_unchecked(idx); let dict = self.values_dict.get_unchecked(value_idx..); unpack_values(dict, values); @@ -287,14 +291,15 @@ where } /// Creates a `MapWithDictBitpacked` from a `HashMap`. -impl TryFrom>> for MapWithDictBitpacked +impl TryFrom, B>> for MapWithDictBitpacked where - K: PartialEq + Hash + Clone, + K: Hash, + B: BuildHasher, { type Error = Error; #[inline] - fn try_from(value: HashMap>) -> Result { + fn try_from(value: HashMap, B>) -> Result { MapWithDictBitpacked::from_iter_with_params(value, DEFAULT_GAMMA) } } @@ -354,7 +359,8 @@ impl ArchivedMapWithDictBitpacked, - ST: PrimInt + Unsigned + rkyv::Archive, + ST: PrimInt + Unsigned + rkyv::Archive, + ::Archived: IntoGroupSeed, H: Hasher + Default, { /// Updates `values` to the array of values corresponding to the key. Returns `false` if the @@ -363,11 +369,11 @@ where /// # Examples /// ``` /// # use std::collections::HashMap; + /// # use entropy_map::ArchivedMapWithDictBitpacked; /// # use entropy_map::MapWithDictBitpacked; /// let map = MapWithDictBitpacked::try_from(HashMap::from([(1, vec![2]), (3, vec![4])])).unwrap(); - /// let archived_map = rkyv::from_bytes::>( - /// &rkyv::to_bytes::<_, 1024>(&map).unwrap() - /// ).unwrap(); + /// let bytes = rkyv::to_bytes::(&map).unwrap(); + /// let archived_map = rkyv::access::, rkyv::rancor::Error>(&bytes).unwrap(); /// let mut values = [0]; /// assert_eq!(archived_map.get_values(&1, &mut values), true); /// assert_eq!(values, [2]); @@ -380,14 +386,13 @@ where None => return false, }; - // SAFETY: `idx` is always within bounds (ensured during construction) + // SAFETY: `idx` and `value_idx` are always within bounds (ensured during construction) unsafe { if self.keys.get_unchecked(idx) != key { return false; } - // SAFETY: `idx` and `value_idx` are always within bounds (ensure during construction) - let value_idx = *self.values_index.get_unchecked(idx) as usize; + let value_idx = self.values_index.get_unchecked(idx).to_native() as usize; let dict = self.values_dict.get_unchecked(value_idx..); unpack_values(dict, values); } @@ -544,11 +549,11 @@ mod tests { let values_num = 10; let original_map = gen_map(items_num, values_num); let map = MapWithDictBitpacked::try_from(original_map.clone()).unwrap(); - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&map).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&map).unwrap(); assert_eq!(rkyv_bytes.len(), 18516); - let rkyv_map = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_map = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); // Test get_values on `Archived` version of `MapWithDictBitpacked` let mut values_buf = vec![0; values_num]; diff --git a/src/mphf.rs b/src/mphf.rs index acb52ae..8f99d12 100644 --- a/src/mphf.rs +++ b/src/mphf.rs @@ -3,19 +3,28 @@ //! This module implements a Minimal Perfect Hash Function (MPHF) based on fingerprinting techniques, //! as detailed in [Fingerprinting-based minimal perfect hashing revisited](https://doi.org/10.1145/3596453). //! +//! If you query with keys that were not used at the time of construction, collisions can happen. +//! Other structures are free of collisions, because they store `keys` and compare on each get. +//! //! This implementation is inspired by existing Rust crate [ph](https://github.com/beling/bsuccinct-rs/tree/main/ph), //! but prioritizes code simplicity and portability, with a special focus on optimizing the rank //! storage mechanism and reducing the construction time and querying latency of MPHF. -use std::hash::{Hash, Hasher}; -use std::marker::PhantomData; -use std::mem::size_of_val; +use core::fmt; +use std::{ + hash::{Hash, Hasher}, + marker::PhantomData, + mem::size_of_val, +}; use num::{Integer, PrimInt, Unsigned}; use wyhash::WyHash; -use crate::mphf::MphfError::*; -use crate::rank::{RankedBits, RankedBitsAccess}; +use crate::{ + mphf::MphfError::*, + rank::{RankedBits, RankedBitsAccess}, + IntoGroupSeed, +}; /// A Minimal Perfect Hash Function (MPHF). /// @@ -26,7 +35,6 @@ use crate::rank::{RankedBits, RankedBitsAccess}; /// - `H`: hasher used to hash keys, default `WyHash`. #[derive(Default)] #[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))] -#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))] pub struct Mphf { /// Ranked bits for efficient rank queries ranked_bits: RankedBits, @@ -44,14 +52,26 @@ const MAX_LEVELS: usize = 64; /// Errors that can occur when initializing `Mphf`. #[derive(Debug)] pub enum MphfError { - /// Error when the maximum number of levels is exceeded during initialization. - MaxLevelsExceeded, - /// Error when the seed type `ST` is too small to store `S` bits - InvalidSeedType, /// Error when the `gamma` parameter is less than 1.0. InvalidGammaParameter, + /// Error when the seed type `ST` is too small to store `S` bits + InvalidSeedType, + /// Error when the maximum number of levels is exceeded during initialization. + MaxLevelsExceeded, +} + +impl fmt::Display for MphfError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::InvalidGammaParameter => write!(f, "the `gamma` parameter is less than 1.0"), + Self::InvalidSeedType => write!(f, "the seed type `ST` is too small to store `S` bits"), + Self::MaxLevelsExceeded => write!(f, "the maximum number of levels is exceeded during initialization"), + } + } } +impl std::error::Error for MphfError {} + /// Default `gamma` parameter for MPHF. pub const DEFAULT_GAMMA: f32 = 2.0; @@ -68,7 +88,13 @@ impl(keys: &[K], gamma: f32) -> Result { + Self::from_iter(keys.iter(), gamma) + } + + /// Initializes `Mphf` using iter of `keys` and parameter `gamma`. + pub fn from_iter<'k, K: Hash + 'k>(keys_iter: impl Iterator, gamma: f32) -> Result { if gamma < 1.0 { return Err(InvalidGammaParameter); } @@ -77,7 +103,7 @@ impl = keys.iter().map(|key| hash_key::(key)).collect(); + let mut hashes: Vec = keys_iter.map(|key| hash_key::(key)).collect(); let mut group_bits = vec![]; let mut group_seeds = vec![]; let mut level_groups = vec![]; @@ -95,7 +121,7 @@ impl(&self, key: &K) -> Option { - Self::get_impl(key, &self.level_groups, &self.group_seeds, &self.ranked_bits) + pub fn get(&self, key: &K) -> Option + where + ST: IntoGroupSeed, + { + Self::get_impl( + key, + self.level_groups.iter().copied(), + &self.group_seeds, + &self.ranked_bits, + ) } /// Inner implementation of `get` with `level_groups`, `group_seeds` and `ranked_bits` passed /// from standard and `Archived` version of `Mphf`. #[inline] - fn get_impl( + fn get_impl( key: &K, - level_groups: &[u32], - group_seeds: &[ST], + level_groups: impl Iterator, + group_seeds: &[GS], ranked_bits: &impl RankedBitsAccess, ) -> Option { let mut groups_before = 0; - for (level, &groups) in level_groups.iter().enumerate() { + for (level, groups) in level_groups.enumerate() { let level_hash = hash_with_seed(hash_key::(key), level as u32); let group_idx = groups_before + fastmod32(level_hash as u32, groups); // SAFETY: `group_idx` is always within bounds (ensured during calculation) - let group_seed = unsafe { group_seeds.get_unchecked(group_idx).to_u32().unwrap() }; + let group_seed = unsafe { group_seeds.get_unchecked(group_idx).into_u32() }; let bit_idx = bit_index_for_seed::(level_hash, group_seed, group_idx); if let Some(rank) = ranked_bits.rank(bit_idx) { return Some(rank); @@ -313,12 +347,18 @@ fn fastmod32(x: u32, n: u32) -> usize { #[cfg(feature = "rkyv_derive")] impl ArchivedMphf where - ST: PrimInt + Unsigned + rkyv::Archive, + ST: PrimInt + Unsigned + rkyv::Archive, + ::Archived: IntoGroupSeed, H: Hasher + Default, { #[inline] pub fn get(&self, key: &K) -> Option { - Mphf::::get_impl(key, &self.level_groups, &self.group_seeds, &self.ranked_bits) + Mphf::::get_impl( + key, + self.level_groups.iter().map(|v| v.to_native()), + self.group_seeds.get(), + &self.ranked_bits, + ) } } @@ -412,11 +452,11 @@ mod tests { let n = 10000; let keys = (0..n as u64).collect::>(); let mphf = Mphf::<32, 4>::from_slice(&keys, DEFAULT_GAMMA).expect("failed to create mphf"); - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&mphf).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&mphf).unwrap(); - assert_eq!(rkyv_bytes.len(), 3804); + assert_eq!(rkyv_bytes.len(), 3884); - let rkyv_mphf = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_mphf = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); // Ensure that all keys are assigned unique index which is less than `n` let mut set = HashSet::with_capacity(n); diff --git a/src/rank.rs b/src/rank.rs index c1aad42..27ed994 100644 --- a/src/rank.rs +++ b/src/rank.rs @@ -5,6 +5,8 @@ use std::mem::size_of_val; +use crate::IntoRankBits; + /// Size of the L2 block in bits. const L2_BIT_SIZE: usize = 512; /// Size of the L1 block in bits, calculated as a multiple of the L2 block size. @@ -24,10 +26,10 @@ pub trait RankedBitsAccess { /// This method is unsafe because `idx` must be within the bounds of the bits stored in `RankedBitsAccess`. /// An index out of bounds can lead to undefined behavior. #[inline] - unsafe fn rank_impl(bits: &[u64], l12_ranks: &T, idx: usize) -> Option { + unsafe fn rank_impl(bits: &[B], l12_ranks: &T, idx: usize) -> Option { let word_idx = idx / 64; let bit_idx = idx % 64; - let word = *bits.get_unchecked(word_idx); + let word = bits.get_unchecked(word_idx).into_u64(); if (word & (1u64 << bit_idx)) == 0 { return None; @@ -41,9 +43,9 @@ pub trait RankedBitsAccess { let offset = (idx / L2_BIT_SIZE) * 8; let block = bits.get_unchecked(offset..offset + blocks_num); - let block_rank = block.iter().map(|&x| x.count_ones() as usize).sum::(); + let block_rank = block.iter().map(|&x| x.into_u64().count_ones() as usize).sum::(); - let word = *bits.get_unchecked(offset + blocks_num); + let word = bits.get_unchecked(offset + blocks_num).into_u64(); let word_mask = ((1u64 << (idx_within_l2 % 64)) - 1) * (idx_within_l2 > 0) as u64; let word_rank = (word & word_mask).count_ones() as usize; @@ -56,7 +58,6 @@ pub trait RankedBitsAccess { #[derive(Debug, Default)] #[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))] -#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))] pub struct RankedBits { /// The bit vector represented as an array of u64 integers. bits: Box<[u64]>, @@ -70,7 +71,6 @@ pub struct RankedBits { /// See https://github.com/rkyv/rkyv/issues/409 for more details. #[derive(Debug)] #[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))] -#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))] pub struct L12Rank([u8; 16]); /// Trait used to access archived and non-archived L1 and L2 ranks @@ -141,7 +141,7 @@ impl RankedBits { l12_ranks.push(l12_rank.into()); } - RankedBits { bits, l12_ranks: l12_ranks.into_boxed_slice() } + Self { bits, l12_ranks: l12_ranks.into_boxed_slice() } } /// Returns the total number of bytes occupied by `RankedBits` @@ -163,17 +163,15 @@ impl RankedBitsAccess for RankedBits { impl RankedBitsAccess for ArchivedRankedBits { #[inline] fn rank(&self, idx: usize) -> Option { - unsafe { Self::rank_impl(&self.bits, &self.l12_ranks, idx) } + unsafe { Self::rank_impl(self.bits.get(), &self.l12_ranks, idx) } } } #[cfg(test)] mod tests { use super::*; - use bitvec::order::Lsb0; - use bitvec::vec::BitVec; - use rand::distributions::Standard; - use rand::Rng; + use bitvec::{order::Lsb0, vec::BitVec}; + use rand::{distributions::Standard, Rng}; #[test] fn test_rank_and_get() { diff --git a/src/set.rs b/src/set.rs index 9b8b08c..3484883 100644 --- a/src/set.rs +++ b/src/set.rs @@ -10,20 +10,24 @@ //! dynamically update membership. However, when the `rkyv_derive` feature is enabled, you can use //! [`rkyv`](https://rkyv.org/) to perform zero-copy deserialization of a new set. -use std::borrow::Borrow; -use std::collections::HashSet; -use std::hash::{Hash, Hasher}; -use std::mem::size_of_val; +use std::{ + borrow::Borrow, + collections::HashSet, + hash::{Hash, Hasher}, + mem::size_of_val, +}; use num::{PrimInt, Unsigned}; use wyhash::WyHash; -use crate::mphf::{Mphf, MphfError, DEFAULT_GAMMA}; +use crate::{ + mphf::{Mphf, MphfError, DEFAULT_GAMMA}, + IntoGroupSeed, +}; /// An efficient, immutable set. #[derive(Default)] #[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))] -#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))] pub struct Set where ST: PrimInt + Unsigned, @@ -38,7 +42,7 @@ where impl Set where K: Eq + Hash, - ST: PrimInt + Unsigned, + ST: PrimInt + Unsigned + IntoGroupSeed, H: Hasher + Default, { /// Constructs a `Set` from an iterator of keys and MPHF function parameters. @@ -69,7 +73,7 @@ where } } - Ok(Set { mphf, keys: keys.into_boxed_slice() }) + Ok(Self { mphf, keys: keys.into_boxed_slice() }) } /// Returns `true` if the set contains the value. @@ -175,7 +179,8 @@ impl ArchivedSet where K: Eq + Hash + rkyv::Archive, K::Archived: PartialEq, - ST: PrimInt + Unsigned + rkyv::Archive, + ST: PrimInt + Unsigned + rkyv::Archive, + ::Archived: IntoGroupSeed, H: Hasher + Default, { /// Returns `true` if the set contains the value. @@ -185,18 +190,17 @@ where /// # use std::collections::HashSet; /// # use entropy_map::{ArchivedSet, Set}; /// let set: Set = Set::try_from(HashSet::from([1, 2, 3])).unwrap(); - /// let archived_set = rkyv::from_bytes::>( - /// &rkyv::to_bytes::<_, 1024>(&set).unwrap() - /// ).unwrap(); + /// let bytes = rkyv::to_bytes::(&set).unwrap(); + /// let archived_set = rkyv::access::, rkyv::rancor::Error>(&bytes).unwrap(); /// assert_eq!(archived_set.contains(&1), true); /// assert_eq!(archived_set.contains(&4), false); /// ``` #[inline] - pub fn contains(&self, key: &Q) -> bool + pub fn contains(&self, key: &Q) -> bool where K: Borrow, ::Archived: PartialEq, - Q: Hash + Eq, + Q: Hash + Eq + ?Sized, { // SAFETY: `idx` is always within bounds (ensured during construction) self.mphf @@ -264,11 +268,11 @@ mod tests { // create regular `HashSet`, then `Set`, then serialize to `rkyv` bytes. let original_set = gen_set(1000); let set = Set::try_from(original_set.clone()).unwrap(); - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&set).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&set).unwrap(); assert_eq!(rkyv_bytes.len(), 8408); - let rkyv_set = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_set = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); // Test get on `Archived` version for k in original_set.iter() { @@ -280,8 +284,8 @@ mod tests { #[test] fn test_rkyv_contains_borrow() { let set = Set::try_from(HashSet::from(["a".to_string(), "b".to_string()])).unwrap(); - let rkyv_bytes = rkyv::to_bytes::<_, 1024>(&set).unwrap(); - let rkyv_set = rkyv::check_archived_root::>(&rkyv_bytes).unwrap(); + let rkyv_bytes = rkyv::to_bytes::(&set).unwrap(); + let rkyv_set = rkyv::access::, rkyv::rancor::Error>(&rkyv_bytes).unwrap(); assert!(rkyv_set.contains("a")); assert!(rkyv_set.contains("b"));