diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..fcbf9b9 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,102 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Commands + +```bash +make test # Start Redis, run full test suite, tear down +make bench # Start Redis, run criterion benchmarks, tear down +make redis-up # Start Redis only (docker-compose, port REDIS_PORT=16379) +make redis-down # Stop Redis and remove volumes +``` + +To run tests manually (Redis must be running): +```bash +REDIS_URL="redis://127.0.0.1:16379/" cargo test --all-features +REDIS_URL="redis://127.0.0.1:16379/" cargo test --all-features # single test +cargo test --doc --all-features # doctests only +``` + +**Always use `make test`, not bare `cargo test`** — tests require `REDIS_URL` to be set and a live Redis instance. + +## Architecture + +**distkit** is an async Rust library (Tokio + Redis) providing distributed counting primitives with two consistency modes and two counter families. + +### Counter families + +| Family | Trait | Strict impl | Lax impl | +|--------|-------|-------------|----------| +| Simple | `CounterTrait` | `StrictCounter` | `LaxCounter` | +| Instance-aware | `InstanceAwareCounterTrait` | `StrictInstanceAwareCounter` | `LaxInstanceAwareCounter` | + +**Strict** counters: every operation is an atomic Lua script round-trip — fully consistent. + +**Lax** counters: `inc`/`dec`/`get`/`set_on_instance` are served from an in-memory `DashMap`; a background Tokio task flushes accumulated deltas to Redis every `flush_interval` (default 20 ms). Epoch-bumping operations (`set`, `del`, `clear`) flush first, then delegate to the strict backend. The background task holds a `Weak` reference and stops when the counter is dropped. + +### Instance-aware counters + +Each counter instance gets a UUID. Operations return `(cumulative, instance_count)`. Redis stores: +- `cumulative_key` hash: key → global total +- `instance_count_key` hash: per-instance contributions +- `instances_key` sorted set: instance_id → last-heartbeat timestamp (ms) +- `epoch_key` hash: per-key epoch counter (bumped by `set`/`del` to invalidate stale slices) + +Dead instances (no heartbeat for `dead_instance_threshold_ms`, default 30 s) are cleaned up by the next live instance that touches an affected key. + +### Lua scripts + +All Redis logic is embedded as inline Lua strings inside each `*_counter.rs` file — no external `.lua` files. `HELPERS_LUA` (in `strict_instance_aware_counter.rs`) defines shared helpers (`now_ms`, `delete_dead_instances`, `check_and_zadd`) that are prepended to every icounter script via string concatenation. + +Scripts echo keys back in their return values so callers can build `HashMap` results instead of relying on positional ordering. **Never use `.zip()` to align pipeline results with input keys** — use the HashMap keyed on the returned key string. + +### `execute_pipeline_with_script_retry` + +`src/common/mod.rs` exports this generic helper used by every batch pipeline operation: + +```rust +execute_pipeline_with_script_retry(conn, script, items, |item| { + let mut inv = script.key(...); + inv.key(...).arg(...); + inv // return owned ScriptInvocation +}) +``` + +On `NOSCRIPT` error it prepends `load_script` and retries the entire pipeline. Callers pass a closure returning one `ScriptInvocation<'s>` per item; the function owns all pipeline mechanics. + +### `RedisKey` + +Newtype wrapping `String`, validated on construction (`TryFrom`): non-empty, ≤255 chars, no colons. Used as the public API key type throughout. `RedisKeyGenerator` prepends the counter-type prefix when building actual Redis keys. + +### `ActivityTracker` + +Drives the lax flush task's sleep/wake cycle. `signal()` sets `is_active = true` atomically and sends on a `watch` channel. The flush task parks at `is_active_watch.changed()` when idle; `run_is_active_task` sets `is_active = false` after `epoch_interval / 2` (7.5 s) of inactivity. The epoch advances every `EPOCH_CHANGE_INTERVAL` (15 s); `signal()` is a no-op within the same epoch to avoid redundant sends. + +### Feature flags + +- `counter` (default): `StrictCounter`, `LaxCounter` +- `instance-aware-counter` (default): `StrictInstanceAwareCounter`, `LaxInstanceAwareCounter` +- `trypema`: re-exports the `trypema` rate-limiting crate + +### Module layout + +``` +src/ + lib.rs # feature-gated re-exports + error.rs # DistkitError + common/ + mod.rs # RedisKey, RedisKeyGenerator, execute_pipeline_with_script_retry, + # ActivityTracker, EPOCH_CHANGE_INTERVAL + counter/ + counter_trait.rs # CounterTrait (async_trait) + strict_counter.rs # StrictCounter + embedded Lua + lax_counter.rs # LaxCounter + embedded Lua + tests/ # unit tests per impl + icounter/ + mod.rs # InstanceAwareCounterTrait (async_trait) + strict_instance_aware_counter.rs # StrictInstanceAwareCounter + all Lua scripts + lax_instance_aware_counter.rs # LaxInstanceAwareCounter (wraps strict) + tests/ + __doctest_helpers.rs # Counter factory helpers for inline doc examples +``` diff --git a/Cargo.lock b/Cargo.lock index 564fab2..8437b4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -332,12 +332,13 @@ dependencies = [ [[package]] name = "distkit" -version = "0.2.3" +version = "0.3.0" dependencies = [ "async-trait", "criterion", "dashmap", "redis", + "regex", "strum", "strum_macros", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index f47fce4..d05179f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "distkit" -version = "0.2.3" +version = "0.3.0" edition = "2024" description = "A toolkit of distributed systems primitives for Rust, backed by Redis" authors = ["Oyinbo David Bayode "] @@ -40,6 +40,7 @@ trypema = ["trypema-crate", "trypema-crate/redis-tokio"] [dependencies] async-trait = "0.1.89" dashmap = "6.1.0" +regex = "1.12.3" redis = { version = "1.2.0", features = [ "aio", "connection-manager", diff --git a/README.md b/README.md index cd6057a..2ec9968 100644 --- a/README.md +++ b/README.md @@ -61,17 +61,17 @@ distkit requires a running Redis instance (5.0+ for Lua script support). ## Quick start ```rust -use distkit::{RedisKey, counter::{StrictCounter, LaxCounter, CounterOptions, CounterTrait}}; +use distkit::{DistkitRedisKey, counter::{StrictCounter, LaxCounter, CounterOptions, CounterTrait}}; #[tokio::main] async fn main() -> Result<(), Box> { let client = redis::Client::open("redis://127.0.0.1/")?; let conn = client.get_connection_manager().await?; - let prefix = RedisKey::try_from("my_app".to_string())?; + let prefix = DistkitRedisKey::try_from("my_app".to_string())?; let options = CounterOptions::new(prefix, conn); - let key = RedisKey::try_from("page_views".to_string())?; + let key = DistkitRedisKey::try_from("page_views".to_string())?; // Strict: immediate consistency let strict = StrictCounter::new(options.clone()); @@ -97,13 +97,36 @@ Every call is a single Redis round-trip executing an atomic Lua script. The counter value is always authoritative. ```rust -let key = RedisKey::try_from("orders".to_string())?; +let key = DistkitRedisKey::try_from("orders".to_string())?; strict.inc(&key, 1).await?; // HINCRBY via Lua strict.set(&key, 100).await?; // HSET via Lua strict.del(&key).await?; // HDEL, returns old value strict.clear().await?; // DEL on the hash ``` +Conditional writes use `CounterComparator` and return the current value +unchanged when the comparison fails. + +```rust +use distkit::CounterComparator; + +strict.set(&key, 10).await?; +assert_eq!(strict.inc_if(&key, CounterComparator::Eq(10), 5).await?, 15); +assert_eq!(strict.set_if(&key, CounterComparator::Gt(20), 99).await?, 15); +``` + +Batch increments follow the same rules and preserve input order. + +```rust +let results = strict + .inc_all_if(&[ + (&key, CounterComparator::Eq(15), 2), + (&key, CounterComparator::Nil, 3), + ]) + .await?; +assert_eq!(results, vec![(&key, 17), (&key, 20)]); +``` + ### LaxCounter Writes are buffered in a local `DashMap` and flushed to Redis in batched @@ -112,7 +135,7 @@ pipelines every `allowed_lag` (default 20 ms). Reads return the local view process. ```rust -let key = RedisKey::try_from("impressions".to_string())?; +let key = DistkitRedisKey::try_from("impressions".to_string())?; lax.inc(&key, 1).await?; // local atomic add, sub-microsecond let val = lax.get(&key).await?; // reads local state, no Redis hit ``` @@ -156,6 +179,12 @@ This makes them well-suited for: restarts or crashes. - **Per-node metrics** -- see both the global total and each instance's slice. +Conditional instance-aware writes follow the same rule set: + +- `inc_if` and `set_if` compare against the cumulative total. +- `set_on_instance_if` compares against the calling instance's slice. +- Failed comparisons return the current `(cumulative, instance_count)` unchanged. + ### StrictInstanceAwareCounter Every call is immediately consistent with Redis. `set` and `del` bump a @@ -167,16 +196,16 @@ use distkit::icounter::{ InstanceAwareCounterTrait, StrictInstanceAwareCounter, StrictInstanceAwareCounterOptions, }; -use distkit::RedisKey; +use distkit::DistkitRedisKey; let client = redis::Client::open("redis://127.0.0.1/")?; let conn = client.get_connection_manager().await?; -let prefix = RedisKey::try_from("my_app".to_string())?; +let prefix = DistkitRedisKey::try_from("my_app".to_string())?; let counter = StrictInstanceAwareCounter::new( StrictInstanceAwareCounterOptions::new(prefix, conn), ); -let key = RedisKey::try_from("connections".to_string())?; +let key = DistkitRedisKey::try_from("connections".to_string())?; // Increment this instance's contribution; returns (cumulative, instance_count). let (total, mine) = counter.inc(&key, 5).await?; @@ -211,13 +240,13 @@ use distkit::icounter::{ InstanceAwareCounterTrait, StrictInstanceAwareCounter, StrictInstanceAwareCounterOptions, }; -use distkit::RedisKey; +use distkit::DistkitRedisKey; let client = redis::Client::open("redis://127.0.0.1/")?; let conn1 = client.get_connection_manager().await?; let conn2 = client.get_connection_manager().await?; -let prefix = RedisKey::try_from("my_app".to_string())?; -let key = RedisKey::try_from("connections".to_string())?; +let prefix = DistkitRedisKey::try_from("my_app".to_string())?; +let key = DistkitRedisKey::try_from("connections".to_string())?; let opts = |conn| StrictInstanceAwareCounterOptions { prefix: prefix.clone(), @@ -250,12 +279,12 @@ use distkit::icounter::{ InstanceAwareCounterTrait, LaxInstanceAwareCounter, LaxInstanceAwareCounterOptions, }; -use distkit::RedisKey; +use distkit::DistkitRedisKey; use std::time::Duration; let client = redis::Client::open("redis://127.0.0.1/")?; let conn = client.get_connection_manager().await?; -let prefix = RedisKey::try_from("my_app".to_string())?; +let prefix = DistkitRedisKey::try_from("my_app".to_string())?; let counter = LaxInstanceAwareCounter::new(LaxInstanceAwareCounterOptions { prefix, connection_manager: conn, @@ -264,7 +293,7 @@ let counter = LaxInstanceAwareCounter::new(LaxInstanceAwareCounterOptions { allowed_lag: Duration::from_millis(20), }); -let key = RedisKey::try_from("connections".to_string())?; +let key = DistkitRedisKey::try_from("connections".to_string())?; // Returns the local estimate immediately — no Redis round-trip on warm path. let (local_total, mine) = counter.inc(&key, 1).await?; diff --git a/benches/common.rs b/benches/common.rs index 22fd830..1809a56 100644 --- a/benches/common.rs +++ b/benches/common.rs @@ -4,17 +4,16 @@ use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; +use distkit::DistkitRedisKey; use distkit::counter::{CounterOptions, LaxCounter, StrictCounter}; use distkit::icounter::{ LaxInstanceAwareCounter, LaxInstanceAwareCounterOptions, StrictInstanceAwareCounter, StrictInstanceAwareCounterOptions, }; -use distkit::RedisKey; use redis::aio::ConnectionManager; pub async fn make_connection() -> ConnectionManager { - let url = std::env::var("REDIS_URL") - .expect("REDIS_URL must be set — run via `make bench`"); + let url = std::env::var("REDIS_URL").expect("REDIS_URL must be set — run via `make bench`"); let client = redis::Client::open(url).expect("REDIS_URL is not a valid Redis URL"); client .get_connection_manager() @@ -48,9 +47,9 @@ pub async fn make_lax_counter(bench_name: &str) -> Arc { LaxCounter::new(CounterOptions::new(bench_prefix(bench_name), conn)) } -/// Builds a `RedisKey` from a plain name string. -pub fn key(name: &str) -> RedisKey { - RedisKey::try_from(name.to_string()) +/// Builds a `DistkitRedisKey` from a plain name string. +pub fn key(name: &str) -> DistkitRedisKey { + DistkitRedisKey::try_from(name.to_string()) .expect("bench key must be non-empty, ≤255 chars, and colon-free") } @@ -58,11 +57,11 @@ pub fn key(name: &str) -> RedisKey { // Internal helpers // --------------------------------------------------------------------------- -fn bench_prefix(bench_name: &str) -> RedisKey { +fn bench_prefix(bench_name: &str) -> DistkitRedisKey { let ts = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_nanos(); - RedisKey::try_from(format!("bench_{}_{}", ts, bench_name)) + DistkitRedisKey::try_from(format!("bench_{}_{}", ts, bench_name)) .expect("constructed bench prefix is always valid") } diff --git a/benches/strict_instance_aware_counter.rs b/benches/strict_instance_aware_counter.rs index bf58ffc..6d9739a 100644 --- a/benches/strict_instance_aware_counter.rs +++ b/benches/strict_instance_aware_counter.rs @@ -7,15 +7,15 @@ // function uses a distinct member key so there is no cross-bench state // interference. Destructive operations (del, del_on_instance, clear, // clear_on_instance) use `iter_batched` to re-seed the key before every -// measured call. `inc_batch_10` rebuilds its input Vec via `iter_batched` -// because the vec is drained after each call. +// measured call. `inc_all_10` rebuilds its borrowed input Vec each iteration. mod common; use std::time::Duration; use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; -use distkit::RedisKey; +use distkit::DistkitRedisKey; +use distkit::icounter::InstanceAwareCounterTrait; use tokio::runtime::Runtime; fn bench_strict_instance_aware_counter(c: &mut Criterion) { @@ -91,17 +91,17 @@ fn bench_strict_instance_aware_counter(c: &mut Criterion) { ); }); - // inc_batch_10 — pipeline 10 distinct keys in a single batch. - // The Vec is rebuilt each iteration (inc_batch drains it). The allocation - // is negligible compared to the Redis round-trip being measured. - let batch_keys: Vec = (0..10) + // inc_all_10 — pipeline 10 distinct keys in a single batch. + // The borrowed input Vec is rebuilt each iteration. The allocation is + // negligible compared to the Redis round-trip being measured. + let batch_keys: Vec = (0..10) .map(|i| common::key(&format!("batch_{i}"))) .collect(); - group.bench_function("inc_batch_10", |b| { + group.bench_function("inc_all_10", |b| { b.to_async(&rt).iter(|| async { - let mut increments: Vec<(RedisKey, i64)> = - batch_keys.iter().map(|k| (k.clone(), 1i64)).collect(); - counter.inc_batch(&mut increments, 50).await.unwrap(); + let increments: Vec<(&DistkitRedisKey, i64)> = + batch_keys.iter().map(|k| (k, 1i64)).collect(); + counter.inc_all(&increments).await.unwrap(); }); }); diff --git a/docs/lib.md b/docs/lib.md index 4e11a71..a8f428e 100644 --- a/docs/lib.md +++ b/docs/lib.md @@ -27,17 +27,17 @@ currently offers three modules and they all run on the tokio runtime: # Quick start ```rust -# use distkit::{RedisKey, counter::{StrictCounter, LaxCounter, CounterOptions, CounterTrait}}; +# use distkit::{DistkitRedisKey, counter::{StrictCounter, LaxCounter, CounterOptions, CounterTrait}}; # async fn example() -> Result<(), Box> { # let client = redis::Client::open("redis://127.0.0.1/")?; # let conn = client.get_connection_manager().await?; // Servers sharing the same prefix coordinate through the same Redis keys. -let prefix = RedisKey::try_from("my_app".to_string())?; +let prefix = DistkitRedisKey::try_from("my_app".to_string())?; let options = CounterOptions::new(prefix, conn); // Strict: every call hits Redis immediately let strict = StrictCounter::new(options.clone()); -let key = RedisKey::try_from("page_views".to_string())?; +let key = DistkitRedisKey::try_from("page_views".to_string())?; strict.inc(&key, 1).await?; let total = strict.get(&key).await?; @@ -66,18 +66,18 @@ let approx = lax.get(&key).await?; types implement [`InstanceAwareCounterTrait`]. Write generic code against either trait: ```rust -# use distkit::{DistkitError, RedisKey, counter::CounterTrait}; +# use distkit::{DistkitError, DistkitRedisKey, counter::CounterTrait}; // Example: bumping a counter by 1 (strict or lax) -async fn bump(counter: &C, key: &RedisKey) -> Result { +async fn bump(counter: &C, key: &DistkitRedisKey) -> Result { counter.inc(key, 1).await } ``` ```rust,no_run -# use distkit::{DistkitError, RedisKey, icounter::InstanceAwareCounterTrait}; +# use distkit::{DistkitError, DistkitRedisKey, icounter::InstanceAwareCounterTrait}; async fn report_connection( counter: &C, - key: &RedisKey, + key: &DistkitRedisKey, delta: i64, ) -> Result { let (total, _mine) = counter.inc(key, delta).await?; @@ -87,19 +87,61 @@ async fn report_connection( # Key types -- [`RedisKey`] -- A validated key string (non-empty, 255 chars max, no colons). - Constructed via `TryFrom`. +- [`DistkitRedisKey`] -- A validated key string (non-empty, 255 chars max, no + colons), with helpers like `new`, `new_or_panic`, and `try_sanitize`. +- [`CounterComparator`] -- The comparison operator used by conditional writes: + [`Eq`](crate::CounterComparator::Eq), [`Lt`](crate::CounterComparator::Lt), + [`Gt`](crate::CounterComparator::Gt), [`Ne`](crate::CounterComparator::Ne), + or [`Nil`](crate::CounterComparator::Nil). - [`CounterOptions`] -- Configuration bundle for counter construction. Carries a prefix, Redis connection, and the `allowed_lag` duration (default 20 ms). Implements `Clone`, so the same options can be passed to both counter types. - [`CounterTrait`] -- The async trait that both counter types implement: - `inc`, `dec`, `get`, `set`, `del`, `clear`. + `inc`, `inc_if`, `dec`, `get`, `set`, `set_if`, `del`, `clear`, and + multi-key helpers including `inc_all_if` and `set_all_if`. + +# Conditional writes + +Use [`CounterComparator`] with the `*_if` methods to apply a write only when +the current value matches a condition. Failed comparisons return the current +value unchanged. + +```rust,no_run +# use distkit::{CounterComparator, DistkitRedisKey, counter::{CounterOptions, CounterTrait, StrictCounter}}; +# async fn example() -> Result<(), Box> { +# let client = redis::Client::open("redis://127.0.0.1/")?; +# let conn = client.get_connection_manager().await?; +# let prefix = DistkitRedisKey::try_from("my_app".to_string())?; +# let counter = StrictCounter::new(CounterOptions::new(prefix, conn)); +let key = DistkitRedisKey::try_from("orders".to_string())?; +counter.set(&key, 10).await?; + +assert_eq!( + counter.inc_if(&key, CounterComparator::Eq(10), 5).await?, + 15 +); +assert_eq!( + counter.set_if(&key, CounterComparator::Gt(20), 99).await?, + 15 +); +assert_eq!( + counter + .inc_all_if(&[ + (&key, CounterComparator::Eq(15), 2), + (&key, CounterComparator::Nil, 3), + ]) + .await?, + vec![(&key, 17), (&key, 20)] +); +# Ok(()) +# } +``` # Error handling All fallible operations return [`DistkitError`]: -- **`InvalidRedisKey`** -- Returned by `RedisKey::try_from` when the input is +- **`InvalidRedisKey`** -- Returned by `DistkitRedisKey::try_from` when the input is empty, longer than 255 characters, or contains a colon. - **`RedisError`** -- A Redis operation failed (connection lost, script error, etc.). Wraps [`redis::RedisError`]. @@ -135,6 +177,12 @@ This makes them well-suited for: restarts or crashes. - **Per-node metrics** -- see both the global total and each instance's slice. +Conditional instance-aware writes follow the same pattern: + +- `inc_if` and `set_if` compare against the cumulative total. +- `set_on_instance_if` and `set_all_on_instance_if` compare against this + instance's current slice. + ## StrictInstanceAwareCounter Every call is immediately consistent with Redis. `set` and `del` bump a @@ -146,16 +194,16 @@ their next operation, preventing double-counting. # InstanceAwareCounterTrait, # StrictInstanceAwareCounter, StrictInstanceAwareCounterOptions, # }; -# use distkit::RedisKey; +# use distkit::DistkitRedisKey; # async fn example() -> Result<(), Box> { # let client = redis::Client::open("redis://127.0.0.1/")?; # let conn = client.get_connection_manager().await?; -let prefix = RedisKey::try_from("my_app".to_string())?; +let prefix = DistkitRedisKey::try_from("my_app".to_string())?; let counter = StrictInstanceAwareCounter::new( StrictInstanceAwareCounterOptions::new(prefix, conn), ); -let key = RedisKey::try_from("connections".to_string())?; +let key = DistkitRedisKey::try_from("connections".to_string())?; // Increment this instance's contribution; returns (cumulative, instance_count). let (total, mine) = counter.inc(&key, 5).await?; @@ -194,13 +242,13 @@ them touches the same key. # InstanceAwareCounterTrait, # StrictInstanceAwareCounter, StrictInstanceAwareCounterOptions, # }; -# use distkit::RedisKey; +# use distkit::DistkitRedisKey; # async fn example() -> Result<(), Box> { # let client = redis::Client::open("redis://127.0.0.1/")?; # let conn1 = client.get_connection_manager().await?; # let conn2 = client.get_connection_manager().await?; -let prefix = RedisKey::try_from("my_app".to_string())?; -let key = RedisKey::try_from("connections".to_string())?; +let prefix = DistkitRedisKey::try_from("my_app".to_string())?; +let key = DistkitRedisKey::try_from("connections".to_string())?; // Two independent instances sharing the same prefix. let opts = |conn| StrictInstanceAwareCounterOptions { @@ -236,12 +284,12 @@ small consistency lag. # InstanceAwareCounterTrait, # LaxInstanceAwareCounter, LaxInstanceAwareCounterOptions, # }; -# use distkit::RedisKey; +# use distkit::DistkitRedisKey; # use std::time::Duration; # async fn example() -> Result<(), Box> { # let client = redis::Client::open("redis://127.0.0.1/")?; # let conn = client.get_connection_manager().await?; -let prefix = RedisKey::try_from("my_app".to_string())?; +let prefix = DistkitRedisKey::try_from("my_app".to_string())?; // options: LaxInstanceAwareCounterOptions::new(prefix, conn) would give the same result. let counter = LaxInstanceAwareCounter::new(LaxInstanceAwareCounterOptions { @@ -252,7 +300,7 @@ let counter = LaxInstanceAwareCounter::new(LaxInstanceAwareCounterOptions { allowed_lag: Duration::from_millis(20), }); -let key = RedisKey::try_from("connections".to_string())?; +let key = DistkitRedisKey::try_from("connections".to_string())?; // Returns the local estimate immediately — no Redis round-trip on warm path. let (local_total, mine) = counter.inc(&key, 1).await?; diff --git a/src/__doctest_helpers.rs b/src/__doctest_helpers.rs index dba4070..ab6f305 100644 --- a/src/__doctest_helpers.rs +++ b/src/__doctest_helpers.rs @@ -10,7 +10,7 @@ use std::{ }; use crate::{ - RedisKey, + DistkitRedisKey, counter::{CounterOptions, LaxCounter, StrictCounter}, icounter::{ LaxInstanceAwareCounter, LaxInstanceAwareCounterOptions, StrictInstanceAwareCounter, @@ -22,8 +22,8 @@ fn redis_url() -> String { std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string()) } -fn unique_prefix() -> Result { - RedisKey::try_from(format!( +fn unique_prefix() -> Result { + DistkitRedisKey::try_from(format!( "test_{}_{}_{}", uuid::Uuid::new_v4(), SystemTime::now() diff --git a/src/common/activity_tracker.rs b/src/common/activity_tracker.rs index c1f0523..78dbdd8 100644 --- a/src/common/activity_tracker.rs +++ b/src/common/activity_tracker.rs @@ -94,6 +94,7 @@ impl ActivityTracker { pub(crate) fn signal(&self) { let epoch = self.epoch.load(Ordering::Relaxed); if self.last_commited_epoch.load(Ordering::Relaxed) < epoch { + self.is_active.store(true, Ordering::Release); let _ = self.is_active_watch.send(epoch); self.last_commited_epoch.store(epoch, Ordering::Relaxed); } diff --git a/src/common/mod.rs b/src/common/mod.rs index 64a7f92..43f6271 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,9 +1,12 @@ use std::{ ops::{Deref, DerefMut}, - sync::Mutex, + sync::{LazyLock, Mutex}, time::Duration, }; +use redis::aio::ConnectionManager; +use regex::Regex; + mod activity_tracker; pub(crate) use activity_tracker::*; @@ -11,24 +14,161 @@ pub(crate) const EPOCH_CHANGE_INTERVAL: Duration = Duration::from_secs(15); use crate::DistkitError; +static REDIS_KEY_STRIP_RE: LazyLock = + LazyLock::new(|| Regex::new(r":").expect("REDIS_KEY_STRIP_RE is valid")); + +pub(crate) async fn execute_pipeline_with_script_retry<'s, T, I, F>( + conn: &mut ConnectionManager, + script: &'s redis::Script, + items: &[I], + build_invocation: F, +) -> Result +where + T: redis::FromRedisValue, + F: Fn(&I) -> redis::ScriptInvocation<'s>, +{ + let mut pipe = redis::Pipeline::new(); + + for item in items { + pipe.invoke_script(&build_invocation(item)); + } + + match pipe.query_async::(conn).await { + Ok(r) => Ok(r), + Err(err) if err.kind() == redis::ErrorKind::Server(redis::ServerErrorKind::NoScript) => { + let mut retry_pipe = redis::Pipeline::new(); + + retry_pipe.load_script(script).ignore(); + + for item in items { + retry_pipe.invoke_script(&build_invocation(item)); + } + + retry_pipe + .query_async::(conn) + .await + .map_err(DistkitError::RedisError) + } + Err(err) => Err(DistkitError::RedisError(err)), + } +} + /// A validated Redis key. /// -/// Keys must be non-empty, at most 255 characters, and must not contain -/// colons (`:`). Construct via [`TryFrom`]. +/// All Redis-backed distkit operations require keys wrapped in this type. +/// Validation happens at construction time, whether you use +/// [`DistkitRedisKey::new`], [`TryFrom`], or +/// [`DistkitRedisKey::new_or_panic`]. /// -/// `RedisKey` dereferences to [`String`], so all string methods are +/// # Validation Rules +/// +/// - Must not be empty +/// - Must be 255 bytes or shorter +/// - Must not contain `:` because distkit uses colons internally as separators +/// +/// `DistkitRedisKey` dereferences to [`String`], so standard string methods are /// available through auto-deref. +/// +/// # Examples +/// +/// ``` +/// use distkit::DistkitRedisKey; +/// +/// let key = DistkitRedisKey::new("user_123".to_string()).unwrap(); +/// let key = DistkitRedisKey::try_from("api_v2_endpoint".to_string()).unwrap(); +/// let key = DistkitRedisKey::new_or_panic("team_alpha".to_string()); +/// +/// assert!(DistkitRedisKey::try_from("user:123".to_string()).is_err()); +/// assert!(DistkitRedisKey::try_from("".to_string()).is_err()); +/// assert!(DistkitRedisKey::try_from("a".repeat(256)).is_err()); +/// ``` #[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq)] -pub struct RedisKey(String); +pub struct DistkitRedisKey(String); + +impl DistkitRedisKey { + /// Returns the default Redis namespace prefix used by distkit. + /// + /// # Examples + /// + /// ``` + /// use distkit::DistkitRedisKey; + /// + /// assert_eq!(*DistkitRedisKey::default_prefix(), "distkit"); + /// ``` + pub fn default_prefix() -> Self { + Self("distkit".to_string()) + } + + /// Fallible constructor. Equivalent to [`TryFrom`]. + /// + /// # Examples + /// + /// ``` + /// use distkit::DistkitRedisKey; + /// + /// let key = DistkitRedisKey::new("orders".to_string())?; + /// assert_eq!(*key, "orders"); + /// # Ok::<(), distkit::DistkitError>(()) + /// ``` + pub fn new(value: String) -> Result { + Self::try_from(value) + } + + /// Panicking constructor for validated keys. + /// + /// # Examples + /// + /// ``` + /// use distkit::DistkitRedisKey; + /// + /// let key = DistkitRedisKey::new_or_panic("orders".to_string()); + /// assert_eq!(*key, "orders"); + /// ``` + pub fn new_or_panic(value: String) -> Self { + Self::try_from(value).expect("invalid DistkitRedisKey") + } + + /// Strips colons from `value`, then validates the sanitized result. + /// + /// # Examples + /// + /// ``` + /// use distkit::DistkitRedisKey; + /// + /// let key = DistkitRedisKey::try_sanitize("user:123".to_string())?; + /// assert_eq!(*key, "user123"); + /// + /// assert!(DistkitRedisKey::try_sanitize(":".to_string()).is_err()); + /// # Ok::<(), distkit::DistkitError>(()) + /// ``` + pub fn try_sanitize(value: String) -> Result { + let sanitized = REDIS_KEY_STRIP_RE.replace_all(&value, "").into_owned(); + Self::try_from(sanitized) + } + + /// Strips colons from `value` and returns the sanitized key. + /// + /// Panics if the sanitized result is still invalid. + /// + /// # Examples + /// + /// ``` + /// use distkit::DistkitRedisKey; + /// + /// let key = DistkitRedisKey::sanitize_or_panic("user:123".to_string()); + /// assert_eq!(*key, "user123"); + /// ``` + pub fn sanitize_or_panic(value: String) -> Self { + Self::try_sanitize(value).expect("sanitized DistkitRedisKey value is still invalid") + } -impl RedisKey { #[cfg(test)] pub(crate) fn from(value: String) -> Self { Self(value) } } -impl Deref for RedisKey { +impl Deref for DistkitRedisKey { type Target = String; fn deref(&self) -> &Self::Target { @@ -36,13 +176,13 @@ impl Deref for RedisKey { } } -impl DerefMut for RedisKey { +impl DerefMut for DistkitRedisKey { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } -impl TryFrom for RedisKey { +impl TryFrom for DistkitRedisKey { type Error = DistkitError; fn try_from(value: String) -> Result { @@ -64,6 +204,10 @@ impl TryFrom for RedisKey { } } +/// Backwards-compatible alias for [`DistkitRedisKey`]. +#[doc(hidden)] +pub type RedisKey = DistkitRedisKey; + #[derive(Clone, Debug, strum_macros::Display)] pub(crate) enum RedisKeyGeneratorTypeKey { #[strum(to_string = "lax_counter")] @@ -78,12 +222,12 @@ pub(crate) enum RedisKeyGeneratorTypeKey { #[derive(Clone, Debug)] pub(crate) struct RedisKeyGenerator { - prefix: RedisKey, + prefix: DistkitRedisKey, key_type: RedisKeyGeneratorTypeKey, } impl RedisKeyGenerator { - pub(crate) fn new(prefix: RedisKey, key_type: RedisKeyGeneratorTypeKey) -> Self { + pub(crate) fn new(prefix: DistkitRedisKey, key_type: RedisKeyGeneratorTypeKey) -> Self { Self { prefix, key_type } } diff --git a/src/comparator.rs b/src/comparator.rs new file mode 100644 index 0000000..b2d3d7a --- /dev/null +++ b/src/comparator.rs @@ -0,0 +1,74 @@ +/// Comparison operator used by conditional counter writes. +/// +/// The comparator is evaluated against the current observed counter value. +/// +/// # Examples +/// +/// ```rust +/// use distkit::CounterComparator; +/// +/// assert!(CounterComparator::Eq(5).matches(5)); +/// assert!(CounterComparator::Lt(5).matches(4)); +/// assert!(CounterComparator::Gt(5).matches(6)); +/// assert!(CounterComparator::Ne(5).matches(4)); +/// assert!(CounterComparator::Nil.matches(42)); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CounterComparator { + /// Matches when the current value is equal to the embedded operand. + Eq(i64), + /// Matches when the current value is less than the embedded operand. + Lt(i64), + /// Matches when the current value is greater than the embedded operand. + Gt(i64), + /// Matches when the current value is not equal to the embedded operand. + Ne(i64), + /// Always matches. + /// + /// This is primarily useful for delegating unconditional APIs through the + /// conditional write path, or for mixing guarded and unguarded writes in a + /// single batch call. + Nil, +} + +impl CounterComparator { + /// Returns whether `current` satisfies this comparator. + pub fn matches(self, current: i64) -> bool { + match self { + Self::Eq(expected) => current == expected, + Self::Lt(expected) => current < expected, + Self::Gt(expected) => current > expected, + Self::Ne(expected) => current != expected, + Self::Nil => true, + } + } + + pub(crate) fn as_lua_parts(self) -> (&'static str, i64) { + match self { + Self::Eq(expected) => ("eq", expected), + Self::Lt(expected) => ("lt", expected), + Self::Gt(expected) => ("gt", expected), + Self::Ne(expected) => ("ne", expected), + Self::Nil => ("nil", 0), + } + } +} + +#[cfg(test)] +mod tests { + use super::CounterComparator; + + #[test] + fn matches_uses_embedded_operand_and_nil_always_matches() { + assert!(CounterComparator::Eq(5).matches(5)); + assert!(CounterComparator::Lt(5).matches(4)); + assert!(CounterComparator::Gt(5).matches(6)); + assert!(CounterComparator::Ne(5).matches(4)); + assert!(CounterComparator::Nil.matches(-99)); + + assert!(!CounterComparator::Eq(5).matches(4)); + assert!(!CounterComparator::Lt(5).matches(5)); + assert!(!CounterComparator::Gt(5).matches(5)); + assert!(!CounterComparator::Ne(5).matches(5)); + } +} diff --git a/src/counter/counter_trait.rs b/src/counter/counter_trait.rs index ca0f54e..4a8c8c3 100644 --- a/src/counter/counter_trait.rs +++ b/src/counter/counter_trait.rs @@ -1,4 +1,4 @@ -use crate::{DistkitError, RedisKey}; +use crate::{CounterComparator, DistkitError, DistkitRedisKey}; /// Async interface for distributed counter operations. /// @@ -15,11 +15,11 @@ pub trait CounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_counter().await?; - /// let key = RedisKey::try_from("visits".to_string())?; + /// let key = DistkitRedisKey::try_from("visits".to_string())?; /// assert_eq!(counter.inc(&key, 1).await?, 1); /// assert_eq!(counter.inc(&key, 9).await?, 10); /// // Negative count is the same as calling dec. @@ -27,7 +27,45 @@ pub trait CounterTrait { /// # Ok(()) /// # } /// ``` - async fn inc(&self, key: &RedisKey, count: i64) -> Result; + async fn inc(&self, key: &DistkitRedisKey, count: i64) -> Result; + + /// Conditionally increments the counter by `count` when the current value + /// satisfies `comparator`. + /// + /// Returns the updated total on success, or the current total unchanged + /// when the condition fails. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, counter::CounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_counter().await?; + /// let key = DistkitRedisKey::try_from("inventory".to_string())?; + /// counter.set(&key, 10).await?; + /// + /// assert_eq!( + /// counter.inc_if(&key, CounterComparator::Eq(10), 5).await?, + /// 15 + /// ); + /// assert_eq!( + /// counter.inc_if(&key, CounterComparator::Lt(10), 5).await?, + /// 15 + /// ); + /// assert_eq!( + /// counter.inc_if(&key, CounterComparator::Nil, 5).await?, + /// 20 + /// ); + /// # Ok(()) + /// # } + /// ``` + async fn inc_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result; /// Decrements the counter by `count` and returns the new total. /// @@ -36,11 +74,11 @@ pub trait CounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_counter().await?; - /// let key = RedisKey::try_from("tokens".to_string())?; + /// let key = DistkitRedisKey::try_from("tokens".to_string())?; /// counter.set(&key, 10).await?; /// assert_eq!(counter.dec(&key, 3).await?, 7); /// // Counters can go negative. @@ -48,7 +86,7 @@ pub trait CounterTrait { /// # Ok(()) /// # } /// ``` - async fn dec(&self, key: &RedisKey, count: i64) -> Result; + async fn dec(&self, key: &DistkitRedisKey, count: i64) -> Result; /// Returns the current value of the counter, or `0` if the key does not /// exist. @@ -56,11 +94,11 @@ pub trait CounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_counter().await?; - /// let key = RedisKey::try_from("visits".to_string())?; + /// let key = DistkitRedisKey::try_from("visits".to_string())?; /// // A key that does not exist returns 0. /// assert_eq!(counter.get(&key).await?, 0); /// counter.inc(&key, 5).await?; @@ -68,7 +106,7 @@ pub trait CounterTrait { /// # Ok(()) /// # } /// ``` - async fn get(&self, key: &RedisKey) -> Result; + async fn get(&self, key: &DistkitRedisKey) -> Result; /// Sets the counter to an exact value, overwriting any previous state. /// Returns the value that was set. @@ -76,11 +114,11 @@ pub trait CounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_counter().await?; - /// let key = RedisKey::try_from("inventory".to_string())?; + /// let key = DistkitRedisKey::try_from("inventory".to_string())?; /// counter.inc(&key, 1000).await?; /// // Overwrite with an authoritative count. /// assert_eq!(counter.set(&key, 850).await?, 850); @@ -88,7 +126,45 @@ pub trait CounterTrait { /// # Ok(()) /// # } /// ``` - async fn set(&self, key: &RedisKey, count: i64) -> Result; + async fn set(&self, key: &DistkitRedisKey, count: i64) -> Result; + + /// Conditionally sets the counter to `count` when the current value + /// satisfies `comparator`. + /// + /// Returns the value after evaluation: `count` when the write applied, or + /// the current total unchanged when the condition failed. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, counter::CounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_counter().await?; + /// let key = DistkitRedisKey::try_from("inventory".to_string())?; + /// counter.set(&key, 10).await?; + /// + /// assert_eq!( + /// counter.set_if(&key, CounterComparator::Gt(5), 25).await?, + /// 25 + /// ); + /// assert_eq!( + /// counter.set_if(&key, CounterComparator::Eq(10), 50).await?, + /// 25 + /// ); + /// assert_eq!( + /// counter.set_if(&key, CounterComparator::Nil, 40).await?, + /// 40 + /// ); + /// # Ok(()) + /// # } + /// ``` + async fn set_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result; /// Deletes the counter and returns the value it held before deletion. /// Returns `0` if the key did not exist. @@ -96,11 +172,11 @@ pub trait CounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_counter().await?; - /// let key = RedisKey::try_from("session".to_string())?; + /// let key = DistkitRedisKey::try_from("session".to_string())?; /// counter.set(&key, 42).await?; /// assert_eq!(counter.del(&key).await?, 42); /// // After deletion the key reads back as 0. @@ -110,19 +186,19 @@ pub trait CounterTrait { /// # Ok(()) /// # } /// ``` - async fn del(&self, key: &RedisKey) -> Result; + async fn del(&self, key: &DistkitRedisKey) -> Result; /// Removes all counters under the current prefix. /// /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_counter().await?; - /// let k1 = RedisKey::try_from("a".to_string())?; - /// let k2 = RedisKey::try_from("b".to_string())?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; /// counter.set(&k1, 10).await?; /// counter.set(&k2, 20).await?; /// counter.clear().await?; @@ -132,4 +208,114 @@ pub trait CounterTrait { /// # } /// ``` async fn clear(&self) -> Result<(), DistkitError>; + + /// Returns `(key, value)` for each key in `keys`, in the same order. + /// A missing key returns `(key, 0)`. + async fn get_all<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError>; + + /// Increments each `(key, delta)` pair and returns `(key, new_total)` in + /// the same order. + /// + /// Duplicate keys are processed sequentially in input order, so later + /// entries observe earlier same-call updates. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_counter().await?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; + /// + /// let results = counter.inc_all(&[(&k1, 3), (&k2, 5)]).await?; + /// + /// assert_eq!(results, vec![(&k1, 3), (&k2, 5)]); + /// # Ok(()) + /// # } + /// ``` + async fn inc_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError>; + + /// Conditionally increments each `(key, delta)` pair when the current + /// value satisfies the corresponding comparator. + /// + /// Each tuple is `(key, comparator, delta)`. Evaluation is per-item, + /// results preserve input order, and duplicate keys are processed + /// sequentially in input order. Use [`CounterComparator::Nil`] for + /// unconditional entries in a mixed batch. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, counter::CounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_counter().await?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; + /// counter.set(&k1, 10).await?; + /// + /// let results = counter + /// .inc_all_if(&[ + /// (&k1, CounterComparator::Eq(10), 5), + /// (&k2, CounterComparator::Nil, 2), + /// ]) + /// .await?; + /// + /// assert_eq!(results, vec![(&k1, 15), (&k2, 2)]); + /// # Ok(()) + /// # } + /// ``` + async fn inc_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError>; + + /// Sets each `(key, count)` pair and returns `(key, count)` in the same + /// order. Semantics match `set` for each individual key. + async fn set_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError>; + + /// Conditionally sets each `(key, count)` pair when the current value + /// satisfies the corresponding comparator. + /// + /// Each tuple is `(key, comparator, count)`. Evaluation is per-item and + /// results preserve input order. Use [`CounterComparator::Nil`] for + /// unconditional entries in a mixed batch. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, counter::CounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_counter().await?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; + /// counter.set(&k1, 10).await?; + /// + /// let results = counter + /// .set_all_if(&[ + /// (&k1, CounterComparator::Eq(10), 15), + /// (&k2, CounterComparator::Nil, 20), + /// ]) + /// .await?; + /// + /// assert_eq!(results, vec![(&k1, 15), (&k2, 20)]); + /// # Ok(()) + /// # } + /// ``` + async fn set_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError>; } diff --git a/src/counter/lax_counter.rs b/src/counter/lax_counter.rs index e789d88..8658ccd 100644 --- a/src/counter/lax_counter.rs +++ b/src/counter/lax_counter.rs @@ -1,4 +1,5 @@ use std::{ + collections::HashMap, ops::Deref, sync::{ Arc, Mutex, @@ -12,17 +13,19 @@ use redis::{Script, aio::ConnectionManager}; use tokio::time::Instant; use crate::{ - ActivityTracker, DistkitError, EPOCH_CHANGE_INTERVAL, RedisKey, RedisKeyGenerator, - RedisKeyGeneratorTypeKey, + ActivityTracker, CounterComparator, DistkitError, DistkitRedisKey, EPOCH_CHANGE_INTERVAL, + RedisKeyGenerator, RedisKeyGeneratorTypeKey, counter::{CounterError, CounterOptions, CounterTrait}, - mutex_lock, + execute_pipeline_with_script_retry, mutex_lock, }; +const MAX_BATCH_SIZE: usize = 100; + const GET_LUA: &str = r#" local container_key = KEYS[1] local key = KEYS[2] - return redis.call('HGET', container_key, key) or 0 + return {key, tonumber(redis.call('HGET', container_key, key)) or 0} "#; const COMMIT_STATE_LUA: &str = r#" @@ -51,7 +54,7 @@ const CLEAR_LUA: &str = r#" #[derive(Debug)] struct Commit { - key: RedisKey, + key: DistkitRedisKey, delta: i64, } @@ -79,8 +82,8 @@ struct SingleStore { pub struct LaxCounter { connection_manager: ConnectionManager, key_generator: RedisKeyGenerator, - store: DashMap, - locks: DashMap>>, + store: DashMap, + locks: DashMap>>, get_script: Script, allowed_lag: Duration, commit_state_script: Script, @@ -102,7 +105,7 @@ impl LaxCounter { /// # Examples /// /// ```rust - /// use distkit::{RedisKey, counter::{LaxCounter, CounterOptions}}; + /// use distkit::{DistkitRedisKey, counter::{LaxCounter, CounterOptions}}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -110,7 +113,7 @@ impl LaxCounter { /// .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string()); /// let client = redis::Client::open(redis_url)?; /// let conn = client.get_connection_manager().await?; - /// let prefix = RedisKey::try_from("my_app".to_string())?; + /// let prefix = DistkitRedisKey::try_from("my_app".to_string())?; /// let counter = LaxCounter::new(CounterOptions::new(prefix, conn)); /// // The background flush task is now running. /// # Ok(()) @@ -218,7 +221,7 @@ impl LaxCounter { }); } - if let Err(err) = counter.flush_to_redis(&mut batch, 100).await { + if let Err(err) = counter.flush_to_redis(&mut batch, MAX_BATCH_SIZE).await { tracing::error!("Failed to flush to redis: {err:?}"); continue; } @@ -255,55 +258,18 @@ impl LaxCounter { } // end method flush_to_redis async fn batch_commit_state(&self, commits: &[Commit]) -> Result<(), DistkitError> { - let mut connection_manager = self.connection_manager.clone(); - - let pipe = self.build_commit_pipeline(commits, false); - - let _: () = match pipe.query_async(&mut connection_manager).await { - Ok(results) => results, - Err(err) => { - if err.kind() != redis::ErrorKind::Server(redis::ServerErrorKind::NoScript) { - return Err(DistkitError::RedisError(err)); - } - - let pipe = self.build_commit_pipeline(commits, true); - - match pipe.query_async::<()>(&mut connection_manager).await { - Ok(results) => results, - Err(err) => { - return Err(DistkitError::RedisError(err)); - } - } - } - }; - - Ok(()) + let mut conn = self.connection_manager.clone(); + let script = &self.commit_state_script; + execute_pipeline_with_script_retry::<(), _, _>(&mut conn, script, commits, |commit| { + let mut inv = script.key(self.key_generator.container_key()); + inv.key(commit.key.as_str()); + inv.arg(commit.delta); + inv + }) + .await } // end method batch_commit_state - #[inline] - fn build_commit_pipeline( - &self, - commits: &[Commit], - should_load_script: bool, - ) -> redis::Pipeline { - let mut pipe = redis::Pipeline::new(); - if should_load_script { - pipe.load_script(&self.commit_state_script).ignore(); - } - - for commit in commits { - pipe.invoke_script( - self.commit_state_script - .key(self.key_generator.container_key()) - .key(commit.key.to_string()) - .arg(commit.delta), - ); - } - - pipe - } - - async fn ensure_valid_state(&self, key: &RedisKey) -> Result<(), DistkitError> { + async fn ensure_valid_state(&self, key: &DistkitRedisKey) -> Result<(), DistkitError> { let lock = self.get_or_create_lock(key).await; let _guard = lock.lock().await; @@ -320,10 +286,10 @@ impl LaxCounter { let mut conn = self.connection_manager.clone(); - let remote_total: i64 = self + let (_, remote_total): (String, i64) = self .get_script .key(self.key_generator.container_key()) - .key(key.to_string()) + .key(key.as_str()) .invoke_async(&mut conn) .await?; @@ -350,7 +316,7 @@ impl LaxCounter { Ok(()) } // end function get_remote_total - async fn get_or_create_lock(&self, key: &RedisKey) -> Arc> { + async fn get_or_create_lock(&self, key: &DistkitRedisKey) -> Arc> { if let Some(lock) = self.locks.get(key) { return lock.clone(); } @@ -360,6 +326,77 @@ impl LaxCounter { .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(()))) .clone() } + + /// Fetches stale/missing keys from Redis in a single pipeline, then updates + /// `self.store` with the fresh remote totals. + async fn batch_refresh_stale(&self, keys: &[&DistkitRedisKey]) -> Result<(), DistkitError> { + if keys.is_empty() { + return Ok(()); + } + + let mut stale_keys = Vec::with_capacity(keys.len()); + + for key in keys { + let Some(store) = self.store.get(*key) else { + stale_keys.push(*key); + continue; + }; + + if let Ok(last_flushed) = mutex_lock(&store.last_flushed, "last_flushed") + && let Some(last_flushed) = last_flushed.deref() + && last_flushed.elapsed() < self.allowed_lag + { + continue; + } + + stale_keys.push(*key); + } + + // To be honest, still contemplating whether to flush to redis here. + // I'd just flush for now to be safe + let mut batch = self.batch.lock().await; + self.flush_to_redis(&mut batch, MAX_BATCH_SIZE).await?; + + let mut conn = self.connection_manager.clone(); + let script = &self.get_script; + + let raw: Vec<(String, i64)> = + execute_pipeline_with_script_retry(&mut conn, script, &stale_keys, |key| { + let mut inv = script.key(self.key_generator.container_key()); + inv.key(key.as_str()); + inv + }) + .await?; + + let map: HashMap = raw.into_iter().collect(); + + for key in stale_keys { + let remote_total = map.get(key.as_str()).copied().unwrap_or(0); + + match self.store.get(key) { + Some(store) => { + store.remote_total.store(remote_total, Ordering::Release); + *mutex_lock(&store.last_updated, "last_updated")? = Instant::now(); + } + None => { + let value = self + .store + .entry((*key).clone()) + .or_insert_with(|| SingleStore { + remote_total: AtomicI64::new(remote_total), + delta: AtomicI64::new(0), + last_updated: Mutex::new(Instant::now()), + last_flushed: Mutex::new(None), + }); + + value.remote_total.store(remote_total, Ordering::Release); + *mutex_lock(&value.last_updated, "last_updated")? = Instant::now(); + } + } + } + + Ok(()) + } } #[async_trait::async_trait] @@ -371,11 +408,11 @@ impl CounterTrait for LaxCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_counter().await?; - /// let key = RedisKey::try_from("hits".to_string())?; + /// let key = DistkitRedisKey::try_from("hits".to_string())?; /// // All three calls are sub-microsecond; no Redis round-trip until flush. /// assert_eq!(counter.inc(&key, 1).await?, 1); /// assert_eq!(counter.inc(&key, 1).await?, 2); @@ -384,7 +421,16 @@ impl CounterTrait for LaxCounter { /// # Ok(()) /// # } /// ``` - async fn inc(&self, key: &RedisKey, count: i64) -> Result { + async fn inc(&self, key: &DistkitRedisKey, count: i64) -> Result { + self.inc_if(key, CounterComparator::Nil, count).await + } + + async fn inc_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result { self.activity.signal(); let store = match self.store.get(key) { @@ -408,16 +454,22 @@ impl CounterTrait for LaxCounter { } }; + let remote_total = store.remote_total.load(Ordering::Acquire); + + let current = remote_total + store.delta.load(Ordering::Acquire); + + if !comparator.matches(current) { + return Ok(current); + } + let prev_delta = if count > 0 { store.delta.fetch_add(count, Ordering::AcqRel) } else { store.delta.fetch_sub(count.abs(), Ordering::AcqRel) }; - let total = store.remote_total.load(Ordering::Acquire) + prev_delta + count; - - Ok(total) - } // end function inc + Ok(remote_total + prev_delta + count) + } /// Buffers `-count` locally and returns the updated local estimate without /// a Redis round-trip. Equivalent to `inc(key, -count)`. @@ -425,17 +477,17 @@ impl CounterTrait for LaxCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_counter().await?; - /// let key = RedisKey::try_from("tokens".to_string())?; + /// let key = DistkitRedisKey::try_from("tokens".to_string())?; /// counter.set(&key, 10).await?; /// assert_eq!(counter.dec(&key, 3).await?, 7); /// # Ok(()) /// # } /// ``` - async fn dec(&self, key: &RedisKey, count: i64) -> Result { + async fn dec(&self, key: &DistkitRedisKey, count: i64) -> Result { self.inc(key, -count).await } // end function dec @@ -450,18 +502,18 @@ impl CounterTrait for LaxCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_counter().await?; - /// let key = RedisKey::try_from("hits".to_string())?; + /// let key = DistkitRedisKey::try_from("hits".to_string())?; /// counter.inc(&key, 7).await?; /// // Returns remote_total (0) + pending_delta (7) = 7, no Redis round-trip. /// assert_eq!(counter.get(&key).await?, 7); /// # Ok(()) /// # } /// ``` - async fn get(&self, key: &RedisKey) -> Result { + async fn get(&self, key: &DistkitRedisKey) -> Result { self.activity.signal(); let store = match self.store.get(key) { Some(store) @@ -499,11 +551,11 @@ impl CounterTrait for LaxCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_counter().await?; - /// let key = RedisKey::try_from("inventory".to_string())?; + /// let key = DistkitRedisKey::try_from("inventory".to_string())?; /// counter.inc(&key, 1000).await?; /// // The write is buffered; this process sees the new value immediately. /// assert_eq!(counter.set(&key, 850).await?, 850); @@ -511,7 +563,16 @@ impl CounterTrait for LaxCounter { /// # Ok(()) /// # } /// ``` - async fn set(&self, key: &RedisKey, count: i64) -> Result { + async fn set(&self, key: &DistkitRedisKey, count: i64) -> Result { + self.set_if(key, CounterComparator::Nil, count).await + } + + async fn set_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result { self.activity.signal(); let store = match self.store.get(key) { Some(store) @@ -534,12 +595,17 @@ impl CounterTrait for LaxCounter { } }; - let total = store.remote_total.load(Ordering::Acquire); + let remote_total = store.remote_total.load(Ordering::Acquire); + let current = remote_total + store.delta.load(Ordering::Acquire); + + if !comparator.matches(current) { + return Ok(current); + } - store.delta.store(count - total, Ordering::Release); + store.delta.store(count - remote_total, Ordering::Release); Ok(count) - } // end function set + } /// Cancels any pending local delta for `key`, then immediately deletes /// it from Redis. Returns the final value, including the cancelled delta. @@ -551,11 +617,11 @@ impl CounterTrait for LaxCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_counter().await?; - /// let key = RedisKey::try_from("session".to_string())?; + /// let key = DistkitRedisKey::try_from("session".to_string())?; /// counter.inc(&key, 10).await?; // buffered, not yet in Redis /// // Pending delta (10) is cancelled; Redis is updated immediately. /// assert_eq!(counter.del(&key).await?, 10); @@ -563,7 +629,7 @@ impl CounterTrait for LaxCounter { /// # Ok(()) /// # } /// ``` - async fn del(&self, key: &RedisKey) -> Result { + async fn del(&self, key: &DistkitRedisKey) -> Result { self.activity.signal(); let lock = self.get_or_create_lock(key).await; @@ -583,7 +649,7 @@ impl CounterTrait for LaxCounter { let total: i64 = self .del_script .key(self.key_generator.container_key()) - .key(key.to_string()) + .key(key.as_str()) .invoke_async(&mut conn) .await?; @@ -598,12 +664,12 @@ impl CounterTrait for LaxCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, counter::CounterTrait}; + /// # use distkit::{DistkitRedisKey, counter::CounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_counter().await?; - /// let k1 = RedisKey::try_from("a".to_string())?; - /// let k2 = RedisKey::try_from("b".to_string())?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; /// counter.inc(&k1, 5).await?; /// counter.inc(&k2, 10).await?; /// counter.clear().await?; @@ -632,4 +698,117 @@ impl CounterTrait for LaxCounter { Ok(()) } // end function clear + + async fn get_all<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError> { + if keys.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + self.batch_refresh_stale(keys).await?; + + keys.iter() + .map(|key| { + let store = self.store.get(*key).expect("store populated after refresh"); + Ok(( + *key, + store.remote_total.load(Ordering::Acquire) + + store.delta.load(Ordering::Acquire), + )) + }) + .collect() + } // end function get_all + + async fn inc_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.inc_all_if(&conditional_updates).await + } + + async fn inc_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + let keys: Vec<&DistkitRedisKey> = updates.iter().map(|(key, _, _)| *key).collect(); + self.batch_refresh_stale(&keys).await?; + + updates + .iter() + .map(|(key, comparator, count)| { + let store = self.store.get(*key).expect("store populated after refresh"); + let remote_total = store.remote_total.load(Ordering::Acquire); + let current = remote_total + store.delta.load(Ordering::Acquire); + + if comparator.matches(current) { + let prev_delta = if *count > 0 { + store.delta.fetch_add(*count, Ordering::AcqRel) + } else { + store.delta.fetch_sub(count.abs(), Ordering::AcqRel) + }; + + Ok((*key, remote_total + prev_delta + *count)) + } else { + Ok((*key, current)) + } + }) + .collect() + } + + async fn set_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.set_all_if(&conditional_updates).await + } + + async fn set_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + let keys: Vec<&DistkitRedisKey> = updates.iter().map(|(key, _, _)| *key).collect(); + self.batch_refresh_stale(&keys).await?; + + updates + .iter() + .map(|(key, comparator, count)| { + let store = self.store.get(*key).expect("store populated after refresh"); + let remote_total = store.remote_total.load(Ordering::Acquire); + let current = remote_total + store.delta.load(Ordering::Acquire); + + if comparator.matches(current) { + store.delta.store(count - remote_total, Ordering::Release); + Ok((*key, *count)) + } else { + Ok((*key, current)) + } + }) + .collect() + } } // end impl CounterTrait for LaxCounter diff --git a/src/counter/mod.rs b/src/counter/mod.rs index 489e31f..f5c3ecf 100644 --- a/src/counter/mod.rs +++ b/src/counter/mod.rs @@ -19,7 +19,7 @@ pub use counter_trait::*; mod error; pub use error::*; -use crate::RedisKey; +use crate::DistkitRedisKey; #[cfg(test)] mod tests; @@ -32,7 +32,7 @@ mod tests; #[derive(Debug, Clone)] pub struct CounterOptions { /// Redis key prefix used to namespace all counter keys. - pub prefix: RedisKey, + pub prefix: DistkitRedisKey, /// Redis connection manager for executing commands. pub connection_manager: ConnectionManager, /// Maximum acceptable staleness for [`LaxCounter`] reads (default 20 ms). @@ -46,7 +46,7 @@ impl CounterOptions { /// # Examples /// /// ```rust - /// use distkit::{RedisKey, counter::CounterOptions}; + /// use distkit::{DistkitRedisKey, counter::CounterOptions}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -54,13 +54,13 @@ impl CounterOptions { /// .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string()); /// let client = redis::Client::open(redis_url)?; /// let conn = client.get_connection_manager().await?; - /// let prefix = RedisKey::try_from("my_app".to_string())?; + /// let prefix = DistkitRedisKey::try_from("my_app".to_string())?; /// let options = CounterOptions::new(prefix, conn); /// // options.allowed_lag == Duration::from_millis(20) /// # Ok(()) /// # } /// ``` - pub fn new(prefix: RedisKey, connection_manager: ConnectionManager) -> Self { + pub fn new(prefix: DistkitRedisKey, connection_manager: ConnectionManager) -> Self { Self { prefix, connection_manager, diff --git a/src/counter/strict_counter.rs b/src/counter/strict_counter.rs index 4434516..fbef673 100644 --- a/src/counter/strict_counter.rs +++ b/src/counter/strict_counter.rs @@ -1,16 +1,42 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use redis::{Script, aio::ConnectionManager}; use crate::{ - DistkitError, RedisKey, RedisKeyGenerator, RedisKeyGeneratorTypeKey, + CounterComparator, DistkitError, DistkitRedisKey, RedisKeyGenerator, RedisKeyGeneratorTypeKey, counter::{CounterOptions, CounterTrait}, + execute_pipeline_with_script_retry, }; +const HELPER_LUA: &str = r#" + local function compare_values(current, comparator, expected) + if comparator == 'nil' then + return true + elseif comparator == 'eq' then + return current == expected + elseif comparator == 'lt' then + return current < expected + elseif comparator == 'gt' then + return current > expected + elseif comparator == 'ne' then + return current ~= expected + end + + return false + end +"#; + const INC_LUA: &str = r#" local container_key = KEYS[1] local key = KEYS[2] - local count = tonumber(ARGV[1]) or 0 + local comparator = ARGV[1] + local compare_against = tonumber(ARGV[2]) or 0 + local count = tonumber(ARGV[3]) or 0 + + local current = tonumber(redis.call('HGET', container_key, key)) or 0 + if not compare_values(current, comparator, compare_against) then + return current + end return redis.call('HINCRBY', container_key, key, count) "#; @@ -18,18 +44,24 @@ const INC_LUA: &str = r#" const SET_LUA: &str = r#" local container_key = KEYS[1] local key = KEYS[2] - local count = tonumber(ARGV[1]) or 0 + local comparator = ARGV[1] + local compare_against = tonumber(ARGV[2]) or 0 + local count = tonumber(ARGV[3]) or 0 - redis.call('HSET', container_key, key, count) + local current = tonumber(redis.call('HGET', container_key, key)) or 0 + if not compare_values(current, comparator, compare_against) then + return {key, current} + end - return count + redis.call('HSET', container_key, key, count) + return {key, count} "#; const GET_LUA: &str = r#" local container_key = KEYS[1] local key = KEYS[2] - return redis.call('HGET', container_key, key) or 0 + return {key, tonumber(redis.call('HGET', container_key, key)) or 0} "#; const DEL_LUA: &str = r#" @@ -70,7 +102,7 @@ impl StrictCounter { /// # Examples /// /// ```rust - /// use distkit::{RedisKey, counter::{StrictCounter, CounterOptions}}; + /// use distkit::{DistkitRedisKey, counter::{StrictCounter, CounterOptions}}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -78,7 +110,7 @@ impl StrictCounter { /// .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string()); /// let client = redis::Client::open(redis_url)?; /// let conn = client.get_connection_manager().await?; - /// let prefix = RedisKey::try_from("my_app".to_string())?; + /// let prefix = DistkitRedisKey::try_from("my_app".to_string())?; /// let counter = StrictCounter::new(CounterOptions::new(prefix, conn)); /// # Ok(()) /// # } @@ -91,9 +123,9 @@ impl StrictCounter { } = options; let key_generator = RedisKeyGenerator::new(prefix, RedisKeyGeneratorTypeKey::Strict); - let inc_script = Script::new(INC_LUA); + let inc_script = Script::new(&format!("{HELPER_LUA}\n{INC_LUA}")); let get_script = Script::new(GET_LUA); - let set_script = Script::new(SET_LUA); + let set_script = Script::new(&format!("{HELPER_LUA}\n{SET_LUA}")); let del_script = Script::new(DEL_LUA); let clear_script = Script::new(CLEAR_LUA); @@ -111,58 +143,82 @@ impl StrictCounter { #[async_trait::async_trait] impl CounterTrait for StrictCounter { - async fn inc(&self, key: &RedisKey, count: i64) -> Result { + async fn inc(&self, key: &DistkitRedisKey, count: i64) -> Result { + self.inc_if(key, CounterComparator::Nil, count).await + } + + async fn inc_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result { let mut conn = self.connection_manager.clone(); + let (lua_comparator, compare_against) = comparator.as_lua_parts(); let total: i64 = self .inc_script .key(self.key_generator.container_key()) - .key(key.to_string()) + .key(key.as_str()) + .arg(lua_comparator) + .arg(compare_against) .arg(count) .invoke_async(&mut conn) .await?; Ok(total) - } // end function inc + } - async fn dec(&self, key: &RedisKey, count: i64) -> Result { + async fn dec(&self, key: &DistkitRedisKey, count: i64) -> Result { self.inc(key, -count).await - } // end function dec + } - async fn get(&self, key: &RedisKey) -> Result { + async fn get(&self, key: &DistkitRedisKey) -> Result { let mut conn = self.connection_manager.clone(); - let total: i64 = self + let (_, total): (String, i64) = self .get_script .key(self.key_generator.container_key()) - .key(key.to_string()) + .key(key.as_str()) .invoke_async(&mut conn) .await?; Ok(total) } // end function get - async fn set(&self, key: &RedisKey, count: i64) -> Result { + async fn set(&self, key: &DistkitRedisKey, count: i64) -> Result { + self.set_if(key, CounterComparator::Nil, count).await + } + + async fn set_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result { let mut conn = self.connection_manager.clone(); + let (lua_comparator, compare_against) = comparator.as_lua_parts(); - let total: i64 = self + let (_, total): (String, i64) = self .set_script .key(self.key_generator.container_key()) - .key(key.to_string()) + .key(key.as_str()) + .arg(lua_comparator) + .arg(compare_against) .arg(count) .invoke_async(&mut conn) .await?; Ok(total) - } // end function set + } - async fn del(&self, key: &RedisKey) -> Result { + async fn del(&self, key: &DistkitRedisKey) -> Result { let mut conn = self.connection_manager.clone(); let total: i64 = self .del_script .key(self.key_generator.container_key()) - .key(key.to_string()) + .key(key.as_str()) .invoke_async(&mut conn) .await?; @@ -180,4 +236,118 @@ impl CounterTrait for StrictCounter { Ok(()) } // end function clear + + async fn get_all<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError> { + if keys.is_empty() { + return Ok(vec![]); + } + + let mut conn = self.connection_manager.clone(); + let script = &self.get_script; + + let raw: Vec<(String, i64)> = + execute_pipeline_with_script_retry(&mut conn, script, keys, |key| { + let mut inv = script.key(self.key_generator.container_key()); + inv.key(key.as_str()); + inv + }) + .await?; + + let map: HashMap = raw.into_iter().collect(); + + Ok(keys + .iter() + .map(|k| (*k, map.get(k.as_str()).copied().unwrap_or(0))) + .collect()) + } // end function get_all + + async fn inc_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.inc_all_if(&conditional_updates).await + } + + async fn inc_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { + return Ok(vec![]); + } + + let mut conn = self.connection_manager.clone(); + let script = &self.inc_script; + + let raw: Vec = + execute_pipeline_with_script_retry(&mut conn, script, updates, |update| { + let (key, comparator, count) = update; + let (lua_comparator, compare_against) = comparator.as_lua_parts(); + let mut inv = script.key(self.key_generator.container_key()); + inv.key(key.as_str()); + inv.arg(lua_comparator); + inv.arg(compare_against); + inv.arg(*count); + inv + }) + .await?; + + Ok(updates + .iter() + .zip(raw.into_iter()) + .map(|((key, _, _), total)| (*key, total)) + .collect()) + } + + async fn set_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.set_all_if(&conditional_updates).await + } + + async fn set_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { + return Ok(vec![]); + } + + let mut conn = self.connection_manager.clone(); + let script = &self.set_script; + + let raw: Vec<(String, i64)> = + execute_pipeline_with_script_retry(&mut conn, script, updates, |update| { + let (key, comparator, count) = update; + let (lua_comparator, compare_against) = comparator.as_lua_parts(); + let mut inv = script.key(self.key_generator.container_key()); + inv.key(key.as_str()); + inv.arg(lua_comparator); + inv.arg(compare_against); + inv.arg(*count); + inv + }) + .await?; + + let map: HashMap = raw.into_iter().collect(); + + Ok(updates + .iter() + .map(|(k, _, _)| (*k, map.get(k.as_str()).copied().unwrap_or(0))) + .collect()) + } } // end impl CounterTrait for StrictCounter diff --git a/src/counter/tests/common.rs b/src/counter/tests/common.rs index 5c8a5a6..0dc5d2c 100644 --- a/src/counter/tests/common.rs +++ b/src/counter/tests/common.rs @@ -4,7 +4,10 @@ use std::time::{SystemTime, UNIX_EPOCH}; use redis::aio::ConnectionManager; -use crate::{RedisKey, counter::{CounterOptions, LaxCounter, StrictCounter}}; +use crate::{ + DistkitRedisKey, + counter::{CounterOptions, LaxCounter, StrictCounter}, +}; static RUN_ID: OnceLock = OnceLock::new(); @@ -18,8 +21,7 @@ fn run_id() -> u128 { } async fn make_connection() -> ConnectionManager { - let url = std::env::var("REDIS_URL") - .expect("REDIS_URL must be set — run via `make test`"); + let url = std::env::var("REDIS_URL").expect("REDIS_URL must be set — run via `make test`"); let client = redis::Client::open(url).expect("valid Redis URL"); client .get_connection_manager() @@ -30,15 +32,21 @@ async fn make_connection() -> ConnectionManager { pub async fn make_strict_counter(prefix: &str) -> Arc { let conn = make_connection().await; let unique_prefix = format!("{}_{}", run_id(), prefix); - StrictCounter::new(CounterOptions::new(RedisKey::from(unique_prefix), conn)) + StrictCounter::new(CounterOptions::new( + DistkitRedisKey::from(unique_prefix), + conn, + )) } pub async fn make_lax_counter(prefix: &str) -> Arc { let conn = make_connection().await; let unique_prefix = format!("{}_{}", run_id(), prefix); - LaxCounter::new(CounterOptions::new(RedisKey::from(unique_prefix), conn)) + LaxCounter::new(CounterOptions::new( + DistkitRedisKey::from(unique_prefix), + conn, + )) } -pub fn key(name: &str) -> RedisKey { - RedisKey::from(name.to_string()) +pub fn key(name: &str) -> DistkitRedisKey { + DistkitRedisKey::from(name.to_string()) } diff --git a/src/counter/tests/lax_counter.rs b/src/counter/tests/lax_counter.rs index 650c3dc..6fa7b49 100644 --- a/src/counter/tests/lax_counter.rs +++ b/src/counter/tests/lax_counter.rs @@ -1,4 +1,4 @@ -use crate::counter::CounterTrait; +use crate::{CounterComparator, counter::CounterTrait}; use super::common::{key, make_lax_counter}; @@ -159,9 +159,9 @@ async fn chained_inc_and_dec() { let k = key("score"); counter.inc(&k, 10).await.unwrap(); // 10 - counter.dec(&k, 3).await.unwrap(); // 7 - counter.inc(&k, 5).await.unwrap(); // 12 - counter.dec(&k, 2).await.unwrap(); // 10 + counter.dec(&k, 3).await.unwrap(); // 7 + counter.inc(&k, 5).await.unwrap(); // 12 + counter.dec(&k, 2).await.unwrap(); // 10 assert_eq!(counter.get(&k).await.unwrap(), 10); } @@ -424,3 +424,377 @@ async fn clear_then_inc_starts_fresh() { let result = counter.inc(&k, 1).await.unwrap(); assert_eq!(result, 1); } + +// --------------------------------------------------------------------------- +// Flush — additional operations +// --------------------------------------------------------------------------- + +/// set is buffered as a corrective delta and eventually reaches Redis. +#[tokio::test] +async fn set_is_eventually_visible_to_fresh_instance() { + let prefix = "lax_set_flush"; + let k = key("counter"); + + let counter = make_lax_counter(prefix).await; + counter.set(&k, 55).await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let fresh = make_lax_counter(prefix).await; + assert_eq!(fresh.get(&k).await.unwrap(), 55); +} + +/// dec is buffered and eventually visible to a fresh instance. +#[tokio::test] +async fn dec_is_eventually_visible_to_fresh_instance() { + let prefix = "lax_dec_flush"; + let k = key("counter"); + + // Seed a value and wait for it to commit to Redis. + let counter = make_lax_counter(prefix).await; + counter.set(&k, 100).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // Now decrement and let the negative delta flush. + counter.dec(&k, 30).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let fresh = make_lax_counter(prefix).await; + assert_eq!(fresh.get(&k).await.unwrap(), 70); +} + +/// set after a partial flush re-fetches the stale cache and produces the +/// correct corrective delta so the final Redis value matches the target. +#[tokio::test] +async fn set_after_partial_flush_is_correct() { + let prefix = "lax_set_after_flush"; + let k = key("counter"); + + let counter = make_lax_counter(prefix).await; + counter.inc(&k, 10).await.unwrap(); + + // Wait for the flush task to commit inc to Redis and let the cache expire. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // set re-fetches remote_total (now 10) and stores delta = 50 - 10 = 40. + counter.set(&k, 50).await.unwrap(); + assert_eq!(counter.get(&k).await.unwrap(), 50); + + // Wait for the corrective delta to flush. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let fresh = make_lax_counter(prefix).await; + assert_eq!(fresh.get(&k).await.unwrap(), 50); +} + +/// del after set (no prior inc) returns the value that was set. +#[tokio::test] +async fn del_after_set_returns_correct_value() { + let counter = make_lax_counter("lax_del_after_set").await; + let k = key("val"); + + counter.set(&k, 77).await.unwrap(); + let returned = counter.del(&k).await.unwrap(); + assert_eq!(returned, 77); +} + +// --------------------------------------------------------------------------- +// get_all / set_all +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn get_all_empty_returns_empty() { + let counter = make_lax_counter("lax_get_all_empty").await; + assert_eq!(counter.get_all(&[]).await.unwrap(), vec![]); +} + +#[tokio::test] +async fn get_all_unknown_keys_return_zero() { + let counter = make_lax_counter("lax_get_all_unknown").await; + let k1 = key("a"); + let k2 = key("b"); + assert_eq!( + counter.get_all(&[&k1, &k2]).await.unwrap(), + vec![(&k1, 0), (&k2, 0)] + ); +} + +#[tokio::test] +async fn get_all_returns_correct_values_after_inc() { + let counter = make_lax_counter("lax_get_all_after_inc").await; + let k1 = key("a"); + let k2 = key("b"); + counter.inc(&k1, 5).await.unwrap(); + counter.inc(&k2, 10).await.unwrap(); + assert_eq!( + counter.get_all(&[&k1, &k2]).await.unwrap(), + vec![(&k1, 5), (&k2, 10)] + ); +} + +#[tokio::test] +async fn get_all_preserves_input_order() { + let counter = make_lax_counter("lax_get_all_order").await; + let k1 = key("a"); + let k2 = key("b"); + let k3 = key("c"); + counter.inc(&k1, 1).await.unwrap(); + counter.inc(&k2, 2).await.unwrap(); + counter.inc(&k3, 3).await.unwrap(); + assert_eq!( + counter.get_all(&[&k3, &k1, &k2]).await.unwrap(), + vec![(&k3, 3), (&k1, 1), (&k2, 2)] + ); +} + +/// A fresh reader with no local cache must re-fetch from Redis and return the +/// value written by the first instance, not stale zeros. +#[tokio::test] +async fn get_all_fetches_stale_keys_from_redis() { + let prefix = "lax_get_all_stale"; + let k = key("counter"); + + let writer = make_lax_counter(prefix).await; + writer.inc(&k, 42).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let reader = make_lax_counter(prefix).await; + assert_eq!(reader.get_all(&[&k]).await.unwrap(), vec![(&k, 42)]); +} + +#[tokio::test] +async fn get_all_mixed_fresh_and_stale() { + let prefix = "lax_get_all_mixed"; + let k1 = key("a"); + let k2 = key("b"); + + let counter = make_lax_counter(prefix).await; + counter.inc(&k1, 7).await.unwrap(); + counter.inc(&k2, 13).await.unwrap(); + // Wait for flush + cache expiry so both keys are stale on a fresh instance. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let reader = make_lax_counter(prefix).await; + let results = reader.get_all(&[&k1, &k2]).await.unwrap(); + assert_eq!(results, vec![(&k1, 7), (&k2, 13)]); +} + +#[tokio::test] +async fn set_all_empty_returns_empty() { + let counter = make_lax_counter("lax_set_all_empty").await; + assert_eq!(counter.set_all(&[]).await.unwrap(), vec![]); +} + +#[tokio::test] +async fn set_all_returns_target_values() { + let counter = make_lax_counter("lax_set_all_returns").await; + let k1 = key("a"); + let k2 = key("b"); + let results = counter.set_all(&[(&k1, 10), (&k2, 20)]).await.unwrap(); + assert_eq!(results, vec![(&k1, 10), (&k2, 20)]); +} + +#[tokio::test] +async fn set_all_subsequent_get_all_is_consistent() { + let counter = make_lax_counter("lax_set_all_consistent").await; + let k1 = key("a"); + let k2 = key("b"); + counter.set_all(&[(&k1, 100), (&k2, 200)]).await.unwrap(); + assert_eq!( + counter.get_all(&[&k1, &k2]).await.unwrap(), + vec![(&k1, 100), (&k2, 200)] + ); +} + +/// set_all is eventually visible to a fresh reader after the flush interval. +#[tokio::test] +async fn set_all_is_eventually_flushed_to_redis() { + let prefix = "lax_set_all_flush"; + let k1 = key("a"); + let k2 = key("b"); + + let counter = make_lax_counter(prefix).await; + counter.set_all(&[(&k1, 55), (&k2, 77)]).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let reader = make_lax_counter(prefix).await; + assert_eq!( + reader.get_all(&[&k1, &k2]).await.unwrap(), + vec![(&k1, 55), (&k2, 77)] + ); +} + +#[tokio::test] +async fn set_all_on_new_keys_uses_zero_remote_total() { + let counter = make_lax_counter("lax_set_all_new_keys").await; + let k1 = key("x"); + let k2 = key("y"); + // Keys have never been written; remote_total is 0. delta = count - 0 = count. + let results = counter.set_all(&[(&k1, 30), (&k2, 40)]).await.unwrap(); + assert_eq!(results, vec![(&k1, 30), (&k2, 40)]); + assert_eq!( + counter.get_all(&[&k1, &k2]).await.unwrap(), + vec![(&k1, 30), (&k2, 40)] + ); +} + +#[tokio::test] +async fn set_all_preserves_input_order() { + let counter = make_lax_counter("lax_set_all_order").await; + let k1 = key("a"); + let k2 = key("b"); + let k3 = key("c"); + let results = counter + .set_all(&[(&k3, 30), (&k1, 10), (&k2, 20)]) + .await + .unwrap(); + assert_eq!(results, vec![(&k3, 30), (&k1, 10), (&k2, 20)]); +} + +#[tokio::test] +async fn inc_if_uses_all_comparators_against_local_view() { + let cases = [ + ("eq", CounterComparator::Eq(10), true), + ("lt", CounterComparator::Lt(11), true), + ("gt", CounterComparator::Gt(10), false), + ("ne", CounterComparator::Ne(9), true), + ("nil", CounterComparator::Nil, true), + ]; + + for (suffix, comparator, should_apply) in cases { + let counter = make_lax_counter(&format!("lax_inc_if_{suffix}")).await; + let k = key("conditional"); + counter.set(&k, 10).await.unwrap(); + + let result = counter.inc_if(&k, comparator, 2).await.unwrap(); + let expected = if should_apply { 12 } else { 10 }; + + assert_eq!(result, expected); + assert_eq!(counter.get(&k).await.unwrap(), expected); + } +} + +#[tokio::test] +async fn inc_if_refreshes_stale_value_before_comparing() { + let prefix = "lax_inc_if_refresh"; + let k = key("hits"); + + let writer = make_lax_counter(prefix).await; + writer.set(&k, 7).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let reader = make_lax_counter(prefix).await; + let result = reader + .inc_if(&k, CounterComparator::Eq(7), 3) + .await + .unwrap(); + + assert_eq!(result, 10); + assert_eq!(reader.get(&k).await.unwrap(), 10); +} + +#[tokio::test] +async fn inc_all_empty_and_inc_all_if_empty_return_empty() { + let counter = make_lax_counter("lax_inc_all_empty").await; + assert_eq!(counter.inc_all(&[]).await.unwrap(), vec![]); + assert_eq!(counter.inc_all_if(&[]).await.unwrap(), vec![]); +} + +#[tokio::test] +async fn inc_all_updates_local_view_immediately_and_supports_duplicates() { + let counter = make_lax_counter("lax_inc_all_duplicates").await; + let k = key("hits"); + + let results = counter.inc_all(&[(&k, 1), (&k, 2)]).await.unwrap(); + + assert_eq!(results, vec![(&k, 1), (&k, 3)]); + assert_eq!(counter.get(&k).await.unwrap(), 3); +} + +#[tokio::test] +async fn inc_all_if_refreshes_stale_values_and_processes_entries_sequentially() { + let prefix = "lax_inc_all_if_refresh"; + let k = key("hits"); + + let writer = make_lax_counter(prefix).await; + writer.set(&k, 0).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let reader = make_lax_counter(prefix).await; + let results = reader + .inc_all_if(&[ + (&k, CounterComparator::Eq(0), 1), + (&k, CounterComparator::Eq(1), 2), + (&k, CounterComparator::Gt(10), 5), + ]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k, 1), (&k, 3), (&k, 3)]); + assert_eq!(reader.get(&k).await.unwrap(), 3); +} + +#[tokio::test] +async fn inc_all_if_is_eventually_flushed_for_successful_updates() { + let prefix = "lax_inc_all_if_flush"; + let k1 = key("a"); + let k2 = key("b"); + + let counter = make_lax_counter(prefix).await; + let results = counter + .inc_all_if(&[ + (&k1, CounterComparator::Nil, 4), + (&k2, CounterComparator::Gt(0), 7), + ]) + .await + .unwrap(); + assert_eq!(results, vec![(&k1, 4), (&k2, 0)]); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let reader = make_lax_counter(prefix).await; + assert_eq!( + reader.get_all(&[&k1, &k2]).await.unwrap(), + vec![(&k1, 4), (&k2, 0)] + ); +} + +#[tokio::test] +async fn set_all_if_returns_current_values_for_failed_conditions() { + let counter = make_lax_counter("lax_set_all_if_failed").await; + let k = key("hits"); + + counter.inc(&k, 5).await.unwrap(); + let results = counter + .set_all_if(&[(&k, CounterComparator::Gt(10), 20)]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k, 5)]); + assert_eq!(counter.get(&k).await.unwrap(), 5); +} + +#[tokio::test] +async fn set_all_if_is_eventually_flushed_for_successful_updates() { + let prefix = "lax_set_all_if_flush"; + let k1 = key("a"); + let k2 = key("b"); + + let counter = make_lax_counter(prefix).await; + let results = counter + .set_all_if(&[ + (&k1, CounterComparator::Nil, 55), + (&k2, CounterComparator::Gt(0), 77), + ]) + .await + .unwrap(); + assert_eq!(results, vec![(&k1, 55), (&k2, 0)]); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let reader = make_lax_counter(prefix).await; + assert_eq!( + reader.get_all(&[&k1, &k2]).await.unwrap(), + vec![(&k1, 55), (&k2, 0)] + ); +} diff --git a/src/counter/tests/strict_counter.rs b/src/counter/tests/strict_counter.rs index 15e01ed..c3e624e 100644 --- a/src/counter/tests/strict_counter.rs +++ b/src/counter/tests/strict_counter.rs @@ -1,4 +1,4 @@ -use crate::counter::CounterTrait; +use crate::{CounterComparator, counter::CounterTrait}; use super::common::{key, make_strict_counter}; @@ -68,9 +68,9 @@ async fn chained_inc_and_dec() { let k = key("score"); counter.inc(&k, 10).await.unwrap(); // 10 - counter.dec(&k, 3).await.unwrap(); // 7 - counter.inc(&k, 5).await.unwrap(); // 12 - counter.dec(&k, 2).await.unwrap(); // 10 + counter.dec(&k, 3).await.unwrap(); // 7 + counter.inc(&k, 5).await.unwrap(); // 12 + counter.dec(&k, 2).await.unwrap(); // 10 let total = counter.get(&k).await.unwrap(); assert_eq!(total, 10); @@ -307,3 +307,116 @@ async fn clear_does_not_affect_other_prefixes() { assert_eq!(counter_b.get(&k).await.unwrap(), 99); } + +#[tokio::test] +async fn inc_if_uses_all_comparators() { + let cases = [ + ("eq", CounterComparator::Eq(10), true), + ("lt", CounterComparator::Lt(11), true), + ("gt", CounterComparator::Gt(10), false), + ("ne", CounterComparator::Ne(9), true), + ("nil", CounterComparator::Nil, true), + ]; + + for (suffix, comparator, should_apply) in cases { + let counter = make_strict_counter(&format!("test_inc_if_{suffix}")).await; + let k = key("conditional"); + counter.set(&k, 10).await.unwrap(); + + let result = counter.inc_if(&k, comparator, 2).await.unwrap(); + let expected = if should_apply { 12 } else { 10 }; + + assert_eq!(result, expected); + assert_eq!(counter.get(&k).await.unwrap(), expected); + } +} + +#[tokio::test] +async fn inc_all_empty_and_inc_all_if_empty_return_empty() { + let counter = make_strict_counter("test_inc_all_empty").await; + assert_eq!(counter.inc_all(&[]).await.unwrap(), vec![]); + assert_eq!(counter.inc_all_if(&[]).await.unwrap(), vec![]); +} + +#[tokio::test] +async fn inc_all_supports_duplicate_keys_sequentially() { + let counter = make_strict_counter("test_inc_all_duplicates").await; + let k = key("hits"); + + let results = counter.inc_all(&[(&k, 1), (&k, 2)]).await.unwrap(); + + assert_eq!(results, vec![(&k, 1), (&k, 3)]); + assert_eq!(counter.get(&k).await.unwrap(), 3); +} + +#[tokio::test] +async fn inc_all_if_supports_partial_success_missing_keys_and_duplicates() { + let counter = make_strict_counter("test_inc_all_if_ordered").await; + let k1 = key("a"); + let k2 = key("b"); + + counter.set(&k1, 0).await.unwrap(); + counter.set(&k2, 10).await.unwrap(); + + let results = counter + .inc_all_if(&[ + (&k1, CounterComparator::Eq(0), 1), + (&k1, CounterComparator::Eq(1), 2), + (&k2, CounterComparator::Gt(20), 5), + (&k2, CounterComparator::Nil, 3), + ]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k1, 1), (&k1, 3), (&k2, 10), (&k2, 13)]); + assert_eq!(counter.get(&k1).await.unwrap(), 3); + assert_eq!(counter.get(&k2).await.unwrap(), 13); +} + +#[tokio::test] +async fn set_if_uses_all_comparators() { + let cases = [ + ("eq", CounterComparator::Eq(10), true), + ("lt", CounterComparator::Lt(11), true), + ("gt", CounterComparator::Gt(10), false), + ("ne", CounterComparator::Ne(9), true), + ("nil", CounterComparator::Nil, true), + ]; + + for (suffix, comparator, should_apply) in cases { + let counter = make_strict_counter(&format!("test_set_if_{suffix}")).await; + let k = key("conditional"); + counter.set(&k, 10).await.unwrap(); + + let result = counter.set_if(&k, comparator, 99).await.unwrap(); + let expected = if should_apply { 99 } else { 10 }; + + assert_eq!(result, expected); + assert_eq!(counter.get(&k).await.unwrap(), expected); + } +} + +#[tokio::test] +async fn set_all_if_supports_partial_success_and_missing_keys() { + let counter = make_strict_counter("test_set_all_if_partial").await; + let k1 = key("a"); + let k2 = key("b"); + let k3 = key("c"); + + counter.set(&k1, 10).await.unwrap(); + counter.set(&k2, 20).await.unwrap(); + + let results = counter + .set_all_if(&[ + (&k3, CounterComparator::Nil, 30), + (&k1, CounterComparator::Gt(5), 11), + (&k2, CounterComparator::Lt(10), 99), + ]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k3, 30), (&k1, 11), (&k2, 20)]); + assert_eq!(counter.get(&k1).await.unwrap(), 11); + assert_eq!(counter.get(&k2).await.unwrap(), 20); + assert_eq!(counter.get(&k3).await.unwrap(), 30); +} diff --git a/src/error.rs b/src/error.rs index f0d9cdf..1f8fb56 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,8 +7,8 @@ use crate::counter::CounterError; /// Top-level error type for all distkit operations. #[derive(Debug, thiserror::Error, PartialEq)] pub enum DistkitError { - /// A [`RedisKey`](crate::RedisKey) failed validation (empty, too long, or - /// contains a colon). + /// A [`DistkitRedisKey`](crate::DistkitRedisKey) failed validation + /// (empty, too long, or contains a colon). #[error("Invalid Redis key: {0}")] InvalidRedisKey(String), /// A counter-specific error. See [`CounterError`]. diff --git a/src/icounter/lax_instance_aware_counter.rs b/src/icounter/lax_instance_aware_counter.rs index 60d1b49..5f98bc4 100644 --- a/src/icounter/lax_instance_aware_counter.rs +++ b/src/icounter/lax_instance_aware_counter.rs @@ -9,13 +9,13 @@ use std::sync::{ Arc, atomic::{AtomicI64, Ordering}, }; -use std::time::{Duration, Instant}; +use std::{collections::HashMap, time::Duration, time::Instant}; use dashmap::DashMap; use redis::aio::ConnectionManager; use crate::{ - ActivityTracker, EPOCH_CHANGE_INTERVAL, RedisKey, + ActivityTracker, CounterComparator, DistkitRedisKey, EPOCH_CHANGE_INTERVAL, common::mutex_lock, error::DistkitError, icounter::{ @@ -60,7 +60,7 @@ impl SingleStore { #[derive(Debug, Clone)] pub struct LaxInstanceAwareCounterOptions { /// Redis key prefix used to namespace all counter keys. - pub prefix: RedisKey, + pub prefix: DistkitRedisKey, /// Redis connection manager. pub connection_manager: ConnectionManager, /// Milliseconds without a heartbeat before an instance is considered dead. @@ -82,7 +82,7 @@ impl LaxInstanceAwareCounterOptions { /// # Examples /// /// ```rust - /// use distkit::{RedisKey, icounter::LaxInstanceAwareCounterOptions}; + /// use distkit::{DistkitRedisKey, icounter::LaxInstanceAwareCounterOptions}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -90,13 +90,13 @@ impl LaxInstanceAwareCounterOptions { /// .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string()); /// let client = redis::Client::open(redis_url)?; /// let conn = client.get_connection_manager().await?; - /// let prefix = RedisKey::try_from("my_app".to_string())?; + /// let prefix = DistkitRedisKey::try_from("my_app".to_string())?; /// let opts = LaxInstanceAwareCounterOptions::new(prefix, conn); /// assert_eq!(opts.dead_instance_threshold_ms, 30_000); /// # Ok(()) /// # } /// ``` - pub fn new(prefix: RedisKey, connection_manager: ConnectionManager) -> Self { + pub fn new(prefix: DistkitRedisKey, connection_manager: ConnectionManager) -> Self { Self { prefix, connection_manager, @@ -120,11 +120,12 @@ impl LaxInstanceAwareCounterOptions { #[derive(Debug)] pub struct LaxInstanceAwareCounter { strict: Arc, - local_store: DashMap, + local_store: DashMap, activity: Arc, flush_interval: Duration, allowed_lag: Duration, - reset_locks: DashMap>>, + reset_locks: DashMap>>, + pending_flushed: tokio::sync::Mutex>, } impl LaxInstanceAwareCounter { @@ -138,7 +139,7 @@ impl LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// use distkit::{RedisKey, icounter::{LaxInstanceAwareCounter, LaxInstanceAwareCounterOptions, InstanceAwareCounterTrait}}; + /// use distkit::{DistkitRedisKey, icounter::{LaxInstanceAwareCounter, LaxInstanceAwareCounterOptions, InstanceAwareCounterTrait}}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -146,7 +147,7 @@ impl LaxInstanceAwareCounter { /// .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string()); /// let client = redis::Client::open(redis_url)?; /// let conn = client.get_connection_manager().await?; - /// let prefix = RedisKey::try_from("my_app".to_string())?; + /// let prefix = DistkitRedisKey::try_from("my_app".to_string())?; /// let counter = LaxInstanceAwareCounter::new(LaxInstanceAwareCounterOptions::new(prefix, conn)); /// assert!(!counter.instance_id().is_empty()); /// # Ok(()) @@ -175,6 +176,7 @@ impl LaxInstanceAwareCounter { flush_interval, allowed_lag, reset_locks: DashMap::default(), + pending_flushed: tokio::sync::Mutex::new(Vec::new()), }); counter.run_flush_task(); @@ -194,8 +196,6 @@ impl LaxInstanceAwareCounter { let mut interval = tokio::time::interval(flush_interval); interval.tick().await; // skip first immediate tick - let mut pending: Vec<(RedisKey, i64)> = Vec::new(); - loop { let is_active = { let Some(counter) = weak.upgrade() else { break }; @@ -209,33 +209,36 @@ impl LaxInstanceAwareCounter { interval.tick().await; let Some(counter) = weak.upgrade() else { break }; + if let Err(err) = counter.flush().await { + tracing::error!("lax_icounter:flush_task: flush failed: {err}"); + } + } + }); + } - // Collect newly stale deltas (delta already swapped to 0 in local_store). - pending.extend(counter.collect_stale_mark_flushed()); + async fn flush(&self) -> Result<(), DistkitError> { + let mut pending = self.pending_flushed.lock().await; - if pending.is_empty() { - continue; - } + // Collect newly stale deltas (delta already swapped to 0 in local_store). + pending.extend(self.collect_stale_mark_flushed()); - let results = match counter.strict.inc_batch(&mut pending, MAX_BATCH_SIZE).await { - Ok(results) => results, - Err(err) => { - tracing::error!("lax_icounter:flush_task: inc_batch failed: {err}"); - continue; - } - }; + if pending.is_empty() { + return Ok(()); + } - for (key_str, cumulative, instance_count) in results { - if let Ok(key) = RedisKey::try_from(key_str) { - counter.update_local(&key, cumulative, instance_count); - } - } + let results = self.strict.inc_batch(&mut pending, MAX_BATCH_SIZE).await?; + + for (key_str, cumulative, instance_count) in results { + if let Ok(key) = DistkitRedisKey::try_from(key_str) { + self.update_local(&key, cumulative, instance_count); } - }); + } + + Ok(()) } /// Acquire (or create) the per-key reset lock and return an `Arc` to it. - fn get_or_create_reset_lock(&self, key: &RedisKey) -> Arc> { + fn get_or_create_reset_lock(&self, key: &DistkitRedisKey) -> Arc> { if let Some(lock) = self.reset_locks.get(key) { return Arc::clone(&lock); } @@ -246,7 +249,7 @@ impl LaxInstanceAwareCounter { .clone() } // end function get_or_create_reset_lock - fn collect_stale_mark_flushed(&self) -> Vec<(RedisKey, i64)> { + fn collect_stale_mark_flushed(&self) -> Vec<(DistkitRedisKey, i64)> { let now = Instant::now(); self.local_store .iter() @@ -273,7 +276,7 @@ impl LaxInstanceAwareCounter { /// Drains the pending delta for a single key by calling `strict.inc` /// directly. Used before epoch-bumping operations (`set`, `del`, etc.). - async fn flush_key(&self, key: &RedisKey) -> Result<(), DistkitError> { + async fn flush_key(&self, key: &DistkitRedisKey) -> Result<(), DistkitError> { let Some(store) = self.local_store.get(key) else { return Ok(()); }; @@ -294,7 +297,7 @@ impl LaxInstanceAwareCounter { /// Drains pending deltas for all keys regardless of staleness. /// Used by `clear` / `clear_on_instance` before delegating to strict. async fn flush_all_keys(&self) -> Result<(), DistkitError> { - let mut all: Vec<(RedisKey, i64)> = self + let mut all: Vec<(DistkitRedisKey, i64)> = self .local_store .iter() .filter_map(|store| { @@ -316,7 +319,7 @@ impl LaxInstanceAwareCounter { let results = self.strict.inc_batch(&mut all, MAX_BATCH_SIZE).await?; for (key_str, cumulative, instance_count) in results { - if let Ok(key) = RedisKey::try_from(key_str) { + if let Ok(key) = DistkitRedisKey::try_from(key_str) { self.update_local(&key, cumulative, instance_count); } } @@ -325,7 +328,7 @@ impl LaxInstanceAwareCounter { } /// Updates `local_store` with fresh values from the strict counter. - fn update_local(&self, key: &RedisKey, cumulative: i64, instance_count: i64) { + fn update_local(&self, key: &DistkitRedisKey, cumulative: i64, instance_count: i64) { match self.local_store.get(key) { Some(store) => { store.cumulative.store(cumulative, Ordering::Release); @@ -340,6 +343,59 @@ impl LaxInstanceAwareCounter { } } } // end function update_local + + async fn refresh_local_if_needed(&self, key: &DistkitRedisKey) -> Result<(), DistkitError> { + let lock = self.get_or_create_reset_lock(key); + let _guard = lock.lock().await; + + if let Some(store) = self.local_store.get(key) + && mutex_lock(&store.last_flush, "lax_icounter:last_flush")?.elapsed() + < self.allowed_lag + { + return Ok(()); + } + + let (cumulative, instance_count) = self.strict.get(key).await?; + + self.update_local(key, cumulative, instance_count); + + Ok(()) + } + + /// Fetches stale/missing keys from the strict counter in a single batched + /// round-trip, then updates `local_store` for each key. + async fn batch_refresh_stale(&self, keys: &[&DistkitRedisKey]) -> Result<(), DistkitError> { + if keys.is_empty() { + return Ok(()); + } + + let keys: Vec<&DistkitRedisKey> = keys + .iter() + .filter(|key| { + self.local_store + .get(*key) + .and_then(|s| { + mutex_lock(&s.last_flush, "lax_icounter:last_flush") + .ok() + .map(|g| g.elapsed() >= self.allowed_lag) + }) + .unwrap_or(true) + }) + .copied() + .collect(); + + if keys.is_empty() { + return Ok(()); + } + + let results = self.strict.get_batch(&keys).await?; + + for (key, cumulative, instance_count) in results { + self.update_local(key, cumulative, instance_count); + } + + Ok(()) + } } // --------------------------------------------------------------------------- @@ -353,7 +409,7 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_icounter().await?; @@ -376,11 +432,11 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_icounter().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// // All three calls are sub-microsecond; no Redis round-trip until flush. /// let (c1, s1) = counter.inc(&key, 1).await?; /// let (c2, s2) = counter.inc(&key, 1).await?; @@ -390,22 +446,36 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - async fn inc(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + async fn inc(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + self.inc_if(key, CounterComparator::Nil, count).await + } + + async fn inc_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result<(i64, i64), DistkitError> { self.activity.signal(); let store = match self.local_store.get(key) { - Some(store) => store, - None => { - let lock = self.get_or_create_reset_lock(key); - let _guard = lock.lock().await; + Some(store) + if mutex_lock(&store.last_flush, "lax_icounter:last_flush")?.elapsed() + < self.allowed_lag => + { + store + } + Some(store) => { + drop(store); - if !self.local_store.contains_key(key) { - let (cumulative, instance_count) = self.strict.get(key).await?; + self.refresh_local_if_needed(key).await?; - self.local_store - .entry(key.clone()) - .or_insert_with(|| SingleStore::new(cumulative, instance_count)); - } + self.local_store + .get(key) + .expect("key should be in local_store") + } + None => { + self.refresh_local_if_needed(key).await?; self.local_store .get(key) @@ -413,11 +483,20 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { } }; - let delta = store.delta.fetch_add(count, Ordering::AcqRel) + count; + let delta_before = store.delta.load(Ordering::Acquire); + let current = ( + store.cumulative.load(Ordering::Acquire) + delta_before, + store.instance_count.load(Ordering::Acquire) + delta_before, + ); + if !comparator.matches(current.0) { + return Ok(current); + } + + let delta_after = store.delta.fetch_add(count, Ordering::AcqRel) + count; let cumulative = store.cumulative.load(Ordering::Acquire); let instance_count = store.instance_count.load(Ordering::Acquire); - Ok((cumulative + delta, instance_count + delta)) + Ok((cumulative + delta_after, instance_count + delta_after)) } /// Decrements the counter locally. Equivalent to `inc(key, -count)`. @@ -427,11 +506,11 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_icounter().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// counter.inc(&key, 10).await?; /// let (cumulative, slice) = counter.dec(&key, 4).await?; /// assert_eq!(cumulative, 6); @@ -439,7 +518,7 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - async fn dec(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + async fn dec(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError> { self.inc(key, -count).await } @@ -452,11 +531,11 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_icounter().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// counter.inc(&key, 5).await?; // buffered locally /// // Pending delta flushed first; then strict.set takes over. /// let (cumulative, slice) = counter.set(&key, 100).await?; @@ -465,9 +544,22 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - async fn set(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> { - self.activity.signal(); + async fn set(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + self.set_if(key, CounterComparator::Nil, count).await + } + async fn set_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result<(i64, i64), DistkitError> { + let current = self.get(key).await?; + if !comparator.matches(current.0) { + return Ok(current); + } + + self.activity.signal(); self.flush_key(key).await?; let (cumulative, instance_count) = self.strict.set(key, count).await?; @@ -486,11 +578,11 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_lax_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 10).await?; /// server_b.inc(&key, 5).await?; /// // Adjusts server_a's local delta to reach 7; no epoch bump. @@ -504,24 +596,39 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// ``` async fn set_on_instance( &self, - key: &RedisKey, + key: &DistkitRedisKey, + count: i64, + ) -> Result<(i64, i64), DistkitError> { + self.set_on_instance_if(key, CounterComparator::Nil, count) + .await + } + + async fn set_on_instance_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, count: i64, ) -> Result<(i64, i64), DistkitError> { self.activity.signal(); let store = match self.local_store.get(key) { - Some(store) => store, - None => { - let lock = self.get_or_create_reset_lock(key); - let _guard = lock.lock().await; + Some(store) + if mutex_lock(&store.last_flush, "lax_icounter:last_flush")?.elapsed() + < self.allowed_lag => + { + store + } + Some(store) => { + drop(store); - if !self.local_store.contains_key(key) { - let (cumulative, instance_count) = self.strict.get(key).await?; + self.refresh_local_if_needed(key).await?; - self.local_store - .entry(key.clone()) - .or_insert_with(|| SingleStore::new(cumulative, instance_count)); - } + self.local_store + .get(key) + .expect("key should be in local_store") + } + None => { + self.refresh_local_if_needed(key).await?; self.local_store .get(key) @@ -529,9 +636,15 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { } }; + let delta = store.delta.load(Ordering::Acquire); + let cumulative = store.cumulative.load(Ordering::Acquire); let instance_count = store.instance_count.load(Ordering::Acquire); + let current = (cumulative + delta, instance_count + delta); + if !comparator.matches(current.1) { + return Ok(current); + } + store.delta.store(count - instance_count, Ordering::Release); - let cumulative = store.cumulative.load(Ordering::Acquire); Ok((cumulative - instance_count + count, count)) } @@ -545,11 +658,11 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_icounter().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// assert_eq!(counter.get(&key).await?, (0, 0)); /// counter.inc(&key, 5).await?; /// // Returns local estimate (buffered delta included). @@ -557,22 +670,27 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - async fn get(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> { + async fn get(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError> { self.activity.signal(); let store = match self.local_store.get(key) { - Some(store) => store, - None => { - let lock = self.get_or_create_reset_lock(key); - let _guard = lock.lock().await; + Some(store) + if mutex_lock(&store.last_flush, "lax_icounter:last_flush")?.elapsed() + < self.allowed_lag => + { + store + } + Some(store) => { + drop(store); - if !self.local_store.contains_key(key) { - let (cumulative, instance_count) = self.strict.get(key).await?; + self.refresh_local_if_needed(key).await?; - self.local_store - .entry(key.clone()) - .or_insert_with(|| SingleStore::new(cumulative, instance_count)); - } + self.local_store + .get(key) + .expect("key should be in local_store") + } + None => { + self.refresh_local_if_needed(key).await?; self.local_store .get(key) @@ -595,11 +713,11 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_icounter().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// counter.inc(&key, 5).await?; // buffered /// let (old_cumulative, _) = counter.del(&key).await?; /// assert_eq!(old_cumulative, 5); @@ -607,7 +725,7 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - async fn del(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> { + async fn del(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError> { self.activity.signal(); self.flush_key(key).await?; let result = self.strict.del(key).await?; @@ -623,11 +741,11 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_lax_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 3).await?; /// server_b.inc(&key, 7).await?; /// // Flush server_a's pending delta, then remove only its slice. @@ -640,7 +758,7 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - async fn del_on_instance(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> { + async fn del_on_instance(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError> { self.activity.signal(); self.flush_key(key).await?; let result = self.strict.del_on_instance(key).await?; @@ -654,12 +772,12 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::lax_icounter().await?; - /// let k1 = RedisKey::try_from("a".to_string())?; - /// let k2 = RedisKey::try_from("b".to_string())?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; /// counter.inc(&k1, 10).await?; /// counter.inc(&k2, 20).await?; /// counter.clear().await?; @@ -682,11 +800,11 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_lax_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 3).await?; /// server_b.inc(&key, 7).await?; /// // Flush + remove only server_a's contributions; server_b's slice survives. @@ -702,4 +820,220 @@ impl InstanceAwareCounterTrait for LaxInstanceAwareCounter { self.local_store.clear(); Ok(()) } + + async fn get_all<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError> { + if keys.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + self.batch_refresh_stale(keys).await?; + + keys.iter() + .map(|key| { + let store = self + .local_store + .get(*key) + .expect("store populated after refresh"); + let delta = store.delta.load(Ordering::Acquire); + Ok(( + *key, + store.cumulative.load(Ordering::Acquire) + delta, + store.instance_count.load(Ordering::Acquire) + delta, + )) + }) + .collect() + } // end function get_all + + async fn get_all_on_instance<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError> { + self.batch_refresh_stale(keys).await?; + + Ok(keys + .iter() + .map(|key| { + let val = self + .local_store + .get(*key) + .map(|s| { + s.instance_count.load(Ordering::Acquire) + s.delta.load(Ordering::Acquire) + }) + .unwrap_or(0); + (*key, val) + }) + .collect()) + } // end function get_all_on_instance + + async fn inc_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.inc_all_if(&conditional_updates).await + } + + async fn inc_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + let keys: Vec<&DistkitRedisKey> = updates.iter().map(|(key, _, _)| *key).collect(); + self.batch_refresh_stale(&keys).await?; + + updates + .iter() + .map(|(key, comparator, count)| { + let store = self + .local_store + .get(*key) + .expect("store populated after refresh"); + let delta_before = store.delta.load(Ordering::Acquire); + let current = ( + store.cumulative.load(Ordering::Acquire) + delta_before, + store.instance_count.load(Ordering::Acquire) + delta_before, + ); + + if comparator.matches(current.0) { + let delta_after = store.delta.fetch_add(*count, Ordering::AcqRel) + *count; + let cumulative = store.cumulative.load(Ordering::Acquire); + let instance_count = store.instance_count.load(Ordering::Acquire); + Ok((*key, cumulative + delta_after, instance_count + delta_after)) + } else { + Ok((*key, current.0, current.1)) + } + }) + .collect() + } + + async fn set_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.set_all_if(&conditional_updates).await + } + + async fn set_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + let keys: Vec<&DistkitRedisKey> = updates.iter().map(|(key, _, _)| *key).collect(); + self.batch_refresh_stale(&keys).await?; + + let mut current_map: HashMap = + HashMap::with_capacity(updates.len()); + let mut matched_updates: Vec<(&DistkitRedisKey, i64)> = Vec::new(); + + for (key, comparator, count) in updates { + let store = self + .local_store + .get(*key) + .expect("store populated after refresh"); + let delta = store.delta.load(Ordering::Acquire); + let current = ( + store.cumulative.load(Ordering::Acquire) + delta, + store.instance_count.load(Ordering::Acquire) + delta, + ); + current_map.insert((*key).clone(), current); + + if comparator.matches(current.0) { + matched_updates.push((*key, *count)); + } + } + + let mut applied_map: HashMap = + HashMap::with_capacity(matched_updates.len()); + if !matched_updates.is_empty() { + self.flush().await?; + let batch = self.strict.set_batch(&matched_updates).await?; + for (key, cumulative, instance_count) in &batch { + self.update_local(key, *cumulative, *instance_count); + applied_map.insert((*key).clone(), (*cumulative, *instance_count)); + } + } + + Ok(updates + .iter() + .map(|(key, _, _)| { + let (cumulative, instance_count) = applied_map + .get(*key) + .copied() + .or_else(|| current_map.get(*key).copied()) + .unwrap_or((0, 0)); + (*key, cumulative, instance_count) + }) + .collect()) + } // end function set_all_if + + async fn set_all_on_instance<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.set_all_on_instance_if(&conditional_updates).await + } + + async fn set_all_on_instance_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + let keys: Vec<&DistkitRedisKey> = updates.iter().map(|(key, _, _)| *key).collect(); + self.batch_refresh_stale(&keys).await?; + + updates + .iter() + .map(|(key, comparator, count)| { + let store = self + .local_store + .get(*key) + .expect("store populated after refresh"); + let delta = store.delta.load(Ordering::Acquire); + let cumulative = store.cumulative.load(Ordering::Acquire); + let instance_count = store.instance_count.load(Ordering::Acquire); + let current = (cumulative + delta, instance_count + delta); + + if comparator.matches(current.1) { + store.delta.store(count - instance_count, Ordering::Release); + Ok((*key, cumulative - instance_count + count, *count)) + } else { + Ok((*key, current.0, current.1)) + } + }) + .collect() + } // end function set_all_on_instance_if } diff --git a/src/icounter/mod.rs b/src/icounter/mod.rs index 8861850..9683e6f 100644 --- a/src/icounter/mod.rs +++ b/src/icounter/mod.rs @@ -14,7 +14,7 @@ mod lax_instance_aware_counter; pub use lax_instance_aware_counter::*; use uuid::Uuid; -use crate::{DistkitError, RedisKey}; +use crate::{CounterComparator, DistkitError, DistkitRedisKey}; // --------------------------------------------------------------------------- // Trait @@ -35,7 +35,7 @@ pub trait InstanceAwareCounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; @@ -54,11 +54,11 @@ pub trait InstanceAwareCounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// let (cumulative_a, slice_a) = server_a.inc(&key, 3).await?; /// assert_eq!(cumulative_a, 3); /// assert_eq!(slice_a, 3); @@ -68,7 +68,46 @@ pub trait InstanceAwareCounterTrait { /// # Ok(()) /// # } /// ``` - async fn inc(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError>; + async fn inc(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError>; + + /// Conditionally increments this instance's contribution for `key` by + /// `count` when the cumulative total satisfies `comparator`. + /// + /// Returns `(cumulative, instance_count)` after evaluation. If the + /// condition fails, the returned values reflect the current state and no + /// increment is applied. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, icounter::InstanceAwareCounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; + /// counter.set(&key, 10).await?; + /// + /// assert_eq!( + /// counter.inc_if(&key, CounterComparator::Eq(10), 5).await?, + /// (15, 15) + /// ); + /// assert_eq!( + /// counter.inc_if(&key, CounterComparator::Lt(10), 5).await?, + /// (15, 15) + /// ); + /// assert_eq!( + /// counter.inc_if(&key, CounterComparator::Nil, 5).await?, + /// (20, 20) + /// ); + /// # Ok(()) + /// # } + /// ``` + async fn inc_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result<(i64, i64), DistkitError>; /// Decrements the counter for `key` by `count` (stale-aware). /// @@ -77,11 +116,11 @@ pub trait InstanceAwareCounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// counter.inc(&key, 10).await?; /// let (cumulative, slice) = counter.dec(&key, 4).await?; /// assert_eq!(cumulative, 6); @@ -89,7 +128,7 @@ pub trait InstanceAwareCounterTrait { /// # Ok(()) /// # } /// ``` - async fn dec(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError>; + async fn dec(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError>; /// Sets the cumulative total for `key` to `count`, bumping the epoch. /// @@ -100,11 +139,11 @@ pub trait InstanceAwareCounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 10).await?; /// server_b.inc(&key, 5).await?; /// // Epoch bumps; all previous per-instance contributions are cleared. @@ -114,7 +153,45 @@ pub trait InstanceAwareCounterTrait { /// # Ok(()) /// # } /// ``` - async fn set(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError>; + async fn set(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError>; + + /// Conditionally sets the cumulative total for `key` to `count` when the + /// cumulative total satisfies `comparator`. + /// + /// Returns `(cumulative, instance_count)` after evaluation. If the + /// condition fails, the returned values reflect the current state. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, icounter::InstanceAwareCounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; + /// counter.set(&key, 10).await?; + /// + /// assert_eq!( + /// counter.set_if(&key, CounterComparator::Gt(5), 40).await?, + /// (40, 40) + /// ); + /// assert_eq!( + /// counter.set_if(&key, CounterComparator::Eq(10), 99).await?, + /// (40, 40) + /// ); + /// assert_eq!( + /// counter.set_if(&key, CounterComparator::Nil, 12).await?, + /// (12, 12) + /// ); + /// # Ok(()) + /// # } + /// ``` + async fn set_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result<(i64, i64), DistkitError>; /// Sets only this instance's contribution for `key` to `count`, without /// bumping the epoch. @@ -124,11 +201,11 @@ pub trait InstanceAwareCounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 10).await?; /// server_b.inc(&key, 5).await?; /// // No epoch bump: server_b's slice is not evicted. @@ -140,7 +217,52 @@ pub trait InstanceAwareCounterTrait { /// ``` async fn set_on_instance( &self, - key: &RedisKey, + key: &DistkitRedisKey, + count: i64, + ) -> Result<(i64, i64), DistkitError>; + + /// Conditionally sets this instance's contribution for `key` to `count` + /// when the current instance slice satisfies `comparator`. + /// + /// Returns `(cumulative, instance_count)` after evaluation. If the + /// condition fails, the returned values reflect the current state. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, icounter::InstanceAwareCounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; + /// server_a.set_on_instance(&key, 7).await?; + /// server_b.set_on_instance(&key, 5).await?; + /// + /// assert_eq!( + /// server_a + /// .set_on_instance_if(&key, CounterComparator::Eq(7), 9) + /// .await?, + /// (14, 9) + /// ); + /// assert_eq!( + /// server_a + /// .set_on_instance_if(&key, CounterComparator::Gt(10), 50) + /// .await?, + /// (14, 9) + /// ); + /// assert_eq!( + /// server_a + /// .set_on_instance_if(&key, CounterComparator::Nil, 11) + /// .await?, + /// (16, 11) + /// ); + /// # Ok(()) + /// # } + /// ``` + async fn set_on_instance_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, count: i64, ) -> Result<(i64, i64), DistkitError>; @@ -152,11 +274,11 @@ pub trait InstanceAwareCounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// // A missing key returns (0, 0). /// assert_eq!(counter.get(&key).await?, (0, 0)); /// counter.inc(&key, 5).await?; @@ -164,7 +286,7 @@ pub trait InstanceAwareCounterTrait { /// # Ok(()) /// # } /// ``` - async fn get(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError>; + async fn get(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError>; /// Deletes `key` globally, bumping the epoch to invalidate all instances. /// @@ -174,11 +296,11 @@ pub trait InstanceAwareCounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 3).await?; /// server_b.inc(&key, 7).await?; /// let (old_cumulative, _) = server_a.del(&key).await?; @@ -188,7 +310,7 @@ pub trait InstanceAwareCounterTrait { /// # Ok(()) /// # } /// ``` - async fn del(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError>; + async fn del(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError>; /// Removes only this instance's contribution for `key`, without bumping /// the epoch. @@ -199,11 +321,11 @@ pub trait InstanceAwareCounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 3).await?; /// server_b.inc(&key, 7).await?; /// // Only server_a's slice is removed; server_b is unaffected. @@ -213,19 +335,19 @@ pub trait InstanceAwareCounterTrait { /// # Ok(()) /// # } /// ``` - async fn del_on_instance(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError>; + async fn del_on_instance(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError>; /// Clears all keys and all instance state from Redis. /// /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; - /// let k1 = RedisKey::try_from("a".to_string())?; - /// let k2 = RedisKey::try_from("b".to_string())?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; /// counter.inc(&k1, 10).await?; /// counter.inc(&k2, 20).await?; /// counter.clear().await?; @@ -242,11 +364,11 @@ pub trait InstanceAwareCounterTrait { /// # Examples /// /// ```rust - /// # use distkit::{RedisKey, icounter::InstanceAwareCounterTrait}; + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 3).await?; /// server_b.inc(&key, 7).await?; /// // Only server_a's contributions are removed; server_b's slice survives. @@ -256,6 +378,167 @@ pub trait InstanceAwareCounterTrait { /// # } /// ``` async fn clear_on_instance(&self) -> Result<(), DistkitError>; + + /// Returns `(key, cumulative, instance_count)` for each key in `keys`, in + /// the same order. A missing key returns `(key, 0, 0)`. + async fn get_all<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError>; + + /// Returns `(key, instance_count)` for each key in `keys`, in the same + /// order. Pure-local: no Redis round-trip, no staleness check. A key + /// with no local contribution returns `(key, 0)`. + async fn get_all_on_instance<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError>; + + /// Increments each `(key, delta)` pair for this instance and returns + /// `(key, cumulative, instance_count)` in the same order. + /// + /// Duplicate keys are processed sequentially in input order, so later + /// entries observe earlier same-call updates. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{DistkitRedisKey, icounter::InstanceAwareCounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; + /// + /// let results = counter.inc_all(&[(&k1, 3), (&k2, 5)]).await?; + /// + /// assert_eq!(results, vec![(&k1, 3, 3), (&k2, 5, 5)]); + /// # Ok(()) + /// # } + /// ``` + async fn inc_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError>; + + /// Conditionally increments each `(key, delta)` pair when the cumulative + /// total satisfies the corresponding comparator. + /// + /// Each tuple is `(key, comparator, delta)`. Evaluation is per-item, + /// results preserve input order, and duplicate keys are processed + /// sequentially in input order. Use [`CounterComparator::Nil`] for + /// unconditional entries in a mixed batch. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, icounter::InstanceAwareCounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; + /// counter.set(&k1, 10).await?; + /// + /// let results = counter + /// .inc_all_if(&[ + /// (&k1, CounterComparator::Eq(10), 5), + /// (&k2, CounterComparator::Nil, 2), + /// ]) + /// .await?; + /// + /// assert_eq!(results, vec![(&k1, 15, 15), (&k2, 2, 2)]); + /// # Ok(()) + /// # } + /// ``` + async fn inc_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError>; + + /// Sets each `(key, count)` pair globally, bumping the epoch. Semantics + /// match `set` for each individual key. Returns `(key, cumulative, instance_count)` + /// in the same order. + async fn set_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError>; + + /// Conditionally sets each `(key, count)` pair globally when the + /// cumulative total satisfies the corresponding comparator. + /// + /// Each tuple is `(key, comparator, count)`. Evaluation is per-item and + /// results preserve input order. Use [`CounterComparator::Nil`] for + /// unconditional entries in a mixed batch. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, icounter::InstanceAwareCounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; + /// counter.set(&k1, 10).await?; + /// + /// let results = counter + /// .set_all_if(&[ + /// (&k1, CounterComparator::Eq(10), 15), + /// (&k2, CounterComparator::Nil, 20), + /// ]) + /// .await?; + /// + /// assert_eq!(results, vec![(&k1, 15, 15), (&k2, 20, 20)]); + /// # Ok(()) + /// # } + /// ``` + async fn set_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError>; + + /// Sets this instance's contribution for each `(key, count)` pair without + /// bumping the epoch. Other instances' slices are preserved. Returns + /// `(key, cumulative, instance_count)` in the same order. + async fn set_all_on_instance<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError>; + + /// Conditionally sets this instance's contribution for each `(key, count)` + /// pair when the current instance slice satisfies the corresponding + /// comparator. + /// + /// Each tuple is `(key, comparator, count)`. Evaluation is per-item and + /// results preserve input order. Use [`CounterComparator::Nil`] for + /// unconditional entries in a mixed batch. + /// + /// # Examples + /// + /// ```rust + /// # use distkit::{CounterComparator, DistkitRedisKey, icounter::InstanceAwareCounterTrait}; + /// # #[tokio::main] + /// # async fn main() -> Result<(), Box> { + /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; + /// + /// let results = counter + /// .set_all_on_instance_if(&[ + /// (&k1, CounterComparator::Nil, 5), + /// (&k2, CounterComparator::Eq(0), 7), + /// ]) + /// .await?; + /// + /// assert_eq!(results, vec![(&k1, 5, 5), (&k2, 7, 7)]); + /// # Ok(()) + /// # } + /// ``` + async fn set_all_on_instance_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError>; } // --------------------------------------------------------------------------- diff --git a/src/icounter/strict_instance_aware_counter.rs b/src/icounter/strict_instance_aware_counter.rs index 8b35347..dea4002 100644 --- a/src/icounter/strict_instance_aware_counter.rs +++ b/src/icounter/strict_instance_aware_counter.rs @@ -5,17 +5,22 @@ //! key, contributing to a shared cumulative total. When an instance stops //! sending heartbeats, its contribution is automatically removed. -use std::sync::{ - Arc, - atomic::{AtomicI64, AtomicU64, Ordering}, +use std::{ + collections::{HashMap, HashSet}, + sync::{ + Arc, + atomic::{AtomicI64, AtomicU64, Ordering}, + }, }; use dashmap::DashMap; use redis::{Script, aio::ConnectionManager}; use crate::{ - ActivityTracker, EPOCH_CHANGE_INTERVAL, RedisKey, RedisKeyGenerator, RedisKeyGeneratorTypeKey, + ActivityTracker, CounterComparator, DistkitRedisKey, EPOCH_CHANGE_INTERVAL, RedisKeyGenerator, + RedisKeyGeneratorTypeKey, error::DistkitError, + execute_pipeline_with_script_retry, icounter::{InstanceAwareCounterTrait, generate_instance_id}, }; @@ -43,6 +48,8 @@ impl SingleStore { } } +const MAX_BATCH_SIZE: usize = 100; + // --------------------------------------------------------------------------- // Lua helpers — prepended to all scripts except `clear` // --------------------------------------------------------------------------- @@ -84,6 +91,22 @@ local function check_and_zadd(instances_key, instance_id, ts) redis.call('ZADD', instances_key, ts, instance_id) return created end + +local function compare_values(current, comparator, expected) + if comparator == 'nil' then + return true + elseif comparator == 'eq' then + return current == expected + elseif comparator == 'lt' then + return current < expected + elseif comparator == 'gt' then + return current > expected + elseif comparator == 'ne' then + return current ~= expected + end + + return false +end "#; // --------------------------------------------------------------------------- @@ -98,11 +121,14 @@ local keys_key = KEYS[4] local inst_count_key = KEYS[5] local counter_key = ARGV[1] -local delta = tonumber(ARGV[2]) -local local_epoch = tonumber(ARGV[3]) -local dead_threshold = tonumber(ARGV[4]) -local prefix = ARGV[5] -local instance_id = ARGV[6] +local comparator = ARGV[2] +local compare_against = tonumber(ARGV[3]) +local delta = tonumber(ARGV[4]) +local local_epoch = tonumber(ARGV[5]) +local local_count = tonumber(ARGV[6]) or 0 +local dead_threshold = tonumber(ARGV[7]) +local prefix = ARGV[8] +local instance_id = ARGV[9] local ts = now_ms() local instance_created = check_and_zadd(instances_key, instance_id, ts) @@ -110,6 +136,19 @@ delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_thre local redis_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0 local is_stale = (local_epoch ~= redis_epoch) +local cumulative = tonumber(redis.call('HGET', cumulative_key, counter_key) or 0) or 0 +local inst_count = tonumber(redis.call('HGET', inst_count_key, counter_key) or 0) or 0 + +if instance_created ~= 0 and not is_stale and local_count > 0 then + redis.call('HSET', inst_count_key, counter_key, local_count) + cumulative = tonumber(redis.call('HINCRBY', cumulative_key, counter_key, local_count)) + inst_count = local_count + redis.call('SADD', keys_key, counter_key) +end + +if not compare_values(cumulative, comparator, compare_against) then + return {counter_key, cumulative, inst_count, redis_epoch, instance_created, 0} +end local new_inst_count if is_stale then @@ -122,7 +161,7 @@ end local new_cumulative = tonumber(redis.call('HINCRBY', cumulative_key, counter_key, delta)) redis.call('SADD', keys_key, counter_key) -return {counter_key, new_cumulative, new_inst_count, redis_epoch, instance_created} +return {counter_key, new_cumulative, new_inst_count, redis_epoch, instance_created, 1} "#; const SET_LUA: &str = r#" @@ -133,19 +172,37 @@ local keys_key = KEYS[4] local inst_count_key = KEYS[5] local counter_key = ARGV[1] -local count = tonumber(ARGV[2]) -local local_epoch = tonumber(ARGV[3]) -local dead_threshold = tonumber(ARGV[4]) -local prefix = ARGV[5] -local instance_id = ARGV[6] -local max_epoch = tonumber(ARGV[7]) +local comparator = ARGV[2] +local compare_against = tonumber(ARGV[3]) +local count = tonumber(ARGV[4]) +local local_epoch = tonumber(ARGV[5]) +local local_count = tonumber(ARGV[6]) or 0 +local dead_threshold = tonumber(ARGV[7]) +local prefix = ARGV[8] +local instance_id = ARGV[9] +local max_epoch = tonumber(ARGV[10]) local ts = now_ms() local instance_created = check_and_zadd(instances_key, instance_id, ts) delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_threshold, ts) -local old_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0 -local new_epoch = old_epoch + 1 +local redis_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0 +local cumulative = tonumber(redis.call('HGET', cumulative_key, counter_key) or 0) or 0 +local inst_count = tonumber(redis.call('HGET', inst_count_key, counter_key) or 0) or 0 +local is_stale = (local_epoch ~= redis_epoch) + +if instance_created ~= 0 and not is_stale and local_count > 0 then + redis.call('HSET', inst_count_key, counter_key, local_count) + cumulative = tonumber(redis.call('HINCRBY', cumulative_key, counter_key, local_count)) + inst_count = local_count + redis.call('SADD', keys_key, counter_key) +end + +if not compare_values(cumulative, comparator, compare_against) then + return {counter_key, cumulative, inst_count, redis_epoch, instance_created, 0} +end + +local new_epoch = redis_epoch + 1 if new_epoch > max_epoch then new_epoch = 0 end @@ -155,7 +212,7 @@ redis.call('HSET', cumulative_key, counter_key, count) redis.call('HSET', inst_count_key, counter_key, count) redis.call('SADD', keys_key, counter_key) -return {count, count, new_epoch, instance_created} +return {counter_key, count, count, new_epoch, instance_created, 1} "#; const SET_ON_INSTANCE_LUA: &str = r#" @@ -166,11 +223,14 @@ local keys_key = KEYS[4] local inst_count_key = KEYS[5] local counter_key = ARGV[1] -local count = tonumber(ARGV[2]) -local local_epoch = tonumber(ARGV[3]) -local dead_threshold = tonumber(ARGV[4]) -local prefix = ARGV[5] -local instance_id = ARGV[6] +local comparator = ARGV[2] +local compare_against = tonumber(ARGV[3]) +local count = tonumber(ARGV[4]) +local local_epoch = tonumber(ARGV[5]) +local local_count = tonumber(ARGV[6]) or 0 +local dead_threshold = tonumber(ARGV[7]) +local prefix = ARGV[8] +local instance_id = ARGV[9] local ts = now_ms() local instance_created = check_and_zadd(instances_key, instance_id, ts) @@ -178,16 +238,27 @@ delete_dead_instances(prefix, instances_key, cumulative_key, keys_key, dead_thre local redis_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or 0 local inst_count = tonumber(redis.call('HGET', inst_count_key, counter_key) or 0) or 0 +local cumulative = tonumber(redis.call('HGET', cumulative_key, counter_key) or 0) or 0 local is_stale = (local_epoch ~= redis_epoch) -local effective_old = is_stale and 0 or inst_count -local delta = count - effective_old +if instance_created ~= 0 and not is_stale and local_count > 0 then + redis.call('HSET', inst_count_key, counter_key, local_count) + cumulative = tonumber(redis.call('HINCRBY', cumulative_key, counter_key, local_count)) + inst_count = local_count + redis.call('SADD', keys_key, counter_key) +end +local current_inst_count = is_stale and 0 or inst_count +if not compare_values(current_inst_count, comparator, compare_against) then + return {counter_key, cumulative, current_inst_count, redis_epoch, instance_created, 0} +end + +local delta = count - current_inst_count redis.call('HSET', inst_count_key, counter_key, count) local new_cumulative = tonumber(redis.call('HINCRBY', cumulative_key, counter_key, delta)) redis.call('SADD', keys_key, counter_key) -return {new_cumulative, count, redis_epoch, instance_created} +return {counter_key, new_cumulative, count, redis_epoch, instance_created, 1} "#; const GET_LUA: &str = r#" @@ -211,7 +282,7 @@ local redis_epoch = tonumber(redis.call('HGET', epoch_key, counter_key) or 0) or local cumulative = tonumber(redis.call('HGET', cumulative_key, counter_key) or 0) or 0 local inst_count = tonumber(redis.call('HGET', inst_count_key, counter_key) or 0) or 0 -return {cumulative, inst_count, redis_epoch, instance_created} +return {counter_key, cumulative, inst_count, redis_epoch, instance_created} "#; const DEL_LUA: &str = r#" @@ -383,7 +454,7 @@ return {counter_key, new_cumulative, new_inst_count, redis_epoch} #[derive(Debug, Clone)] pub struct StrictInstanceAwareCounterOptions { /// Redis key prefix used to namespace all counter keys. - pub prefix: RedisKey, + pub prefix: DistkitRedisKey, /// Redis connection manager. pub connection_manager: ConnectionManager, /// Milliseconds without a heartbeat before an instance is considered dead. @@ -397,7 +468,7 @@ impl StrictInstanceAwareCounterOptions { /// # Examples /// /// ```rust - /// use distkit::{RedisKey, icounter::StrictInstanceAwareCounterOptions}; + /// use distkit::{DistkitRedisKey, icounter::StrictInstanceAwareCounterOptions}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -405,13 +476,13 @@ impl StrictInstanceAwareCounterOptions { /// .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string()); /// let client = redis::Client::open(redis_url)?; /// let conn = client.get_connection_manager().await?; - /// let prefix = RedisKey::try_from("my_app".to_string())?; + /// let prefix = DistkitRedisKey::try_from("my_app".to_string())?; /// let opts = StrictInstanceAwareCounterOptions::new(prefix, conn); /// assert_eq!(opts.dead_instance_threshold_ms, 30_000); /// # Ok(()) /// # } /// ``` - pub fn new(prefix: RedisKey, connection_manager: ConnectionManager) -> Self { + pub fn new(prefix: DistkitRedisKey, connection_manager: ConnectionManager) -> Self { Self { prefix, connection_manager, @@ -439,7 +510,7 @@ pub struct StrictInstanceAwareCounter { instance_id: String, dead_instance_threshold_ms: u64, /// Per-key in-memory state: epoch, last-seen cumulative, and this instance's count. - local_store: DashMap, + local_store: DashMap, /// Maximum epoch value before wrapping. Set to `u64::MAX / 2`. max_epoch: u64, inc_script: Script, @@ -501,7 +572,7 @@ impl StrictInstanceAwareCounter { /// # Examples /// /// ```rust - /// use distkit::{RedisKey, icounter::{StrictInstanceAwareCounter, StrictInstanceAwareCounterOptions}}; + /// use distkit::{DistkitRedisKey, icounter::{StrictInstanceAwareCounter, StrictInstanceAwareCounterOptions}}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -509,7 +580,7 @@ impl StrictInstanceAwareCounter { /// .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string()); /// let client = redis::Client::open(redis_url)?; /// let conn = client.get_connection_manager().await?; - /// let prefix = RedisKey::try_from("my_app".to_string())?; + /// let prefix = DistkitRedisKey::try_from("my_app".to_string())?; /// let counter = StrictInstanceAwareCounter::new(StrictInstanceAwareCounterOptions::new(prefix, conn)); /// assert!(!counter.instance_id().is_empty()); /// # Ok(()) @@ -584,21 +655,27 @@ impl StrictInstanceAwareCounter { // local_store helpers // ----------------------------------------------------------------------- - fn get_local_epoch(&self, key: &RedisKey) -> u64 { + fn get_local_epoch(&self, key: &DistkitRedisKey) -> u64 { self.local_store .get(key) .map(|s| s.epoch.load(Ordering::Acquire)) .unwrap_or(0) } - fn get_local_count(&self, key: &RedisKey) -> i64 { + fn get_local_count(&self, key: &DistkitRedisKey) -> i64 { self.local_store .get(key) .map(|s| s.local_count.load(Ordering::Acquire)) .unwrap_or(0) } - fn update_local_store(&self, key: &RedisKey, epoch: u64, cumulative: i64, local_count: i64) { + fn update_local_store( + &self, + key: &DistkitRedisKey, + epoch: u64, + cumulative: i64, + local_count: i64, + ) { match self.local_store.get(key) { Some(s) => { s.epoch.store(epoch, Ordering::Release); @@ -650,44 +727,13 @@ impl StrictInstanceAwareCounter { }); } - /// Builds a Redis pipeline with one `INC_IF_EPOCH_MATCHES_LUA` invocation per - /// item in `chunk`. `load_script = true` prepends a `LOAD SCRIPT` command to - /// handle cache misses (mirrors `LaxCounter::build_commit_pipeline`). - fn build_recovery_pipeline( - &self, - chunk: &[(RedisKey, i64, u64)], - load_script: bool, - ) -> redis::Pipeline { - let mut pipe = redis::Pipeline::new(); - if load_script { - pipe.load_script(&self.inc_if_epoch_matches_script).ignore(); - } - for (key, count, local_epoch) in chunk { - pipe.invoke_script( - self.inc_if_epoch_matches_script - .key(self.epoch_key()) - .key(self.instances_key()) - .key(self.cumulative_key()) - .key(self.keys_key()) - .key(self.inst_count_key()) - .arg(key.as_str()) - .arg(*count) - .arg(*local_epoch) - .arg(self.dead_instance_threshold_ms) - .arg(self.prefix_str()) - .arg(&self.instance_id), - ); - } - pipe - } - /// Sends recovery increments for all keys in `recoveries` using pipelined /// `INC_IF_EPOCH_MATCHES_LUA` calls, chunked to avoid oversized pipelines. /// After each chunk the returned `(key, cumulative, inst_count, redis_epoch)` /// tuples are used to update `local_store` before the next chunk begins. async fn recover_contributions_batched( &self, - recoveries: Vec<(RedisKey, i64, u64)>, + recoveries: Vec<(DistkitRedisKey, i64, u64)>, chunk_size: usize, ) -> Result<(), DistkitError> { if recoveries.is_empty() { @@ -700,26 +746,29 @@ impl StrictInstanceAwareCounter { while processed < recoveries.len() { let end = (processed + chunk_size).min(recoveries.len()); let chunk = &recoveries[processed..end]; - - let results: Vec<(String, i64, i64, i64)> = { - let pipe = self.build_recovery_pipeline(chunk, false); - match pipe.query_async(&mut conn).await { - Ok(r) => r, - Err(err) => { - if err.kind() != redis::ErrorKind::Server(redis::ServerErrorKind::NoScript) - { - return Err(DistkitError::RedisError(err)); - } - // Script not in cache — reload and retry. - let pipe = self.build_recovery_pipeline(chunk, true); - pipe.query_async(&mut conn).await? - } - } - }; + let script = &self.inc_if_epoch_matches_script; + + let results: Vec<(String, i64, i64, i64)> = + execute_pipeline_with_script_retry(&mut conn, script, chunk, |item| { + let (key, count, local_epoch) = item; + let mut inv = script.key(self.epoch_key()); + inv.key(self.instances_key()); + inv.key(self.cumulative_key()); + inv.key(self.keys_key()); + inv.key(self.inst_count_key()); + inv.arg(key.as_str()); + inv.arg(*count); + inv.arg(*local_epoch); + inv.arg(self.dead_instance_threshold_ms); + inv.arg(self.prefix_str()); + inv.arg(&self.instance_id); + inv + }) + .await?; // Each result carries its own key — no zip required. for (key_str, cumulative, inst_count, redis_epoch) in results { - if let Ok(key) = RedisKey::try_from(key_str) { + if let Ok(key) = DistkitRedisKey::try_from(key_str) { self.update_local_store(&key, redis_epoch as u64, cumulative, inst_count); } } @@ -730,69 +779,47 @@ impl StrictInstanceAwareCounter { Ok(()) } - /// Builds a Redis pipeline with one `inc_script` invocation per item in - /// `chunk`. Because `INC_LUA` now echoes `counter_key` as its first return - /// element, results are self-identifying — no zip required. - fn build_inc_batch_pipeline( + pub(crate) async fn inc_batch( &self, - chunk: &[(RedisKey, i64)], - load_script: bool, - ) -> redis::Pipeline { - let mut pipe = redis::Pipeline::new(); - if load_script { - pipe.load_script(&self.inc_script).ignore(); + increments: &mut Vec<(DistkitRedisKey, i64)>, + max_batch_size: usize, + ) -> Result, DistkitError> { + if increments.is_empty() { + return Ok(vec![]); } - for (key, delta) in chunk { - let local_epoch = self.get_local_epoch(key); - pipe.invoke_script( - self.inc_script - .key(self.epoch_key()) - .key(self.instances_key()) - .key(self.cumulative_key()) - .key(self.keys_key()) - .key(self.inst_count_key()) - .arg(key.as_str()) - .arg(*delta) - .arg(local_epoch) - .arg(self.dead_instance_threshold_ms) - .arg(self.prefix_str()) - .arg(&self.instance_id), + + let mut processed = 0; + let mut output: Vec<(String, i64, i64)> = Vec::with_capacity(increments.len()); + + while processed < increments.len() { + let end = (processed + max_batch_size).min(increments.len()); + let chunk = &increments[processed..end]; + let conditional_chunk: Vec<(&DistkitRedisKey, CounterComparator, i64)> = chunk + .iter() + .map(|(key, delta)| (key, CounterComparator::Nil, *delta)) + .collect(); + let chunk_results = self.inc_if_batch(&conditional_chunk).await?; + + output.extend( + chunk_results + .into_iter() + .map(|(key, cumulative, inst_count)| (key.to_string(), cumulative, inst_count)), ); + + processed = end; } - pipe + + // All chunks succeeded — drain entire input. + increments.drain(..processed); + + Ok(output) } - /// Sends multiple increments in a pipelined batch, chunked to `max_batch_size` per - /// pipeline. Takes `&mut Vec` so successfully committed entries are drained - /// in-place; on failure the remaining entries stay in the vector for the - /// caller to retry. - /// - /// Returns `(counter_key, cumulative, instance_count)` for every entry that - /// was committed. Also updates `local_store` from each result. - /// - /// # Examples - /// - /// ```rust - /// # use distkit::RedisKey; - /// # #[tokio::main] - /// # async fn main() -> Result<(), Box> { - /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; - /// let k1 = RedisKey::try_from("a".to_string())?; - /// let k2 = RedisKey::try_from("b".to_string())?; - /// let mut increments = vec![(k1, 3_i64), (k2, 7_i64)]; - /// let results = counter.inc_batch(&mut increments, 50).await?; - /// // Successful entries are drained from the input vec. - /// assert!(increments.is_empty()); - /// assert_eq!(results.len(), 2); - /// # Ok(()) - /// # } - /// ``` - pub async fn inc_batch( + pub(crate) async fn inc_if_batch<'k>( &self, - increments: &mut Vec<(RedisKey, i64)>, - max_batch_size: usize, - ) -> Result, DistkitError> { - if increments.is_empty() { + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { return Ok(vec![]); } @@ -800,51 +827,297 @@ impl StrictInstanceAwareCounter { let mut conn = self.connection_manager.clone(); let mut processed = 0; - let mut output: Vec<(String, i64, i64)> = Vec::with_capacity(increments.len()); + let mut output = Vec::with_capacity(updates.len()); - while processed < increments.len() { - let end = (processed + max_batch_size).min(increments.len()); - let chunk = &increments[processed..end]; + while processed < updates.len() { + let mut seen = HashSet::new(); + let mut end = processed; + while end < updates.len() && seen.insert(updates[end].0.as_str()) { + end += 1; + } - // Build and run the pipeline inside a block so the `chunk` slice - // borrow ends before we potentially drain `increments`. - // Results: (counter_key, cumulative, inst_count, redis_epoch, instance_created) - let first_attempt = { - let pipe = self.build_inc_batch_pipeline(chunk, false); - pipe.query_async::>(&mut conn) - .await - }; - - let chunk_results: Vec<(String, i64, i64, u64, i64)> = match first_attempt { - Ok(r) => r, - Err(err) => { - if err.kind() != redis::ErrorKind::Server(redis::ServerErrorKind::NoScript) { - return Err(DistkitError::RedisError(err)); - } - // Script not cached — reload and retry. After the drain the - // current chunk is now at indices [0..chunk_len]. - let pipe = self.build_inc_batch_pipeline(chunk, true); - match pipe.query_async(&mut conn).await { - Ok(r) => r, - Err(e) => return Err(DistkitError::RedisError(e)), + let chunk = &updates[processed..end]; + let script = &self.inc_script; + let local_epochs: Vec = chunk + .iter() + .map(|(key, _, _)| self.get_local_epoch(key)) + .collect(); + + let chunk_results: Vec<(String, i64, i64, u64, i64, i64)> = + execute_pipeline_with_script_retry(&mut conn, script, chunk, |update| { + let (key, comparator, delta) = update; + let (lua_comparator, compare_against) = comparator.as_lua_parts(); + let mut inv = script.key(self.epoch_key()); + inv.key(self.instances_key()); + inv.key(self.cumulative_key()); + inv.key(self.keys_key()); + inv.key(self.inst_count_key()); + inv.arg(key.as_str()); + inv.arg(lua_comparator); + inv.arg(compare_against); + inv.arg(*delta); + inv.arg(self.get_local_epoch(key)); + inv.arg(self.get_local_count(key)); + inv.arg(self.dead_instance_threshold_ms); + inv.arg(self.prefix_str()); + inv.arg(&self.instance_id); + inv + }) + .await?; + + for ( + ((key, _, _), local_epoch), + (_, cumulative, inst_count, redis_epoch, _, matched_raw), + ) in chunk + .iter() + .zip(local_epochs.iter()) + .zip(chunk_results.into_iter()) + { + if matched_raw != 0 || *local_epoch == redis_epoch { + self.update_local_store(key, redis_epoch, cumulative, inst_count); + } + + output.push((*key, cumulative, inst_count)); + } + + processed = end; + } + + Ok(output) + } + + pub(crate) async fn get_batch<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError> { + if keys.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + let mut conn = self.connection_manager.clone(); + let mut map: HashMap = HashMap::with_capacity(keys.len()); + let mut recovery_keys: Vec<(DistkitRedisKey, i64)> = Vec::new(); + + let mut processed = 0; + while processed < keys.len() { + let end = (processed + MAX_BATCH_SIZE).min(keys.len()); + let chunk = &keys[processed..end]; + let script = &self.get_script; + + let chunk_results: Vec<(String, i64, i64, u64, i64)> = + execute_pipeline_with_script_retry(&mut conn, script, chunk, |key| { + let local_epoch = self.get_local_epoch(key); + let mut inv = script.key(self.epoch_key()); + inv.key(self.instances_key()); + inv.key(self.cumulative_key()); + inv.key(self.keys_key()); + inv.key(self.inst_count_key()); + inv.arg(key.as_str()); + inv.arg(local_epoch); + inv.arg(self.dead_instance_threshold_ms); + inv.arg(self.prefix_str()); + inv.arg(&self.instance_id); + inv + }) + .await?; + + for (key_str, cumulative, inst_count, redis_epoch, instance_created_raw) in + chunk_results + { + if let Ok(key) = DistkitRedisKey::try_from(key_str.clone()) { + let instance_created = instance_created_raw != 0; + let local_epoch = self.get_local_epoch(&key); + let old_local_count = self.get_local_count(&key); + self.update_local_store(&key, redis_epoch, cumulative, inst_count); + if instance_created && local_epoch == redis_epoch && old_local_count > 0 { + recovery_keys.push((key.clone(), old_local_count)); } + map.insert(key_str, (cumulative, inst_count)); } - }; + } + + processed = end; + } + + // Sequential recovery fallback (rare: instance was cleaned up as dead). + for (key, old_count) in recovery_keys { + let (cumulative, inst_count) = self.inc(&key, old_count).await?; + map.insert(key.to_string(), (cumulative, inst_count)); + } + + Ok(keys + .iter() + .map(|k| { + let (cum, inst) = map.get(k.as_str()).copied().unwrap_or((0, 0)); + (*k, cum, inst) + }) + .collect()) + } + + pub(crate) async fn set_batch<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.set_if_batch(&conditional_updates).await + } + + pub(crate) async fn set_if_batch<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + let mut conn = self.connection_manager.clone(); + let mut map: HashMap = HashMap::with_capacity(updates.len()); + let mut processed = 0; - for (key_str, cumulative, inst_count, redis_epoch, _) in chunk_results { - if let Ok(key) = RedisKey::try_from(key_str.clone()) { + while processed < updates.len() { + let end = (processed + MAX_BATCH_SIZE).min(updates.len()); + let chunk = &updates[processed..end]; + let script = &self.set_script; + let local_epochs: HashMap = chunk + .iter() + .map(|(key, _, _)| ((*key).clone(), self.get_local_epoch(key))) + .collect(); + + let chunk_results: Vec<(String, i64, i64, u64, i64, i64)> = + execute_pipeline_with_script_retry(&mut conn, script, chunk, |update| { + let (key, comparator, count) = update; + let (lua_comparator, compare_against) = comparator.as_lua_parts(); + let mut inv = script.key(self.epoch_key()); + inv.key(self.instances_key()); + inv.key(self.cumulative_key()); + inv.key(self.keys_key()); + inv.key(self.inst_count_key()); + inv.arg(key.as_str()); + inv.arg(lua_comparator); + inv.arg(compare_against); + inv.arg(*count); + inv.arg(self.get_local_epoch(key)); + inv.arg(self.get_local_count(key)); + inv.arg(self.dead_instance_threshold_ms); + inv.arg(self.prefix_str()); + inv.arg(&self.instance_id); + inv.arg(self.max_epoch); + inv + }) + .await?; + + for (key, cumulative, inst_count, redis_epoch, _, matched_raw) in chunk_results { + let Ok(key) = DistkitRedisKey::try_from(key.clone()) else { + continue; + }; + + let local_epoch = local_epochs.get(&key).copied().unwrap_or(0); + if matched_raw != 0 || local_epoch == redis_epoch { self.update_local_store(&key, redis_epoch, cumulative, inst_count); } - output.push((key_str, cumulative, inst_count)); + + map.insert(key, (cumulative, inst_count)); } processed = end; } - // All chunks succeeded — drain entire input. - increments.drain(..processed); + Ok(updates + .iter() + .map(|(k, _, _)| { + let (cum, inst) = map.get(k).copied().unwrap_or((0, 0)); + (*k, cum, inst) + }) + .collect()) + } - Ok(output) + async fn set_on_instance_batch<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.set_on_instance_if_batch(&conditional_updates).await + } + + async fn set_on_instance_if_batch<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + if updates.is_empty() { + return Ok(vec![]); + } + + self.activity.signal(); + + let mut conn = self.connection_manager.clone(); + let mut map: HashMap = HashMap::with_capacity(updates.len()); + let mut processed = 0; + + while processed < updates.len() { + let end = (processed + MAX_BATCH_SIZE).min(updates.len()); + let chunk = &updates[processed..end]; + let script = &self.set_on_instance_script; + let local_epochs: HashMap = chunk + .iter() + .map(|(key, _, _)| ((*key).clone(), self.get_local_epoch(key))) + .collect(); + + let chunk_results: Vec<(String, i64, i64, u64, i64, i64)> = + execute_pipeline_with_script_retry(&mut conn, script, chunk, |update| { + let (key, comparator, count) = update; + let (lua_comparator, compare_against) = comparator.as_lua_parts(); + let mut inv = script.key(self.epoch_key()); + inv.key(self.instances_key()); + inv.key(self.cumulative_key()); + inv.key(self.keys_key()); + inv.key(self.inst_count_key()); + inv.arg(key.as_str()); + inv.arg(lua_comparator); + inv.arg(compare_against); + inv.arg(*count); + inv.arg(self.get_local_epoch(key)); + inv.arg(self.get_local_count(key)); + inv.arg(self.dead_instance_threshold_ms); + inv.arg(self.prefix_str()); + inv.arg(&self.instance_id); + inv + }) + .await?; + + for (key, cumulative, inst_count, redis_epoch, _, matched_raw) in chunk_results { + let Ok(key) = DistkitRedisKey::try_from(key.clone()) else { + continue; + }; + + let local_epoch = local_epochs.get(&key).copied().unwrap_or(0); + if matched_raw != 0 || local_epoch == redis_epoch { + self.update_local_store(&key, redis_epoch, cumulative, inst_count); + } + + map.insert(key, (cumulative, inst_count)); + } + processed = end; + } + + Ok(updates + .iter() + .map(|(k, _, _)| { + let (cum, inst) = map.get(k).copied().unwrap_or((0, 0)); + (*k, cum, inst) + }) + .collect()) } #[cfg(test)] @@ -871,7 +1144,7 @@ impl StrictInstanceAwareCounter { // The instance was cleaned up while offline. Recover contributions // for all keys that still have a positive local count, but only // when the per-key epoch in Redis still matches — epoch-safe recovery. - let recoveries: Vec<(RedisKey, i64, u64)> = self + let recoveries: Vec<(DistkitRedisKey, i64, u64)> = self .local_store .iter() .filter_map(|e| { @@ -900,7 +1173,7 @@ impl StrictInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::RedisKey; + /// # use distkit::DistkitRedisKey; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; @@ -920,11 +1193,11 @@ impl StrictInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::RedisKey; + /// # use distkit::DistkitRedisKey; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// let (cumulative_a, slice_a) = server_a.inc(&key, 3).await?; /// assert_eq!(cumulative_a, 3); /// assert_eq!(slice_a, 3); @@ -934,18 +1207,34 @@ impl StrictInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - pub async fn inc(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + pub async fn inc(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + self.inc_if(key, CounterComparator::Nil, count).await + } + + /// Conditionally adds `count` to this instance's contribution for `key` + /// when the cumulative total satisfies `comparator`. + /// + /// Returns `(cumulative, instance_count)` after evaluation. If the + /// condition fails, the returned values reflect the current state. + pub async fn inc_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result<(i64, i64), DistkitError> { self.activity.signal(); let mut conn = self.connection_manager.clone(); let local_epoch = self.get_local_epoch(key); + let (lua_comparator, compare_against) = comparator.as_lua_parts(); - let (_, cumulative, inst_count, redis_epoch, instance_created_raw): ( + let (_, cumulative, inst_count, redis_epoch, _, matched_raw): ( String, i64, i64, u64, i64, + i64, ) = self .inc_script .key(self.epoch_key()) @@ -954,22 +1243,19 @@ impl StrictInstanceAwareCounter { .key(self.keys_key()) .key(self.inst_count_key()) .arg(key.as_str()) + .arg(lua_comparator) + .arg(compare_against) .arg(count) .arg(local_epoch) + .arg(self.get_local_count(key)) .arg(self.dead_instance_threshold_ms) .arg(self.prefix_str()) .arg(&self.instance_id) .invoke_async(&mut conn) .await?; - let instance_created = instance_created_raw != 0; - let should_recover = instance_created && local_epoch == redis_epoch; - - let old_local_count = self.get_local_count(key); - self.update_local_store(key, redis_epoch, cumulative, inst_count); - - if should_recover && old_local_count > 0 { - return Box::pin(self.inc(key, old_local_count)).await; + if matched_raw != 0 || local_epoch == redis_epoch { + self.update_local_store(key, redis_epoch, cumulative, inst_count); } Ok((cumulative, inst_count)) @@ -984,11 +1270,11 @@ impl StrictInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::RedisKey; + /// # use distkit::DistkitRedisKey; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 10).await?; /// server_b.inc(&key, 5).await?; /// // Epoch bumps; all previous per-instance contributions are cleared. @@ -998,13 +1284,35 @@ impl StrictInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - pub async fn set(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + pub async fn set(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + self.set_if(key, CounterComparator::Nil, count).await + } + + /// Conditionally sets the cumulative total for `key` to `count` when the + /// current cumulative total satisfies `comparator`. + /// + /// Returns `(cumulative, instance_count)` after evaluation. If the + /// condition fails, the returned values reflect the current state. + pub async fn set_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result<(i64, i64), DistkitError> { self.activity.signal(); let mut conn = self.connection_manager.clone(); let local_epoch = self.get_local_epoch(key); + let (lua_comparator, compare_against) = comparator.as_lua_parts(); - let (cumulative, inst_count, new_epoch_raw, _): (i64, i64, u64, i64) = self + let (_, cumulative, inst_count, redis_epoch, _, matched_raw): ( + String, + i64, + i64, + u64, + i64, + i64, + ) = self .set_script .key(self.epoch_key()) .key(self.instances_key()) @@ -1012,8 +1320,11 @@ impl StrictInstanceAwareCounter { .key(self.keys_key()) .key(self.inst_count_key()) .arg(key.as_str()) + .arg(lua_comparator) + .arg(compare_against) .arg(count) .arg(local_epoch) + .arg(self.get_local_count(key)) .arg(self.dead_instance_threshold_ms) .arg(self.prefix_str()) .arg(&self.instance_id) @@ -1021,8 +1332,9 @@ impl StrictInstanceAwareCounter { .invoke_async(&mut conn) .await?; - // No recovery: epoch always bumps, so local_epoch != new_epoch - self.update_local_store(key, new_epoch_raw, cumulative, inst_count); + if matched_raw != 0 || local_epoch == redis_epoch { + self.update_local_store(key, redis_epoch, cumulative, inst_count); + } Ok((cumulative, inst_count)) } @@ -1035,11 +1347,11 @@ impl StrictInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::RedisKey; + /// # use distkit::DistkitRedisKey; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 10).await?; /// server_b.inc(&key, 5).await?; /// // No epoch bump: server_b's slice is not evicted. @@ -1051,15 +1363,38 @@ impl StrictInstanceAwareCounter { /// ``` pub async fn set_on_instance( &self, - key: &RedisKey, + key: &DistkitRedisKey, + count: i64, + ) -> Result<(i64, i64), DistkitError> { + self.set_on_instance_if(key, CounterComparator::Nil, count) + .await + } + + /// Conditionally sets this instance's contribution for `key` to `count` + /// when the current instance slice satisfies `comparator`. + /// + /// Returns `(cumulative, instance_count)` after evaluation. If the + /// condition fails, the returned values reflect the current state. + pub async fn set_on_instance_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, count: i64, ) -> Result<(i64, i64), DistkitError> { self.activity.signal(); let mut conn = self.connection_manager.clone(); let local_epoch = self.get_local_epoch(key); + let (lua_comparator, compare_against) = comparator.as_lua_parts(); - let (cumulative, inst_count, redis_epoch_raw, _): (i64, i64, u64, i64) = self + let (_, cumulative, inst_count, redis_epoch, _, matched_raw): ( + String, + i64, + i64, + u64, + i64, + i64, + ) = self .set_on_instance_script .key(self.epoch_key()) .key(self.instances_key()) @@ -1067,16 +1402,20 @@ impl StrictInstanceAwareCounter { .key(self.keys_key()) .key(self.inst_count_key()) .arg(key.as_str()) + .arg(lua_comparator) + .arg(compare_against) .arg(count) .arg(local_epoch) + .arg(self.get_local_count(key)) .arg(self.dead_instance_threshold_ms) .arg(self.prefix_str()) .arg(&self.instance_id) .invoke_async(&mut conn) .await?; - // No recovery: caller is explicitly setting their contribution to a specific value. - self.update_local_store(key, redis_epoch_raw, cumulative, inst_count); + if matched_raw != 0 || local_epoch == redis_epoch { + self.update_local_store(key, redis_epoch, cumulative, inst_count); + } Ok((cumulative, inst_count)) } @@ -1088,11 +1427,11 @@ impl StrictInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::RedisKey; + /// # use distkit::DistkitRedisKey; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// // A missing key returns (0, 0). /// assert_eq!(counter.get(&key).await?, (0, 0)); /// counter.inc(&key, 5).await?; @@ -1100,26 +1439,32 @@ impl StrictInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - pub async fn get(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> { + pub async fn get(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError> { self.activity.signal(); let mut conn = self.connection_manager.clone(); let local_epoch = self.get_local_epoch(key); - let (cumulative, inst_count, redis_epoch, instance_created_raw): (i64, i64, u64, i64) = - self.get_script - .key(self.epoch_key()) - .key(self.instances_key()) - .key(self.cumulative_key()) - .key(self.keys_key()) - .key(self.inst_count_key()) - .arg(key.as_str()) - .arg(local_epoch) - .arg(self.dead_instance_threshold_ms) - .arg(self.prefix_str()) - .arg(&self.instance_id) - .invoke_async(&mut conn) - .await?; + let (_, cumulative, inst_count, redis_epoch, instance_created_raw): ( + String, + i64, + i64, + u64, + i64, + ) = self + .get_script + .key(self.epoch_key()) + .key(self.instances_key()) + .key(self.cumulative_key()) + .key(self.keys_key()) + .key(self.inst_count_key()) + .arg(key.as_str()) + .arg(local_epoch) + .arg(self.dead_instance_threshold_ms) + .arg(self.prefix_str()) + .arg(&self.instance_id) + .invoke_async(&mut conn) + .await?; let instance_created = instance_created_raw != 0; let should_recover = instance_created && local_epoch == redis_epoch; @@ -1142,11 +1487,11 @@ impl StrictInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::RedisKey; + /// # use distkit::DistkitRedisKey; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 3).await?; /// server_b.inc(&key, 7).await?; /// let (old_cumulative, _) = server_a.del(&key).await?; @@ -1156,7 +1501,7 @@ impl StrictInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - pub async fn del(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> { + pub async fn del(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError> { self.activity.signal(); let mut conn = self.connection_manager.clone(); @@ -1194,11 +1539,11 @@ impl StrictInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::RedisKey; + /// # use distkit::DistkitRedisKey; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let (server_a, server_b) = distkit::__doctest_helpers::two_strict_icounters().await?; - /// let key = RedisKey::try_from("connections".to_string())?; + /// let key = DistkitRedisKey::try_from("connections".to_string())?; /// server_a.inc(&key, 3).await?; /// server_b.inc(&key, 7).await?; /// // Only server_a's slice is removed; server_b is unaffected. @@ -1208,7 +1553,7 @@ impl StrictInstanceAwareCounter { /// # Ok(()) /// # } /// ``` - pub async fn del_on_instance(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> { + pub async fn del_on_instance(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError> { self.activity.signal(); let mut conn = self.connection_manager.clone(); @@ -1240,12 +1585,12 @@ impl StrictInstanceAwareCounter { /// # Examples /// /// ```rust - /// # use distkit::RedisKey; + /// # use distkit::DistkitRedisKey; /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { /// # let counter = distkit::__doctest_helpers::strict_icounter().await?; - /// let k1 = RedisKey::try_from("a".to_string())?; - /// let k2 = RedisKey::try_from("b".to_string())?; + /// let k1 = DistkitRedisKey::try_from("a".to_string())?; + /// let k2 = DistkitRedisKey::try_from("b".to_string())?; /// counter.inc(&k1, 10).await?; /// counter.inc(&k2, 20).await?; /// counter.clear().await?; @@ -1312,35 +1657,62 @@ impl InstanceAwareCounterTrait for StrictInstanceAwareCounter { self.instance_id() } - async fn inc(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + async fn inc(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError> { self.inc(key, count).await } - async fn dec(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + async fn inc_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result<(i64, i64), DistkitError> { + self.inc_if(key, comparator, count).await + } + + async fn dec(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError> { self.inc(key, -count).await } - async fn set(&self, key: &RedisKey, count: i64) -> Result<(i64, i64), DistkitError> { + async fn set(&self, key: &DistkitRedisKey, count: i64) -> Result<(i64, i64), DistkitError> { self.set(key, count).await } + async fn set_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result<(i64, i64), DistkitError> { + self.set_if(key, comparator, count).await + } + async fn set_on_instance( &self, - key: &RedisKey, + key: &DistkitRedisKey, count: i64, ) -> Result<(i64, i64), DistkitError> { self.set_on_instance(key, count).await } - async fn get(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> { + async fn set_on_instance_if( + &self, + key: &DistkitRedisKey, + comparator: CounterComparator, + count: i64, + ) -> Result<(i64, i64), DistkitError> { + self.set_on_instance_if(key, comparator, count).await + } + + async fn get(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError> { self.get(key).await } - async fn del(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> { + async fn del(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError> { self.del(key).await } - async fn del_on_instance(&self, key: &RedisKey) -> Result<(i64, i64), DistkitError> { + async fn del_on_instance(&self, key: &DistkitRedisKey) -> Result<(i64, i64), DistkitError> { self.del_on_instance(key).await } @@ -1351,4 +1723,66 @@ impl InstanceAwareCounterTrait for StrictInstanceAwareCounter { async fn clear_on_instance(&self) -> Result<(), DistkitError> { self.clear_on_instance().await } + + async fn get_all<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError> { + self.get_batch(keys).await + } + + async fn get_all_on_instance<'k>( + &self, + keys: &[&'k DistkitRedisKey], + ) -> Result, DistkitError> { + let pairs = self.get_batch(keys).await?; + Ok(pairs.into_iter().map(|(k, _, inst)| (k, inst)).collect()) + } + + async fn inc_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + let conditional_updates: Vec<(&DistkitRedisKey, CounterComparator, i64)> = updates + .iter() + .map(|(key, count)| (*key, CounterComparator::Nil, *count)) + .collect(); + + self.inc_all_if(&conditional_updates).await + } + + async fn inc_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + self.inc_if_batch(updates).await + } + + async fn set_all<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + self.set_batch(updates).await + } + + async fn set_all_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + self.set_if_batch(updates).await + } + + async fn set_all_on_instance<'k>( + &self, + updates: &[(&'k DistkitRedisKey, i64)], + ) -> Result, DistkitError> { + self.set_on_instance_batch(updates).await + } + + async fn set_all_on_instance_if<'k>( + &self, + updates: &[(&'k DistkitRedisKey, CounterComparator, i64)], + ) -> Result, DistkitError> { + self.set_on_instance_if_batch(updates).await + } } diff --git a/src/icounter/tests/common.rs b/src/icounter/tests/common.rs index 8b70f9c..0c168d9 100644 --- a/src/icounter/tests/common.rs +++ b/src/icounter/tests/common.rs @@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use redis::aio::ConnectionManager; -use crate::RedisKey; +use crate::DistkitRedisKey; use crate::icounter::{StrictInstanceAwareCounter, StrictInstanceAwareCounterOptions}; static RUN_ID: OnceLock = OnceLock::new(); @@ -30,7 +30,7 @@ pub async fn make_counter(prefix: &str) -> Arc { let conn = make_connection().await; let unique_prefix = format!("{}_{}", run_id(), prefix); StrictInstanceAwareCounter::new(StrictInstanceAwareCounterOptions::new( - RedisKey::from(unique_prefix), + DistkitRedisKey::from(unique_prefix), conn, )) } @@ -43,7 +43,7 @@ pub async fn make_counter_with_opts( let conn = make_connection().await; let unique_prefix = format!("{}_{}", run_id(), prefix); StrictInstanceAwareCounter::new(StrictInstanceAwareCounterOptions { - prefix: RedisKey::from(unique_prefix), + prefix: DistkitRedisKey::from(unique_prefix), connection_manager: conn, dead_instance_threshold_ms: threshold_ms, }) @@ -61,11 +61,11 @@ pub async fn make_pair( let unique_prefix = format!("{}_{}", run_id(), prefix); let c1 = StrictInstanceAwareCounter::new(StrictInstanceAwareCounterOptions::new( - RedisKey::from(unique_prefix.clone()), + DistkitRedisKey::from(unique_prefix.clone()), conn1, )); let c2 = StrictInstanceAwareCounter::new(StrictInstanceAwareCounterOptions::new( - RedisKey::from(unique_prefix), + DistkitRedisKey::from(unique_prefix), conn2, )); (c1, c2) @@ -83,7 +83,7 @@ pub async fn make_pair_with_opts( let unique_prefix = format!("{}_{}", run_id(), prefix); let opts = |conn| StrictInstanceAwareCounterOptions { - prefix: RedisKey::from(unique_prefix.clone()), + prefix: DistkitRedisKey::from(unique_prefix.clone()), connection_manager: conn, dead_instance_threshold_ms: threshold_ms, }; @@ -100,7 +100,10 @@ pub async fn make_n_counters(prefix: &str, n: usize) -> Vec RedisKey { - RedisKey::from(name.to_string()) +pub fn key(name: &str) -> DistkitRedisKey { + DistkitRedisKey::from(name.to_string()) } diff --git a/src/icounter/tests/lax_instance_aware_counter.rs b/src/icounter/tests/lax_instance_aware_counter.rs index 6802c5c..37e712a 100644 --- a/src/icounter/tests/lax_instance_aware_counter.rs +++ b/src/icounter/tests/lax_instance_aware_counter.rs @@ -3,12 +3,13 @@ use std::time::Duration; use tokio::time::sleep; +use crate::CounterComparator; use crate::icounter::{ InstanceAwareCounterTrait, LaxInstanceAwareCounter, LaxInstanceAwareCounterOptions, }; use super::common::{make_connection, run_id}; -use crate::RedisKey; +use crate::DistkitRedisKey; // --------------------------------------------------------------------------- // Test helpers @@ -25,7 +26,7 @@ fn unique_prefix(name: &str) -> String { async fn make_lax_from_prefix(prefix: &str) -> Arc { let conn = make_connection().await; LaxInstanceAwareCounter::new(LaxInstanceAwareCounterOptions { - prefix: RedisKey::from(prefix.to_string()), + prefix: DistkitRedisKey::from(prefix.to_string()), connection_manager: conn, dead_instance_threshold_ms: THRESHOLD_MS, flush_interval: Duration::from_millis(FLUSH_MS), @@ -40,15 +41,19 @@ async fn make_lax(name: &str) -> Arc { /// Two lax counters sharing the same Redis prefix (different instance IDs). async fn make_lax_pair( name: &str, -) -> (Arc, Arc, String) { +) -> ( + Arc, + Arc, + String, +) { let prefix = unique_prefix(name); let c1 = make_lax_from_prefix(&prefix).await; let c2 = make_lax_from_prefix(&prefix).await; (c1, c2, prefix) } -fn key(name: &str) -> RedisKey { - RedisKey::from(name.to_string()) +fn key(name: &str) -> DistkitRedisKey { + DistkitRedisKey::from(name.to_string()) } // --------------------------------------------------------------------------- @@ -212,3 +217,559 @@ async fn clear_on_instance_removes_only_this_instance() { let (cum, _) = reader.get(&k).await.unwrap(); assert_eq!(cum, 10, "only c1's contribution should be removed"); } + +// --------------------------------------------------------------------------- +// dec +// --------------------------------------------------------------------------- + +/// dec subtracts from the local estimate without a Redis round-trip. +#[tokio::test] +async fn dec_returns_local_estimate() { + let c = make_lax("dec_returns_local_estimate").await; + let k = key("hits"); + + c.inc(&k, 10).await.unwrap(); + let (cum, instance) = c.dec(&k, 3).await.unwrap(); + + assert_eq!(cum, 7); + assert_eq!(instance, 7); +} + +// --------------------------------------------------------------------------- +// get on unknown key +// --------------------------------------------------------------------------- + +/// get on a key that was never written returns (0, 0). +#[tokio::test] +async fn get_on_unknown_key_returns_zero() { + let c = make_lax("get_unknown_zero").await; + let k = key("ghost"); + + let (cum, instance) = c.get(&k).await.unwrap(); + assert_eq!(cum, 0); + assert_eq!(instance, 0); +} + +// --------------------------------------------------------------------------- +// del_on_instance +// --------------------------------------------------------------------------- + +/// del_on_instance flushes the pending delta then removes only this instance's +/// contribution; other instances' slices are unaffected. +#[tokio::test] +async fn del_on_instance_removes_only_this_instance_contribution() { + let (c1, c2, prefix) = make_lax_pair("del_on_instance_lax").await; + let k = key("hits"); + + c1.inc(&k, 20).await.unwrap(); + c2.inc(&k, 10).await.unwrap(); + // Flush both so strict counter has c1=20, c2=10, cumulative=30. + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + + // del_on_instance removes c1's slice (20); c2's contribution (10) survives. + let (new_cum, removed) = c1.del_on_instance(&k).await.unwrap(); + assert_eq!(removed, 20); + assert_eq!(new_cum, 10); + + // A fresh reader fetches from strict and sees only c2's slice. + let reader = make_lax_from_prefix(&prefix).await; + let (cum, _) = reader.get(&k).await.unwrap(); + assert_eq!(cum, 10); +} + +// --------------------------------------------------------------------------- +// set_on_instance +// --------------------------------------------------------------------------- + +/// set_on_instance adjusts the local delta so this instance's contribution +/// reaches the target without bumping the epoch. +#[tokio::test] +async fn set_on_instance_adjusts_local_count() { + let c = make_lax("set_on_instance_adjusts").await; + let k = key("hits"); + + // Accumulate a delta, then override the instance slice via set_on_instance. + c.inc(&k, 10).await.unwrap(); + let (cum, instance) = c.set_on_instance(&k, 7).await.unwrap(); + assert_eq!(instance, 7); + assert_eq!(cum, 7); + + // get reflects the overwritten local state. + let (get_cum, get_instance) = c.get(&k).await.unwrap(); + assert_eq!(get_cum, 7); + assert_eq!(get_instance, 7); +} + +// --------------------------------------------------------------------------- +// clear_on_instance with unflushed delta +// --------------------------------------------------------------------------- + +/// clear_on_instance flushes a pending delta before removing this instance's +/// contributions so the flush and delete cancel out, leaving 0 for readers. +#[tokio::test] +async fn clear_on_instance_flushes_pending_delta() { + let (c1, _c2, prefix) = make_lax_pair("clear_on_instance_pending_delta").await; + let k = key("hits"); + + // Accumulate a delta without waiting for the background flush. + c1.inc(&k, 50).await.unwrap(); + + // clear_on_instance must flush the 50 first, then remove c1's instance slice. + c1.clear_on_instance().await.unwrap(); + + // The flush committed 50 and the delete immediately removed it — net 0. + let reader = make_lax_from_prefix(&prefix).await; + let (cum, _) = reader.get(&k).await.unwrap(); + assert_eq!(cum, 0); +} + +// --------------------------------------------------------------------------- +// Stale-cumulative divergence — both instances must see the same cumulative +// --------------------------------------------------------------------------- + +/// After both instances have fully flushed, every subsequent get() must return +/// the ground-truth cumulative (3). Only the instance counts differ (1 vs 2). +/// +/// The early flusher (c1) is stuck with a stale cumulative until get() triggers +/// a re-fetch from strict — this test will FAIL until that re-fetch is +/// implemented in get(). +/// +/// Timeline: +/// 1. c1.inc(k, 1) — c1 seeds from strict (0,0); delta=1 +/// 2. c1 flushes — strict: c1=1, cum=1; c1.local: cum=1, inst=1, δ=0 +/// 3. c2.inc(k, 1) — c2 seeds from strict (cum=1, inst_c2=0); delta=1 +/// 4. c2.inc(k, 1) — delta=2 +/// 5. c2 flushes — strict: c1=1, c2=2, cum=3; c2.local: cum=3, inst=2, δ=0 +/// 6. c1.get() → (3, 1) — re-fetched: cache is 100 ms old, allowed_lag=10 ms +/// 7. c2.get() → (3, 2) — fresh from c2's own flush +/// 8. fresh.get() → (3, 0) — ground truth from strict +#[tokio::test] +async fn instances_see_different_cumulatives_after_sequential_flushes() { + let (c1, c2, prefix) = make_lax_pair("stale_cumulative_divergence").await; + let k = key("hits"); + + // Step 1: c1 increments once. + c1.inc(&k, 1).await.unwrap(); + + // Step 2: Wait for c1's background flush to commit to strict. + // After this: strict has c1=1, cum=1; c1.local: cum=1, inst=1, δ=0. + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + + // Steps 3 & 4: c2 increments twice, seeding from strict (cum=1, inst_c2=0). + c2.inc(&k, 1).await.unwrap(); + c2.inc(&k, 1).await.unwrap(); + + // Step 5: Wait for c2's background flush to commit to strict. + // After this: strict has c1=1, c2=2, cum=3; c2.local: cum=3, inst=2, δ=0. + // c1 is now idle (δ=0); its cache is ~100 ms old — well past allowed_lag (10 ms). + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + + // Step 6: c1's cache has expired (100 ms >> allowed_lag=10 ms). + // get() must re-fetch from strict and return the current ground-truth cumulative. + let (c1_cum, c1_inst) = c1.get(&k).await.unwrap(); + assert_eq!(c1_cum, 3, "c1 must re-fetch from strict — cache is stale"); + assert_eq!(c1_inst, 1, "c1's own instance slice is unchanged"); + + // Step 7: c2's cumulative is fresh from its own flush. + let (c2_cum, c2_inst) = c2.get(&k).await.unwrap(); + assert_eq!(c2_cum, 3, "c2 sees fresh cumulative from its own flush"); + assert_eq!(c2_inst, 2); + + // Step 8: A fresh reader has no local cache — always fetches directly from strict. + let reader = make_lax_from_prefix(&prefix).await; + let (reader_cum, reader_inst) = reader.get(&k).await.unwrap(); + assert_eq!(reader_cum, 3, "fresh reader sees ground-truth total"); + assert_eq!(reader_inst, 0, "fresh reader has no instance contribution"); +} + +/// Mirror of `instances_see_different_cumulatives_after_sequential_flushes`: +/// c2 flushes first this time, so c2 ends up with the stale cache. +/// +/// This test will FAIL until get() implements a staleness re-fetch. +/// +/// Timeline: +/// 1. c2.inc(k, 1) × 2 — c2 seeds from strict (0,0); delta=2 +/// 2. c2 flushes — strict: c2=2, cum=2; c2.local: cum=2, inst=2, δ=0 +/// 3. c1.inc(k, 1) — c1 seeds from strict (cum=2, inst_c1=0); delta=1 +/// 4. c1 flushes — strict: c1=1, c2=2, cum=3; c1.local: cum=3, inst=1, δ=0 +/// c2 is idle (δ=0), cache frozen at cum=2 +/// 5. c2.get() → (3, 2) — must re-fetch; cache is ~100 ms old +/// 6. c1.get() → (3, 1) — fresh from c1's own flush +/// 7. fresh.get() → (3, 0) +#[tokio::test] +async fn early_flusher_sees_stale_cumulative_from_reversed_flush_order() { + let (c1, c2, prefix) = make_lax_pair("stale_cumulative_reversed").await; + let k = key("hits"); + + // Steps 1 & 2: c2 increments twice and flushes first. + c2.inc(&k, 1).await.unwrap(); + c2.inc(&k, 1).await.unwrap(); + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + // strict: c2=2, cum=2; c2.local: cum=2, inst=2, δ=0 + + // Step 3: c1 increments after c2 has already flushed, seeding from strict (cum=2). + c1.inc(&k, 1).await.unwrap(); + + // Step 4: c1 flushes. strict: c1=1, c2=2, cum=3; c1.local: cum=3, inst=1. + // c2 is idle (δ=0) — the flush task skips it every tick. + // Use a longer wait so c1's contribution is definitely visible in Redis + // before c2.get() triggers a staleness re-fetch. + sleep(Duration::from_millis(FLUSH_MS * 15)).await; + + // Step 5: c2's cache is ~100 ms old (>> allowed_lag=10 ms). + // get() must re-fetch from strict and return the current ground-truth cumulative. + let (c2_cum, c2_inst) = c2.get(&k).await.unwrap(); + assert_eq!(c2_cum, 3, "c2 must re-fetch from strict — cache is stale"); + assert_eq!(c2_inst, 2, "c2's own instance slice is unchanged"); + + // Step 6: c1's cumulative is fresh from its own flush. + let (c1_cum, c1_inst) = c1.get(&k).await.unwrap(); + assert_eq!(c1_cum, 3, "c1 sees fresh cumulative from its own flush"); + assert_eq!(c1_inst, 1); + + // Step 7: Ground truth. + let reader = make_lax_from_prefix(&prefix).await; + let (reader_cum, reader_inst) = reader.get(&k).await.unwrap(); + assert_eq!(reader_cum, 3, "fresh reader sees ground-truth total"); + assert_eq!(reader_inst, 0); +} + +// --------------------------------------------------------------------------- +// get_all +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn get_all_empty_returns_empty() { + let c = make_lax("get_all_empty").await; + assert_eq!(c.get_all(&[]).await.unwrap(), vec![]); +} + +#[tokio::test] +async fn get_all_unknown_keys_return_zero_zero() { + let c = make_lax("get_all_unknown").await; + let k1 = key("a"); + let k2 = key("b"); + assert_eq!( + c.get_all(&[&k1, &k2]).await.unwrap(), + vec![(&k1, 0, 0), (&k2, 0, 0)] + ); +} + +#[tokio::test] +async fn get_all_returns_correct_values_after_inc() { + let c = make_lax("get_all_after_inc").await; + let k1 = key("a"); + let k2 = key("b"); + c.inc(&k1, 5).await.unwrap(); + c.inc(&k2, 10).await.unwrap(); + let results = c.get_all(&[&k1, &k2]).await.unwrap(); + assert_eq!(results, vec![(&k1, 5, 5), (&k2, 10, 10)]); +} + +#[tokio::test] +async fn get_all_preserves_input_order() { + let c = make_lax("get_all_order").await; + let k1 = key("a"); + let k2 = key("b"); + let k3 = key("c"); + c.inc(&k1, 1).await.unwrap(); + c.inc(&k2, 2).await.unwrap(); + c.inc(&k3, 3).await.unwrap(); + let results = c.get_all(&[&k3, &k1, &k2]).await.unwrap(); + assert_eq!(results, vec![(&k3, 3, 3), (&k1, 1, 1), (&k2, 2, 2)]); +} + +/// A fresh reader with no local cache must batch-fetch from the strict counter +/// and return the written values — not stale zeros. +#[tokio::test] +async fn get_all_fetches_stale_keys_from_redis() { + let (writer, _, prefix) = make_lax_pair("get_all_stale").await; + let k = key("hits"); + + writer.inc(&k, 42).await.unwrap(); + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + + let reader = make_lax_from_prefix(&prefix).await; + let results = reader.get_all(&[&k]).await.unwrap(); + assert_eq!(results[0].1, 42); +} + +// --------------------------------------------------------------------------- +// get_all_on_instance +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn get_all_on_instance_empty_returns_empty() { + let c = make_lax("goi_empty").await; + assert_eq!(c.get_all_on_instance(&[]).await.unwrap(), vec![]); +} + +#[tokio::test] +async fn get_all_on_instance_unknown_keys_return_zero() { + let c = make_lax("goi_unknown").await; + let k1 = key("a"); + let k2 = key("b"); + // No Redis call — purely local; missing keys are 0. + assert_eq!( + c.get_all_on_instance(&[&k1, &k2]).await.unwrap(), + vec![(&k1, 0), (&k2, 0)] + ); +} + +#[tokio::test] +async fn get_all_on_instance_returns_instance_count_plus_delta() { + let c = make_lax("goi_delta").await; + let k = key("hits"); + c.inc(&k, 5).await.unwrap(); + // instance_count is seeded from strict (0) + delta (5) = 5. + assert_eq!(c.get_all_on_instance(&[&k]).await.unwrap(), vec![(&k, 5)]); +} + +#[tokio::test] +async fn get_all_on_instance_unaffected_by_other_instances() { + let (c1, c2, _) = make_lax_pair("goi_isolation").await; + let k = key("hits"); + + c2.inc(&k, 100).await.unwrap(); + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + + // c1 has no local contribution for this key; result must be 0. + assert_eq!(c1.get_all_on_instance(&[&k]).await.unwrap(), vec![(&k, 0)]); +} + +// --------------------------------------------------------------------------- +// set_all_on_instance +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn set_all_on_instance_empty_returns_empty() { + let c = make_lax("soi_empty").await; + assert_eq!(c.set_all_on_instance(&[]).await.unwrap(), vec![]); +} + +#[tokio::test] +async fn set_all_on_instance_basic_correctness() { + let c = make_lax("soi_basic").await; + let k1 = key("a"); + let k2 = key("b"); + let results = c.set_all_on_instance(&[(&k1, 7), (&k2, 3)]).await.unwrap(); + assert_eq!(results, vec![(&k1, 7, 7), (&k2, 3, 3)]); +} + +/// Other instances' slices must not be overwritten by set_all_on_instance. +#[tokio::test] +async fn set_all_on_instance_preserves_other_slices() { + let (c1, c2, prefix) = make_lax_pair("soi_preserves").await; + let k = key("hits"); + + c2.inc(&k, 10).await.unwrap(); + sleep(Duration::from_millis(FLUSH_MS * 10)).await; + + // c1 sets its own slice to 4. c2's slice (10) must survive. + c1.set_all_on_instance(&[(&k, 4)]).await.unwrap(); + // Long enough for c1 to flush delta=4 but well under the 300 ms dead-instance + // threshold (c2 last heartbeat ~10 ms ago, c1 last heartbeat ~105 ms ago). + sleep(Duration::from_millis(FLUSH_MS * 10)).await; + + // A fresh reader sees both contributions. + let reader = make_lax_from_prefix(&prefix).await; + let (cum, _) = reader.get(&k).await.unwrap(); + assert_eq!(cum, 14); +} + +/// Stale keys are batch-refreshed before computing the delta. +#[tokio::test] +async fn set_all_on_instance_stale_keys_are_batch_refreshed() { + let (c1, c2, _) = make_lax_pair("soi_stale_refresh").await; + let k = key("hits"); + + c2.inc(&k, 5).await.unwrap(); + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + + // c1 has never touched this key — it is seeded from strict during set_all_on_instance. + let results = c1.set_all_on_instance(&[(&k, 3)]).await.unwrap(); + // instance_count for c1 = 0 (fresh), cumulative from strict = 5. + // Expected: (5 - 0 + 3, 3) = (8, 3). + assert_eq!(results[0].2, 3, "instance slice set correctly"); +} + +// --------------------------------------------------------------------------- +// set_all +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn set_all_empty_returns_empty() { + let c = make_lax("sa_empty").await; + assert_eq!(c.set_all(&[]).await.unwrap(), vec![]); +} + +#[tokio::test] +async fn set_all_basic_correctness() { + let c = make_lax("sa_basic").await; + let k1 = key("a"); + let k2 = key("b"); + let results = c.set_all(&[(&k1, 10), (&k2, 20)]).await.unwrap(); + assert_eq!(results, vec![(&k1, 10, 10), (&k2, 20, 20)]); +} + +/// After set_all the new values must be visible to another instance immediately. +#[tokio::test] +async fn set_all_visible_to_other_instances() { + let (c1, c2, _) = make_lax_pair("sa_visible").await; + let k = key("hits"); + + c1.set_all(&[(&k, 99)]).await.unwrap(); + + let (cum, _) = c2.get(&k).await.unwrap(); + assert_eq!(cum, 99); +} + +/// set_all flushes any pending delta for the affected keys before calling strict.set. +#[tokio::test] +async fn set_all_flushes_pending_delta_first() { + let (c1, c2, _) = make_lax_pair("sa_flush_first").await; + let k = key("hits"); + + // c1 accumulates delta=5 (not yet flushed). + c1.inc(&k, 5).await.unwrap(); + // set_all must flush the pending delta before set, then set to 20. + c1.set_all(&[(&k, 20)]).await.unwrap(); + + // c2 must see 20, not 20+5. + let (cum, _) = c2.get(&k).await.unwrap(); + assert_eq!(cum, 20); +} + +#[tokio::test] +async fn inc_if_uses_all_comparators_against_local_view() { + let cases = [ + ("eq", CounterComparator::Eq(10), true), + ("lt", CounterComparator::Lt(11), true), + ("gt", CounterComparator::Gt(10), false), + ("ne", CounterComparator::Ne(9), true), + ("nil", CounterComparator::Nil, true), + ]; + + for (suffix, comparator, should_apply) in cases { + let c = make_lax(&format!("lax_inc_if_{suffix}")).await; + let k = key("hits"); + c.set(&k, 10).await.unwrap(); + + let (cum, inst) = c.inc_if(&k, comparator, 2).await.unwrap(); + let expected = if should_apply { (12, 12) } else { (10, 10) }; + + assert_eq!((cum, inst), expected); + assert_eq!(c.get(&k).await.unwrap(), expected); + } +} + +#[tokio::test] +async fn set_if_success_is_visible_to_other_instances_immediately() { + let (c1, c2, _) = make_lax_pair("lax_set_if_visible").await; + let k = key("hits"); + + c1.inc(&k, 5).await.unwrap(); + let result = c1.set_if(&k, CounterComparator::Eq(5), 20).await.unwrap(); + + assert_eq!(result, (20, 20)); + assert_eq!(c2.get(&k).await.unwrap().0, 20); +} + +#[tokio::test] +async fn inc_all_empty_and_inc_all_if_empty_return_empty() { + let c = make_lax("lax_inc_all_empty").await; + assert_eq!(c.inc_all(&[]).await.unwrap(), vec![]); + assert_eq!(c.inc_all_if(&[]).await.unwrap(), vec![]); +} + +#[tokio::test] +async fn inc_all_updates_local_view_immediately_and_supports_duplicates() { + let c = make_lax("lax_inc_all_duplicates").await; + let k = key("hits"); + + let results = c.inc_all(&[(&k, 1), (&k, 2)]).await.unwrap(); + + assert_eq!(results, vec![(&k, 1, 1), (&k, 3, 3)]); + assert_eq!(c.get(&k).await.unwrap(), (3, 3)); +} + +#[tokio::test] +async fn inc_all_if_uses_stale_aware_local_cumulative_and_is_sequential() { + let (c1, c2, _) = make_lax_pair("lax_inc_all_if_ordered").await; + let k = key("hits"); + + c2.inc(&k, 5).await.unwrap(); + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + + let results = c1 + .inc_all_if(&[ + (&k, CounterComparator::Eq(5), 1), + (&k, CounterComparator::Eq(6), 2), + (&k, CounterComparator::Gt(20), 4), + ]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k, 6, 1), (&k, 8, 3), (&k, 8, 3)]); + assert_eq!(c1.get(&k).await.unwrap(), (8, 3)); +} + +#[tokio::test] +async fn inc_all_if_successes_are_eventually_visible_after_flush() { + let (c1, c2, _) = make_lax_pair("lax_inc_all_if_flush").await; + let k1 = key("a"); + let k2 = key("b"); + + let results = c1 + .inc_all_if(&[ + (&k1, CounterComparator::Nil, 4), + (&k2, CounterComparator::Gt(0), 7), + ]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k1, 4, 4), (&k2, 0, 0)]); + + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + + assert_eq!(c2.get(&k1).await.unwrap().0, 4); + assert_eq!(c2.get(&k2).await.unwrap().0, 0); +} + +#[tokio::test] +async fn set_all_if_supports_partial_success() { + let (c1, c2, _) = make_lax_pair("lax_set_all_if_partial").await; + let k1 = key("a"); + let k2 = key("b"); + + c1.set_all(&[(&k1, 10), (&k2, 20)]).await.unwrap(); + + let results = c1 + .set_all_if(&[ + (&k1, CounterComparator::Nil, 11), + (&k2, CounterComparator::Lt(10), 99), + ]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k1, 11, 11), (&k2, 20, 20)]); + assert_eq!(c2.get(&k1).await.unwrap().0, 11); + assert_eq!(c2.get(&k2).await.unwrap().0, 20); +} + +#[tokio::test] +async fn set_all_on_instance_if_refreshes_stale_state_before_comparing() { + let (c1, c2, _) = make_lax_pair("lax_set_all_on_instance_if_refresh").await; + let k = key("hits"); + + c2.inc(&k, 5).await.unwrap(); + sleep(Duration::from_millis(FLUSH_MS * 5)).await; + + let results = c1 + .set_all_on_instance_if(&[(&k, CounterComparator::Eq(0), 3)]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k, 8, 3)]); + assert_eq!(c1.get(&k).await.unwrap(), (8, 3)); +} diff --git a/src/icounter/tests/mod.rs b/src/icounter/tests/mod.rs index 5947758..a7287de 100644 --- a/src/icounter/tests/mod.rs +++ b/src/icounter/tests/mod.rs @@ -1,3 +1,3 @@ pub mod common; -mod strict_instance_aware_counter; mod lax_instance_aware_counter; +mod strict_instance_aware_counter; diff --git a/src/icounter/tests/strict_instance_aware_counter.rs b/src/icounter/tests/strict_instance_aware_counter.rs index 4cf2ec6..816df1f 100644 --- a/src/icounter/tests/strict_instance_aware_counter.rs +++ b/src/icounter/tests/strict_instance_aware_counter.rs @@ -1,5 +1,8 @@ use std::time::Duration; +use crate::CounterComparator; +use crate::icounter::InstanceAwareCounterTrait; + use super::common::{ key, make_counter, make_n_counters, make_n_counters_with_opts, make_pair, make_pair_with_opts, }; @@ -968,3 +971,158 @@ async fn no_recovery_for_live_instance() { c1.clear().await.unwrap(); } + +#[tokio::test] +async fn inc_if_uses_all_comparators_against_cumulative() { + let cases = [ + ("eq", CounterComparator::Eq(10), true), + ("lt", CounterComparator::Lt(11), true), + ("gt", CounterComparator::Gt(10), false), + ("ne", CounterComparator::Ne(9), true), + ("nil", CounterComparator::Nil, true), + ]; + + for (suffix, comparator, should_apply) in cases { + let c = make_counter(&format!("strict_inc_if_{suffix}")).await; + let k = key("hits"); + c.set(&k, 10).await.unwrap(); + + let (cum, inst) = c.inc_if(&k, comparator, 2).await.unwrap(); + let expected = if should_apply { (12, 12) } else { (10, 10) }; + + assert_eq!((cum, inst), expected); + assert_eq!(c.get(&k).await.unwrap(), expected); + c.clear().await.unwrap(); + } +} + +#[tokio::test] +async fn inc_all_empty_and_inc_all_if_empty_return_empty() { + let c = make_counter("strict_inc_all_empty").await; + assert_eq!(c.inc_all(&[]).await.unwrap(), vec![]); + assert_eq!(c.inc_all_if(&[]).await.unwrap(), vec![]); + c.clear().await.unwrap(); +} + +#[tokio::test] +async fn inc_all_returns_ordered_results_and_supports_duplicates() { + let c = make_counter("strict_inc_all_duplicates").await; + let k = key("hits"); + + let results = c.inc_all(&[(&k, 1), (&k, 2)]).await.unwrap(); + + assert_eq!(results, vec![(&k, 1, 1), (&k, 3, 3)]); + assert_eq!(c.get(&k).await.unwrap(), (3, 3)); + c.clear().await.unwrap(); +} + +#[tokio::test] +async fn inc_all_if_compares_against_cumulative_and_processes_duplicates_sequentially() { + let (c1, c2) = make_pair("strict_inc_all_if_ordered").await; + let k1 = key("a"); + let k2 = key("b"); + + c2.inc(&k2, 5).await.unwrap(); + c1.set(&k1, 0).await.unwrap(); + + let results = c1 + .inc_all_if(&[ + (&k1, CounterComparator::Eq(0), 1), + (&k1, CounterComparator::Eq(1), 2), + (&k2, CounterComparator::Gt(10), 4), + (&k2, CounterComparator::Eq(5), 3), + ]) + .await + .unwrap(); + + assert_eq!( + results, + vec![(&k1, 1, 1), (&k1, 3, 3), (&k2, 5, 0), (&k2, 8, 3)] + ); + assert_eq!(c1.get(&k1).await.unwrap(), (3, 3)); + assert_eq!(c1.get(&k2).await.unwrap(), (8, 3)); + assert_eq!(c2.get(&k2).await.unwrap(), (8, 5)); + + c1.clear().await.unwrap(); +} + +#[tokio::test] +async fn set_on_instance_if_compares_against_instance_slice() { + let (c1, c2) = make_pair("strict_set_on_instance_if").await; + let k = key("hits"); + + c1.set_on_instance(&k, 7).await.unwrap(); + c2.set_on_instance(&k, 5).await.unwrap(); + + let result = c1 + .set_on_instance_if(&k, CounterComparator::Gt(6), 9) + .await + .unwrap(); + assert_eq!(result, (14, 9)); + + let failed = c1 + .set_on_instance_if(&k, CounterComparator::Eq(8), 50) + .await + .unwrap(); + assert_eq!(failed, (14, 9)); + + let unconditional = c1 + .set_on_instance_if(&k, CounterComparator::Nil, 11) + .await + .unwrap(); + assert_eq!(unconditional, (16, 11)); + + c1.clear().await.unwrap(); +} + +#[tokio::test] +async fn set_all_if_supports_partial_success_and_missing_keys() { + let c = make_counter("strict_set_all_if_partial").await; + let k1 = key("a"); + let k2 = key("b"); + let k3 = key("c"); + + c.set(&k1, 10).await.unwrap(); + c.set(&k2, 20).await.unwrap(); + + let results = c + .set_all_if(&[ + (&k3, CounterComparator::Nil, 30), + (&k1, CounterComparator::Gt(5), 11), + (&k2, CounterComparator::Lt(10), 99), + ]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k3, 30, 30), (&k1, 11, 11), (&k2, 20, 20)]); + assert_eq!(c.get(&k1).await.unwrap(), (11, 11)); + assert_eq!(c.get(&k2).await.unwrap(), (20, 20)); + assert_eq!(c.get(&k3).await.unwrap(), (30, 30)); + + c.clear().await.unwrap(); +} + +#[tokio::test] +async fn set_all_on_instance_if_supports_partial_success() { + let (c1, c2) = make_pair("strict_set_all_on_instance_if").await; + let k1 = key("a"); + let k2 = key("b"); + + c1.set_on_instance(&k1, 4).await.unwrap(); + c2.set_on_instance(&k1, 5).await.unwrap(); + c2.set_on_instance(&k2, 3).await.unwrap(); + + let results = c1 + .set_all_on_instance_if(&[ + (&k2, CounterComparator::Eq(1), 10), + (&k1, CounterComparator::Nil, 7), + ]) + .await + .unwrap(); + + assert_eq!(results, vec![(&k2, 3, 0), (&k1, 12, 7)]); + assert_eq!(c2.get(&k1).await.unwrap().0, 12); + assert_eq!(c2.get(&k2).await.unwrap(), (3, 3)); + + c1.clear().await.unwrap(); +} diff --git a/src/lib.rs b/src/lib.rs index 1f96edb..49afc53 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,8 @@ mod common; pub use common::*; +mod comparator; +pub use comparator::*; #[cfg(feature = "counter")] pub mod counter;