From 6a2198fe22f8cb4e7e61aed0287571e4c34b481d Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Thu, 25 Dec 2025 22:38:43 +0800 Subject: [PATCH 01/14] feat(tdigest): implement TDIGEST.TRIMMED_MEAN command --- src/commands/cmd_tdigest.cc | 52 ++++++++++++++ src/types/redis_tdigest.cc | 35 ++++++++++ src/types/redis_tdigest.h | 6 ++ src/types/tdigest.h | 70 +++++++++++++++++++ .../gocase/unit/type/tdigest/tdigest_test.go | 63 +++++++++++++++++ 5 files changed, 226 insertions(+) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index 64dfafcd7e8..12065574f3e 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -412,6 +412,57 @@ class CommandTDigestMerge : public Commander { TDigestMergeOptions options_; }; +class CommandTDigestTrimmedMean : public Commander { + public: + Status Parse(const std::vector &args) override { + if (args.size() != 4) { + return {Status::RedisParseErr, errWrongNumOfArguments}; + } + + key_name_ = args[1]; + + auto low_cut_quantile = ParseFloat(args[2]); + if (!low_cut_quantile) { + return {Status::RedisParseErr, errValueIsNotFloat}; + } + low_cut_quantile_ = *low_cut_quantile; + + auto high_cut_quantile = ParseFloat(args[3]); + if (!high_cut_quantile) { + return {Status::RedisParseErr, errValueIsNotFloat}; + } + high_cut_quantile_ = *high_cut_quantile; + + return Status::OK(); + } + + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + TDigest tdigest(srv->storage, conn->GetNamespace()); + TDigestTrimmedMeanResult result; + + auto s = tdigest.TrimmedMean(ctx, key_name_, low_cut_quantile_, high_cut_quantile_, &result); + if (!s.ok()) { + if (s.IsNotFound()) { + return {Status::RedisExecErr, errKeyNotFound}; + } + return {Status::RedisExecErr, s.ToString()}; + } + + if (!result.mean.has_value()) { + *output = redis::BulkString(kNan); + } else { + *output = redis::BulkString(util::Float2String(*result.mean)); + } + + return Status::OK(); + } + + private: + std::string key_name_; + double low_cut_quantile_; + double high_cut_quantile_; +}; + std::vector GetMergeKeyRange(const std::vector &args) { auto numkeys = ParseInt(args[2], 10).ValueOr(0); return {{1, 1, 1}, {3, 2 + numkeys, 1}}; @@ -425,6 +476,7 @@ REDIS_REGISTER_COMMANDS(TDigest, MakeCmdAttr("tdigest.crea MakeCmdAttr("tdigest.revrank", -3, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.rank", -3, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.quantile", -3, "read-only", 1, 1, 1), + MakeCmdAttr("tdigest.trimmed_mean", 4, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.reset", 2, "write", 1, 1, 1), MakeCmdAttr("tdigest.merge", -4, "write", GetMergeKeyRange)); } // namespace redis diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc index fec7aef1c9b..51b14c0e750 100644 --- a/src/types/redis_tdigest.cc +++ b/src/types/redis_tdigest.cc @@ -725,6 +725,41 @@ rocksdb::Status TDigest::applyNewCentroids(ObserverOrUniquePtrGetLockManager(), ns_key); + if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) { + return status; + } + + if (metadata.total_observations == 0) { + return rocksdb::Status::OK(); + } + + if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) { + return status; + } + } + + // Dump centroids and create DummyCentroids wrapper for TDigest algorithm + std::vector centroids; + if (auto status = dumpCentroids(ctx, ns_key, metadata, ¢roids); !status.ok()) { + return status; + } + auto dump_centroids = DummyCentroids(metadata, centroids); + auto trimmed_mean_result = TDigestTrimmedMean(dump_centroids, low_cut_quantile, high_cut_quantile); + if (!trimmed_mean_result) { + return rocksdb::Status::InvalidArgument(trimmed_mean_result.Msg()); + } + + result->mean = *trimmed_mean_result; + return rocksdb::Status::OK(); +} + std::string TDigest::internalSegmentGuardPrefixKey(const TDigestMetadata& metadata, const std::string& ns_key, SegmentType seg) const { std::string prefix_key; diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h index 5daaed80c81..81917d6b8ea 100644 --- a/src/types/redis_tdigest.h +++ b/src/types/redis_tdigest.h @@ -53,6 +53,10 @@ struct TDigestQuantitleResult { std::optional> quantiles; }; +struct TDigestTrimmedMeanResult { + std::optional mean; +}; + class TDigest : public SubKeyScanner { public: using Slice = rocksdb::Slice; @@ -79,6 +83,8 @@ class TDigest : public SubKeyScanner { const TDigestMergeOptions& options); rocksdb::Status Rank(engine::Context& ctx, const Slice& digest_name, const std::vector& inputs, bool reverse, std::vector& result); + rocksdb::Status TrimmedMean(engine::Context& ctx, const Slice& digest_name, double low_cut_quantile, + double high_cut_quantile, TDigestTrimmedMeanResult* result); rocksdb::Status GetMetaData(engine::Context& context, const Slice& digest_name, TDigestMetadata* metadata); private: diff --git a/src/types/tdigest.h b/src/types/tdigest.h index d77b673f7a8..8ec68fc84e3 100644 --- a/src/types/tdigest.h +++ b/src/types/tdigest.h @@ -276,3 +276,73 @@ inline Status TDigestRank(TD&& td, const std::vector& inputs, bool rever return TDigestRankImpl(std::forward(td), inputs, result); } } + + +template +inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, double high_cut_quantile) { + if (td.Size() == 0) { + return Status{Status::InvalidArgument, "empty tdigest"}; + } + + // Validate quantile parameters + if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) { + return Status{Status::InvalidArgument, "low cut quantile must be between 0 and 1"}; + } + if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) { + return Status{Status::InvalidArgument, "high cut quantile must be between 0 and 1"}; + } + if (low_cut_quantile >= high_cut_quantile) { + return Status{Status::InvalidArgument, "low cut quantile must be less than high cut quantile"}; + } + + // Get boundary values for trimming + double low_boundary; + double high_boundary; + + // For 0 and 1 quantiles, use exact min/max values + if (low_cut_quantile == 0.0) { + low_boundary = td.Min(); + } else { + auto low_result = TDigestQuantile(std::forward(td), low_cut_quantile); + if (!low_result) { + return low_result; + } + low_boundary = *low_result; + } + + if (high_cut_quantile == 1.0) { + high_boundary = td.Max(); + } else { + auto high_result = TDigestQuantile(std::forward(td), high_cut_quantile); + if (!high_result) { + return high_result; + } + high_boundary = *high_result; + } + + // Calculate trimmed mean by iterating through centroids + auto iter = td.Begin(); + double total_weight_in_range = 0; + double weighted_sum = 0; + + while (iter->Valid()) { + auto centroid = GET_OR_RET(iter->GetCentroid()); + + // Check if centroid falls within the trimmed range + // For full range (0 to 1), include all centroids + if ((low_cut_quantile == 0.0 && high_cut_quantile == 1.0) || + (centroid.mean >= low_boundary && centroid.mean <= high_boundary)) { + total_weight_in_range += centroid.weight; + weighted_sum += centroid.mean * centroid.weight; + } + + iter->Next(); + } + + // Check if we have any data in the trimmed range + if (total_weight_in_range == 0) { + return std::numeric_limits::quiet_NaN(); + } + + return weighted_sum / total_weight_in_range; +} diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index 335ee0eff75..472d530dbe1 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -716,4 +716,67 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.EqualValues(t, expected[i], rank, "REVRANK mismatch at index %d", i) } }) + + t.Run("tdigest.trimmed_mean with different arguments", func(t *testing.T) { + keyPrefix := "tdigest_trimmed_mean_" + + // Test invalid arguments + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"key").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"key", "0.1").Err(), errMsgWrongNumberArg) + + // Test non-existent key + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist) + + // Test with empty tdigest + emptyKey := keyPrefix + "empty" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, "compression", "100").Err()) + rsp := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", "0.9") + require.NoError(t, rsp.Err()) + result, err := rsp.Result() + require.NoError(t, err) + require.Equal(t, "nan", result) + + // Test with sample data + key1 := keyPrefix + "test1" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key1, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key1, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + // Test trimmed mean with trimming + rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.1", "0.9") + require.NoError(t, rsp.Err()) + result, err = rsp.Result() + require.NoError(t, err) + mean, err := strconv.ParseFloat(result.(string), 64) + require.NoError(t, err) + require.InDelta(t, 5.5, mean, 1.0) + + // Test with no trimming + rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0", "1") + require.NoError(t, rsp.Err()) + result, err = rsp.Result() + require.NoError(t, err) + mean, err = strconv.ParseFloat(result.(string), 64) + require.NoError(t, err) + require.InDelta(t, 5.5, mean, 0.1) + + // Test with invalid quantile ranges + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1") + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1") + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.9", "0.1").Err(), "low cut quantile must be less than high cut quantile") + + // Test with skewed data + key2 := keyPrefix + "skewed" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key2, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key2, "1", "1", "1", "1", "1", "10", "100").Err()) + + // Test trimming with outliers + rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key2, "0.2", "0.8") + require.NoError(t, rsp.Err()) + result, err = rsp.Result() + require.NoError(t, err) + mean, err = strconv.ParseFloat(result.(string), 64) + require.NoError(t, err) + require.Less(t, mean, 50.0) + }) } From 178b1ee47c87ddcb9c60160f3588e8ed67d0aae5 Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Sun, 18 Jan 2026 22:21:56 +0800 Subject: [PATCH 02/14] tdigest: refine trimmed mean implementation and tests --- src/commands/cmd_tdigest.cc | 8 +- src/types/redis_tdigest.cc | 2 +- src/types/tdigest.h | 24 +--- tests/cppunit/types/tdigest_test.cc | 29 ++++ .../gocase/unit/type/tdigest/tdigest_test.go | 126 +++++++++++------- 5 files changed, 117 insertions(+), 72 deletions(-) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index 3692e2ff9c2..f615c4bc1d7 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -498,21 +498,21 @@ class CommandTDigestTrimmedMean : public Commander { if (args.size() != 4) { return {Status::RedisParseErr, errWrongNumOfArguments}; } - + key_name_ = args[1]; - + auto low_cut_quantile = ParseFloat(args[2]); if (!low_cut_quantile) { return {Status::RedisParseErr, errValueIsNotFloat}; } low_cut_quantile_ = *low_cut_quantile; - + auto high_cut_quantile = ParseFloat(args[3]); if (!high_cut_quantile) { return {Status::RedisParseErr, errValueIsNotFloat}; } high_cut_quantile_ = *high_cut_quantile; - + return Status::OK(); } diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc index 8d6ecae5b96..addaf7bbedb 100644 --- a/src/types/redis_tdigest.cc +++ b/src/types/redis_tdigest.cc @@ -784,7 +784,7 @@ rocksdb::Status TDigest::TrimmedMean(engine::Context& ctx, const Slice& digest_n if (auto status = dumpCentroids(ctx, ns_key, metadata, ¢roids); !status.ok()) { return status; } - auto dump_centroids = DummyCentroids(metadata, centroids); + auto dump_centroids = DummyCentroids(metadata, centroids); auto trimmed_mean_result = TDigestTrimmedMean(dump_centroids, low_cut_quantile, high_cut_quantile); if (!trimmed_mean_result) { return rocksdb::Status::InvalidArgument(trimmed_mean_result.Msg()); diff --git a/src/types/tdigest.h b/src/types/tdigest.h index 57c3fe32b79..f938b5b8c13 100644 --- a/src/types/tdigest.h +++ b/src/types/tdigest.h @@ -310,22 +310,12 @@ inline Status TDigestRank(TD&& td, const std::vector& inputs, std::vecto return Status::OK(); } -template -inline Status TDigestRank(TD&& td, const std::vector& inputs, bool reverse, std::vector& result) { - if (reverse) { - return TDigestRankImpl(std::forward(td), inputs, result); - } else { - return TDigestRankImpl(std::forward(td), inputs, result); - } -} - template inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, double high_cut_quantile) { if (td.Size() == 0) { return Status{Status::InvalidArgument, "empty tdigest"}; } - // Validate quantile parameters if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) { return Status{Status::InvalidArgument, "low cut quantile must be between 0 and 1"}; } @@ -336,15 +326,13 @@ inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, dou return Status{Status::InvalidArgument, "low cut quantile must be less than high cut quantile"}; } - // Get boundary values for trimming - double low_boundary; - double high_boundary; + double low_boundary = std::numeric_limits::quiet_NaN(); + double high_boundary = std::numeric_limits::quiet_NaN(); - // For 0 and 1 quantiles, use exact min/max values if (low_cut_quantile == 0.0) { low_boundary = td.Min(); } else { - auto low_result = TDigestQuantile(std::forward(td), low_cut_quantile); + auto low_result = TDigestQuantile(td, low_cut_quantile); if (!low_result) { return low_result; } @@ -354,14 +342,13 @@ inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, dou if (high_cut_quantile == 1.0) { high_boundary = td.Max(); } else { - auto high_result = TDigestQuantile(std::forward(td), high_cut_quantile); + auto high_result = TDigestQuantile(td, high_cut_quantile); if (!high_result) { return high_result; } high_boundary = *high_result; } - // Calculate trimmed mean by iterating through centroids auto iter = td.Begin(); double total_weight_in_range = 0; double weighted_sum = 0; @@ -369,8 +356,6 @@ inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, dou while (iter->Valid()) { auto centroid = GET_OR_RET(iter->GetCentroid()); - // Check if centroid falls within the trimmed range - // For full range (0 to 1), include all centroids if ((low_cut_quantile == 0.0 && high_cut_quantile == 1.0) || (centroid.mean >= low_boundary && centroid.mean <= high_boundary)) { total_weight_in_range += centroid.weight; @@ -380,7 +365,6 @@ inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, dou iter->Next(); } - // Check if we have any data in the trimmed range if (total_weight_in_range == 0) { return std::numeric_limits::quiet_NaN(); } diff --git a/tests/cppunit/types/tdigest_test.cc b/tests/cppunit/types/tdigest_test.cc index ed12df1368f..d704f84f4aa 100644 --- a/tests/cppunit/types/tdigest_test.cc +++ b/tests/cppunit/types/tdigest_test.cc @@ -524,3 +524,32 @@ TEST_F(RedisTDigestTest, ByRank_And_ByRevRank) { EXPECT_EQ(result[0], 1.0) << "Rank 0 should be minimum"; EXPECT_TRUE(std::isinf(result[3])) << "Rank >= total_weight should be infinity"; } + +TEST_F(RedisTDigestTest, TrimmedMean) { + std::string test_digest_name = "test_digest_trimmed_mean" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + + std::vector values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + status = tdigest_->Add(*ctx_, test_digest_name, values); + ASSERT_TRUE(status.ok()) << status.ToString(); + + redis::TDigestTrimmedMeanResult result; + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.1, 0.9, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + EXPECT_NEAR(*result.mean, 5.5, 1.0) << "Trimmed mean should be approximately 5.5"; + + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.0, 1.0, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + EXPECT_NEAR(*result.mean, 5.5, 0.1) << "Full range should equal complete mean"; + + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.25, 0.75, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + EXPECT_GT(*result.mean, 3.0) << "Trimmed mean should be greater than 3.0"; + EXPECT_LT(*result.mean, 8.0) << "Trimmed mean should be less than 8.0"; +} diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index 205ed1edd6b..f883f3c02c8 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -718,67 +718,99 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { } }) - t.Run("tdigest.trimmed_mean with different arguments", func(t *testing.T) { - keyPrefix := "tdigest_trimmed_mean_" - - // Test invalid arguments - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg) - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"key").Err(), errMsgWrongNumberArg) - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"key", "0.1").Err(), errMsgWrongNumberArg) - - // Test non-existent key - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", keyPrefix+"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist) + t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) { + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist) + }) - // Test with empty tdigest - emptyKey := keyPrefix + "empty" + t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) { + emptyKey := "tdigest_empty" require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, "compression", "100").Err()) - rsp := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", "0.9") - require.NoError(t, rsp.Err()) - result, err := rsp.Result() - require.NoError(t, err) - require.Equal(t, "nan", result) - // Test with sample data - key1 := keyPrefix + "test1" - require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key1, "compression", "100").Err()) - require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key1, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", "0.9") + require.NoError(t, result.Err()) + require.Equal(t, "nan", result.Val()) + }) - // Test trimmed mean with trimming - rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.1", "0.9") - require.NoError(t, rsp.Err()) - result, err = rsp.Result() - require.NoError(t, err) - mean, err := strconv.ParseFloat(result.(string), 64) + t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) { + key := "tdigest_basic" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9") + require.NoError(t, result.Err()) + mean, err := strconv.ParseFloat(result.Val().(string), 64) require.NoError(t, err) require.InDelta(t, 5.5, mean, 1.0) + }) - // Test with no trimming - rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0", "1") - require.NoError(t, rsp.Err()) - result, err = rsp.Result() - require.NoError(t, err) - mean, err = strconv.ParseFloat(result.(string), 64) + t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) { + key := "tdigest_no_trim" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1") + require.NoError(t, result.Err()) + mean, err := strconv.ParseFloat(result.Val().(string), 64) require.NoError(t, err) require.InDelta(t, 5.5, mean, 0.1) + }) - // Test with invalid quantile ranges - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1") - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1") - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key1, "0.9", "0.1").Err(), "low cut quantile must be less than high cut quantile") + t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) { + key := "tdigest_skewed" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1", "1", "1", "1", "10", "100").Err()) - // Test with skewed data - key2 := keyPrefix + "skewed" - require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key2, "compression", "100").Err()) - require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key2, "1", "1", "1", "1", "1", "10", "100").Err()) + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8") + require.NoError(t, result.Err()) + mean, err := strconv.ParseFloat(result.Val().(string), 64) + require.NoError(t, err) + require.Less(t, mean, 50.0) + }) - // Test trimming with outliers - rsp = rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key2, "0.2", "0.8") - require.NoError(t, rsp.Err()) - result, err = rsp.Result() + t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t *testing.T) { + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "key").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "key", "0.1").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg) + }) + + t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t *testing.T) { + key := "tdigest_invalid" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5").Err()) + + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1") + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1") + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.9", "0.1").Err(), "low cut quantile must be less than high cut quantile") + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.5", "0.5").Err(), "low cut quantile must be less than high cut quantile") + }) + + t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) { + key := "tdigest_single" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9") + require.NoError(t, result.Err()) + mean, err := strconv.ParseFloat(result.Val().(string), 64) require.NoError(t, err) - mean, err = strconv.ParseFloat(result.(string), 64) + require.InDelta(t, 42.0, mean, 0.001) + }) + + t.Run("TDIGEST.TRIMMED_MEAN with extreme trimming", func(t *testing.T) { + key := "tdigest_extreme" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.4", "0.6") + require.NoError(t, result.Err()) + meanStr := result.Val().(string) + if meanStr == "nan" { + return + } + mean, err := strconv.ParseFloat(meanStr, 64) require.NoError(t, err) - require.Less(t, mean, 50.0) + require.Greater(t, mean, 0.0) }) } From 3f01bd614a242011952021b54f54825d0f2cc6ad Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Sun, 8 Mar 2026 11:38:05 +0800 Subject: [PATCH 03/14] refactor(tdigest): move TRIMMED_MEAN quantile validation to parse step --- src/commands/cmd_tdigest.cc | 10 ++++++++++ src/types/tdigest.h | 10 ---------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index f615c4bc1d7..f42bc507d07 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -513,6 +513,16 @@ class CommandTDigestTrimmedMean : public Commander { } high_cut_quantile_ = *high_cut_quantile; + if (low_cut_quantile_ < 0.0 || low_cut_quantile_ > 1.0) { + return {Status::RedisParseErr, "low cut quantile must be between 0 and 1"}; + } + if (high_cut_quantile_ < 0.0 || high_cut_quantile_ > 1.0) { + return {Status::RedisParseErr, "high cut quantile must be between 0 and 1"}; + } + if (low_cut_quantile_ >= high_cut_quantile_) { + return {Status::RedisParseErr, "low cut quantile must be less than high cut quantile"}; + } + return Status::OK(); } diff --git a/src/types/tdigest.h b/src/types/tdigest.h index f938b5b8c13..e5231eee55e 100644 --- a/src/types/tdigest.h +++ b/src/types/tdigest.h @@ -316,16 +316,6 @@ inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, dou return Status{Status::InvalidArgument, "empty tdigest"}; } - if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) { - return Status{Status::InvalidArgument, "low cut quantile must be between 0 and 1"}; - } - if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) { - return Status{Status::InvalidArgument, "high cut quantile must be between 0 and 1"}; - } - if (low_cut_quantile >= high_cut_quantile) { - return Status{Status::InvalidArgument, "low cut quantile must be less than high cut quantile"}; - } - double low_boundary = std::numeric_limits::quiet_NaN(); double high_boundary = std::numeric_limits::quiet_NaN(); From af7b44fdf9fcb6d763bc0998b1d8e814348bf3e8 Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Sun, 8 Mar 2026 13:12:52 +0800 Subject: [PATCH 04/14] fix(tdigest): fix TRIMMED_MEAN algorithm and parameter validation --- src/commands/cmd_tdigest.cc | 2 +- src/types/tdigest.h | 52 +++++++++++++++---------------------- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index f42bc507d07..3d0367b44ea 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -519,7 +519,7 @@ class CommandTDigestTrimmedMean : public Commander { if (high_cut_quantile_ < 0.0 || high_cut_quantile_ > 1.0) { return {Status::RedisParseErr, "high cut quantile must be between 0 and 1"}; } - if (low_cut_quantile_ >= high_cut_quantile_) { + if (DoubleCompare(low_cut_quantile_, high_cut_quantile_) >= 0) { return {Status::RedisParseErr, "low cut quantile must be less than high cut quantile"}; } diff --git a/src/types/tdigest.h b/src/types/tdigest.h index e5231eee55e..01b049623c3 100644 --- a/src/types/tdigest.h +++ b/src/types/tdigest.h @@ -313,51 +313,41 @@ inline Status TDigestRank(TD&& td, const std::vector& inputs, std::vecto template inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, double high_cut_quantile) { if (td.Size() == 0) { - return Status{Status::InvalidArgument, "empty tdigest"}; + return std::numeric_limits::quiet_NaN(); } - double low_boundary = std::numeric_limits::quiet_NaN(); - double high_boundary = std::numeric_limits::quiet_NaN(); + const double total_weight = td.TotalWeight(); + const double leftmost_weight = std::floor(total_weight * low_cut_quantile); + const double rightmost_weight = std::ceil(total_weight * high_cut_quantile); - if (low_cut_quantile == 0.0) { - low_boundary = td.Min(); - } else { - auto low_result = TDigestQuantile(td, low_cut_quantile); - if (!low_result) { - return low_result; - } - low_boundary = *low_result; - } - - if (high_cut_quantile == 1.0) { - high_boundary = td.Max(); - } else { - auto high_result = TDigestQuantile(td, high_cut_quantile); - if (!high_result) { - return high_result; - } - high_boundary = *high_result; - } + double count_done = 0.0; + double trimmed_sum = 0.0; + double trimmed_count = 0.0; auto iter = td.Begin(); - double total_weight_in_range = 0; - double weighted_sum = 0; - while (iter->Valid()) { auto centroid = GET_OR_RET(iter->GetCentroid()); + const double n_weight = centroid.weight; + double count_add = n_weight; + + count_add -= std::min(std::max(0.0, leftmost_weight - count_done), count_add); + count_add = std::min(std::max(0.0, rightmost_weight - count_done), count_add); - if ((low_cut_quantile == 0.0 && high_cut_quantile == 1.0) || - (centroid.mean >= low_boundary && centroid.mean <= high_boundary)) { - total_weight_in_range += centroid.weight; - weighted_sum += centroid.mean * centroid.weight; + count_done += n_weight; + + trimmed_sum += centroid.mean * count_add; + trimmed_count += count_add; + + if (count_done >= rightmost_weight) { + break; } iter->Next(); } - if (total_weight_in_range == 0) { + if (trimmed_count == 0.0) { return std::numeric_limits::quiet_NaN(); } - return weighted_sum / total_weight_in_range; + return trimmed_sum / trimmed_count; } From a7f773fa46f1a5ea1a683e37d7203d3546b47ab3 Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Sun, 8 Mar 2026 13:13:29 +0800 Subject: [PATCH 05/14] test(tdigest): improve TRIMMED_MEAN test coverage and precision --- tests/cppunit/types/tdigest_test.cc | 57 +++++++++++++++++-- .../gocase/unit/type/tdigest/tdigest_test.go | 19 +++---- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/tests/cppunit/types/tdigest_test.cc b/tests/cppunit/types/tdigest_test.cc index d704f84f4aa..9fc1fb1f729 100644 --- a/tests/cppunit/types/tdigest_test.cc +++ b/tests/cppunit/types/tdigest_test.cc @@ -540,16 +540,65 @@ TEST_F(RedisTDigestTest, TrimmedMean) { status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.1, 0.9, &result); ASSERT_TRUE(status.ok()) << status.ToString(); ASSERT_TRUE(result.mean.has_value()); - EXPECT_NEAR(*result.mean, 5.5, 1.0) << "Trimmed mean should be approximately 5.5"; + EXPECT_NEAR(*result.mean, 5.5, 0.01); status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.0, 1.0, &result); ASSERT_TRUE(status.ok()) << status.ToString(); ASSERT_TRUE(result.mean.has_value()); - EXPECT_NEAR(*result.mean, 5.5, 0.1) << "Full range should equal complete mean"; + EXPECT_NEAR(*result.mean, 5.5, 0.01); status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.25, 0.75, &result); ASSERT_TRUE(status.ok()) << status.ToString(); ASSERT_TRUE(result.mean.has_value()); - EXPECT_GT(*result.mean, 3.0) << "Trimmed mean should be greater than 3.0"; - EXPECT_LT(*result.mean, 8.0) << "Trimmed mean should be less than 8.0"; + EXPECT_NEAR(*result.mean, 5.5, 0.01); +} + +TEST_F(RedisTDigestTest, TrimmedMeanEmptyDigest) { + std::string test_digest_name = "test_digest_trimmed_mean_empty" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + + redis::TDigestTrimmedMeanResult result; + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.1, 0.9, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_FALSE(result.mean.has_value()); +} + +TEST_F(RedisTDigestTest, TrimmedMeanUnorderedInput) { + std::string test_digest_name = "test_digest_trimmed_mean_unordered" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + + std::vector values = {5, 2, 8, 1, 9, 3, 7, 4, 6, 10}; + status = tdigest_->Add(*ctx_, test_digest_name, values); + ASSERT_TRUE(status.ok()) << status.ToString(); + + redis::TDigestTrimmedMeanResult result; + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.1, 0.9, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + EXPECT_NEAR(*result.mean, 5.5, 0.01); +} + +TEST_F(RedisTDigestTest, TrimmedMeanComplexInput) { + std::string test_digest_name = "test_digest_trimmed_mean_complex" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + + std::vector values = {-10, 5, -3, 5, 0, 5, 3, -5, 10, -10}; + status = tdigest_->Add(*ctx_, test_digest_name, values); + ASSERT_TRUE(status.ok()) << status.ToString(); + + redis::TDigestTrimmedMeanResult result; + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.2, 0.8, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + ASSERT_FALSE(std::isnan(*result.mean)); + EXPECT_NEAR(*result.mean, 5.0 / 6.0, 0.01); } diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index f883f3c02c8..662ab92ea47 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -740,7 +740,7 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.NoError(t, result.Err()) mean, err := strconv.ParseFloat(result.Val().(string), 64) require.NoError(t, err) - require.InDelta(t, 5.5, mean, 1.0) + require.InDelta(t, 5.5, mean, 0.01) }) t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) { @@ -752,7 +752,7 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.NoError(t, result.Err()) mean, err := strconv.ParseFloat(result.Val().(string), 64) require.NoError(t, err) - require.InDelta(t, 5.5, mean, 0.1) + require.InDelta(t, 5.5, mean, 0.01) }) t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) { @@ -764,7 +764,7 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.NoError(t, result.Err()) mean, err := strconv.ParseFloat(result.Val().(string), 64) require.NoError(t, err) - require.Less(t, mean, 50.0) + require.InDelta(t, 2.8, mean, 0.01) }) t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t *testing.T) { @@ -788,13 +788,13 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) { key := "tdigest_single" require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) - require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42", "42").Err()) result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9") require.NoError(t, result.Err()) mean, err := strconv.ParseFloat(result.Val().(string), 64) require.NoError(t, err) - require.InDelta(t, 42.0, mean, 0.001) + require.InDelta(t, 42.0, mean, 0.01) }) t.Run("TDIGEST.TRIMMED_MEAN with extreme trimming", func(t *testing.T) { @@ -804,13 +804,10 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.4", "0.6") require.NoError(t, result.Err()) - meanStr := result.Val().(string) - if meanStr == "nan" { - return - } - mean, err := strconv.ParseFloat(meanStr, 64) + mean, err := strconv.ParseFloat(result.Val().(string), 64) require.NoError(t, err) - require.Greater(t, mean, 0.0) + require.False(t, math.IsNaN(mean)) + require.InDelta(t, 5.5, mean, 0.01) }) } From 75c60a737ce8703cd787a0c5222b9d53fd1176dd Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Sun, 8 Mar 2026 13:19:24 +0800 Subject: [PATCH 06/14] test(tdigest): extract TRIMMED_MEAN error messages as constants --- tests/gocase/unit/type/tdigest/tdigest_test.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index 662ab92ea47..65f76f71093 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -40,6 +40,9 @@ const ( errMsgKeyNotExist = "key does not exist" errNumkeysMustBePositive = "numkeys need to be a positive integer" errCompressionParameterMustBePositive = "compression parameter needs to be a positive integer" + errMsgLowCutQuantileRange = "low cut quantile must be between 0 and 1" + errMsgHighCutQuantileRange = "high cut quantile must be between 0 and 1" + errMsgLowCutQuantileLess = "low cut quantile must be less than high cut quantile" ) type tdigestInfo struct { @@ -779,10 +782,10 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5").Err()) - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1") - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1") - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.9", "0.1").Err(), "low cut quantile must be less than high cut quantile") - require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.5", "0.5").Err(), "low cut quantile must be less than high cut quantile") + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "-0.1", "0.9").Err(), errMsgLowCutQuantileRange) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "1.1").Err(), errMsgHighCutQuantileRange) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.9", "0.1").Err(), errMsgLowCutQuantileLess) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.5", "0.5").Err(), errMsgLowCutQuantileLess) }) t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) { From 7d024634b5526e8e9ea5abaca579f95f6d5ec6c2 Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Mon, 9 Mar 2026 21:59:31 +0800 Subject: [PATCH 07/14] test(tdigest): extract TRIMMED_MEAN error messages as constants --- src/commands/cmd_tdigest.cc | 6 +++--- src/commands/error_constants.h | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index 3d0367b44ea..a195ca5a424 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -514,13 +514,13 @@ class CommandTDigestTrimmedMean : public Commander { high_cut_quantile_ = *high_cut_quantile; if (low_cut_quantile_ < 0.0 || low_cut_quantile_ > 1.0) { - return {Status::RedisParseErr, "low cut quantile must be between 0 and 1"}; + return {Status::RedisParseErr, errLowCutQuantileRange}; } if (high_cut_quantile_ < 0.0 || high_cut_quantile_ > 1.0) { - return {Status::RedisParseErr, "high cut quantile must be between 0 and 1"}; + return {Status::RedisParseErr, errHighCutQuantileRange}; } if (DoubleCompare(low_cut_quantile_, high_cut_quantile_) >= 0) { - return {Status::RedisParseErr, "low cut quantile must be less than high cut quantile"}; + return {Status::RedisParseErr, errLowCutQuantileLess}; } return Status::OK(); diff --git a/src/commands/error_constants.h b/src/commands/error_constants.h index bde89a6d7f0..98c9a7a126d 100644 --- a/src/commands/error_constants.h +++ b/src/commands/error_constants.h @@ -54,4 +54,7 @@ inline constexpr const char *errParsingNumkeys = "error parsing numkeys"; inline constexpr const char *errNumkeysMustBePositive = "numkeys need to be a positive integer"; inline constexpr const char *errWrongKeyword = "wrong keyword"; inline constexpr const char *errInvalidRankValue = "rank needs to be non-negative"; +inline constexpr const char *errLowCutQuantileRange = "low cut quantile must be between 0 and 1"; +inline constexpr const char *errHighCutQuantileRange = "high cut quantile must be between 0 and 1"; +inline constexpr const char *errLowCutQuantileLess = "low cut quantile must be less than high cut quantile"; } // namespace redis From 8a3e9b11882295fe88717429e5ae1d67be1a2150 Mon Sep 17 00:00:00 2001 From: chakkk309 Date: Mon, 9 Mar 2026 23:24:08 +0800 Subject: [PATCH 08/14] Update src/types/redis_tdigest.cc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/types/redis_tdigest.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc index addaf7bbedb..dcdc0f44c34 100644 --- a/src/types/redis_tdigest.cc +++ b/src/types/redis_tdigest.cc @@ -771,6 +771,7 @@ rocksdb::Status TDigest::TrimmedMean(engine::Context& ctx, const Slice& digest_n } if (metadata.total_observations == 0) { + result->mean.reset(); return rocksdb::Status::OK(); } From fedf4b868976ee854fa86349c5be7964cef319ed Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Mon, 9 Mar 2026 23:24:26 +0800 Subject: [PATCH 09/14] fix(tdigest): handle non-finite quantile inputs --- src/commands/cmd_tdigest.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index a195ca5a424..42c322328de 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -513,10 +513,10 @@ class CommandTDigestTrimmedMean : public Commander { } high_cut_quantile_ = *high_cut_quantile; - if (low_cut_quantile_ < 0.0 || low_cut_quantile_ > 1.0) { + if (!std::isfinite(low_cut_quantile_) || low_cut_quantile_ < 0.0 || low_cut_quantile_ > 1.0) { return {Status::RedisParseErr, errLowCutQuantileRange}; } - if (high_cut_quantile_ < 0.0 || high_cut_quantile_ > 1.0) { + if (!std::isfinite(high_cut_quantile_) || high_cut_quantile_ < 0.0 || high_cut_quantile_ > 1.0) { return {Status::RedisParseErr, errHighCutQuantileRange}; } if (DoubleCompare(low_cut_quantile_, high_cut_quantile_) >= 0) { From 6fb6a9e357de4d7cdd5ee8cab492e6de2c0246cd Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Thu, 12 Mar 2026 22:39:30 +0800 Subject: [PATCH 10/14] chore(types): clarify trimmed mean overlap logic --- src/types/tdigest.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/types/tdigest.h b/src/types/tdigest.h index 01b049623c3..1f152c1e00d 100644 --- a/src/types/tdigest.h +++ b/src/types/tdigest.h @@ -330,6 +330,7 @@ inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, dou const double n_weight = centroid.weight; double count_add = n_weight; + // Keep only the portion of this centroid that overlaps with the trimmed rank range. count_add -= std::min(std::max(0.0, leftmost_weight - count_done), count_add); count_add = std::min(std::max(0.0, rightmost_weight - count_done), count_add); From f5b60b06f9d6d8922ed13663bf2be1db94659e10 Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Thu, 12 Mar 2026 22:41:39 +0800 Subject: [PATCH 11/14] chore(tests): remove redundant tdigest NaN assertion --- tests/gocase/unit/type/tdigest/tdigest_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index 65f76f71093..68c83102b32 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -809,7 +809,6 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.NoError(t, result.Err()) mean, err := strconv.ParseFloat(result.Val().(string), 64) require.NoError(t, err) - require.False(t, math.IsNaN(mean)) require.InDelta(t, 5.5, mean, 0.01) }) } From 49cdccb2cba5e251bb811992dc704779c7dbcb75 Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Thu, 12 Mar 2026 23:06:36 +0800 Subject: [PATCH 12/14] fix(commands): align tdigest trimmed mean errors with redis --- src/commands/cmd_tdigest.cc | 6 ++-- src/commands/error_constants.h | 11 +++++-- .../gocase/unit/type/tdigest/tdigest_test.go | 29 +++++++++++++++++-- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index 42c322328de..a30b46e4d0d 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -503,13 +503,13 @@ class CommandTDigestTrimmedMean : public Commander { auto low_cut_quantile = ParseFloat(args[2]); if (!low_cut_quantile) { - return {Status::RedisParseErr, errValueIsNotFloat}; + return {Status::RedisParseErr, errParseLowCutQuantile}; } low_cut_quantile_ = *low_cut_quantile; auto high_cut_quantile = ParseFloat(args[3]); if (!high_cut_quantile) { - return {Status::RedisParseErr, errValueIsNotFloat}; + return {Status::RedisParseErr, errParseHighCutQuantile}; } high_cut_quantile_ = *high_cut_quantile; @@ -519,7 +519,7 @@ class CommandTDigestTrimmedMean : public Commander { if (!std::isfinite(high_cut_quantile_) || high_cut_quantile_ < 0.0 || high_cut_quantile_ > 1.0) { return {Status::RedisParseErr, errHighCutQuantileRange}; } - if (DoubleCompare(low_cut_quantile_, high_cut_quantile_) >= 0) { + if (low_cut_quantile_ >= high_cut_quantile_) { return {Status::RedisParseErr, errLowCutQuantileLess}; } diff --git a/src/commands/error_constants.h b/src/commands/error_constants.h index 98c9a7a126d..6b91aecf666 100644 --- a/src/commands/error_constants.h +++ b/src/commands/error_constants.h @@ -54,7 +54,12 @@ inline constexpr const char *errParsingNumkeys = "error parsing numkeys"; inline constexpr const char *errNumkeysMustBePositive = "numkeys need to be a positive integer"; inline constexpr const char *errWrongKeyword = "wrong keyword"; inline constexpr const char *errInvalidRankValue = "rank needs to be non-negative"; -inline constexpr const char *errLowCutQuantileRange = "low cut quantile must be between 0 and 1"; -inline constexpr const char *errHighCutQuantileRange = "high cut quantile must be between 0 and 1"; -inline constexpr const char *errLowCutQuantileLess = "low cut quantile must be less than high cut quantile"; +inline constexpr const char *errParseLowCutQuantile = "T-Digest: error parsing low_cut_percentile"; +inline constexpr const char *errParseHighCutQuantile = "T-Digest: error parsing high_cut_percentile"; +inline constexpr const char *errLowCutQuantileRange = + "T-Digest: low_cut_percentile and high_cut_percentile should be in [0,1]"; +inline constexpr const char *errHighCutQuantileRange = + "T-Digest: low_cut_percentile and high_cut_percentile should be in [0,1]"; +inline constexpr const char *errLowCutQuantileLess = + "T-Digest: low_cut_percentile should be lower than high_cut_percentile"; } // namespace redis diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index 68c83102b32..07ef000e4d6 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -40,9 +40,11 @@ const ( errMsgKeyNotExist = "key does not exist" errNumkeysMustBePositive = "numkeys need to be a positive integer" errCompressionParameterMustBePositive = "compression parameter needs to be a positive integer" - errMsgLowCutQuantileRange = "low cut quantile must be between 0 and 1" - errMsgHighCutQuantileRange = "high cut quantile must be between 0 and 1" - errMsgLowCutQuantileLess = "low cut quantile must be less than high cut quantile" + errMsgParseLowCutQuantile = "T-Digest: error parsing low_cut_percentile" + errMsgParseHighCutQuantile = "T-Digest: error parsing high_cut_percentile" + errMsgLowCutQuantileRange = "T-Digest: low_cut_percentile and high_cut_percentile should be in [0,1]" + errMsgHighCutQuantileRange = "T-Digest: low_cut_percentile and high_cut_percentile should be in [0,1]" + errMsgLowCutQuantileLess = "T-Digest: low_cut_percentile should be lower than high_cut_percentile" ) type tdigestInfo struct { @@ -788,6 +790,17 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.5", "0.5").Err(), errMsgLowCutQuantileLess) }) + t.Run("TDIGEST.TRIMMED_MEAN invalid quantile parsing", func(t *testing.T) { + key := "tdigest_invalid_parse" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5").Err()) + + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "abc", "0.9").Err(), errMsgParseLowCutQuantile) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "nan", "0.9").Err(), errMsgParseLowCutQuantile) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "abc").Err(), errMsgParseHighCutQuantile) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "nan").Err(), errMsgParseHighCutQuantile) + }) + t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) { key := "tdigest_single" require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) @@ -811,6 +824,16 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.NoError(t, err) require.InDelta(t, 5.5, mean, 0.01) }) + + t.Run("TDIGEST.TRIMMED_MEAN with nearly equal quantiles", func(t *testing.T) { + key := "tdigest_nearly_equal" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "1000").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.5", "0.5000000001") + require.NoError(t, result.Err()) + require.Equal(t, "nan", result.Val()) + }) } func TestTDigestByRankAndByRevRank(t *testing.T) { From 9b7ac7c1dd2106f1fb16232ea949d3c5c943c31a Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Thu, 12 Mar 2026 23:17:28 +0800 Subject: [PATCH 13/14] fix(tdigest): handle nan quantile inputs --- src/commands/cmd_tdigest.cc | 4 ++-- tests/gocase/unit/type/tdigest/tdigest_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index a30b46e4d0d..751a964b00c 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -502,13 +502,13 @@ class CommandTDigestTrimmedMean : public Commander { key_name_ = args[1]; auto low_cut_quantile = ParseFloat(args[2]); - if (!low_cut_quantile) { + if (!low_cut_quantile || std::isnan(*low_cut_quantile)) { return {Status::RedisParseErr, errParseLowCutQuantile}; } low_cut_quantile_ = *low_cut_quantile; auto high_cut_quantile = ParseFloat(args[3]); - if (!high_cut_quantile) { + if (!high_cut_quantile || std::isnan(*high_cut_quantile)) { return {Status::RedisParseErr, errParseHighCutQuantile}; } high_cut_quantile_ = *high_cut_quantile; diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index 07ef000e4d6..f8b47978efd 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -832,7 +832,7 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.5", "0.5000000001") require.NoError(t, result.Err()) - require.Equal(t, "nan", result.Val()) + require.Equal(t, "6", result.Val()) }) } From 592558d5ecb986a7159f1e37f52c330e57bc9af3 Mon Sep 17 00:00:00 2001 From: Jackie Yan Date: Thu, 12 Mar 2026 23:40:08 +0800 Subject: [PATCH 14/14] style(commands): simplify tdigest trimmed mean errors --- src/commands/cmd_tdigest.cc | 9 +++++++-- src/commands/error_constants.h | 8 -------- tests/gocase/unit/type/tdigest/tdigest_test.go | 10 +++++----- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index 751a964b00c..ad144bf372c 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -45,6 +45,11 @@ constexpr auto kInfoUnmergedWeight = "Unmerged weight"; constexpr auto kInfoObservations = "Observations"; constexpr auto kInfoTotalCompressions = "Total compressions"; constexpr auto kNan = "nan"; + +constexpr const char *errParseLowCutQuantile = "error parsing low_cut_percentile"; +constexpr const char *errParseHighCutQuantile = "error parsing high_cut_percentile"; +constexpr const char *errCutQuantileRange = "low_cut_percentile and high_cut_percentile should be in [0,1]"; +constexpr const char *errLowCutQuantileLess = "low_cut_percentile should be lower than high_cut_percentile"; } // namespace class CommandTDigestCreate : public Commander { @@ -514,10 +519,10 @@ class CommandTDigestTrimmedMean : public Commander { high_cut_quantile_ = *high_cut_quantile; if (!std::isfinite(low_cut_quantile_) || low_cut_quantile_ < 0.0 || low_cut_quantile_ > 1.0) { - return {Status::RedisParseErr, errLowCutQuantileRange}; + return {Status::RedisParseErr, errCutQuantileRange}; } if (!std::isfinite(high_cut_quantile_) || high_cut_quantile_ < 0.0 || high_cut_quantile_ > 1.0) { - return {Status::RedisParseErr, errHighCutQuantileRange}; + return {Status::RedisParseErr, errCutQuantileRange}; } if (low_cut_quantile_ >= high_cut_quantile_) { return {Status::RedisParseErr, errLowCutQuantileLess}; diff --git a/src/commands/error_constants.h b/src/commands/error_constants.h index 6b91aecf666..bde89a6d7f0 100644 --- a/src/commands/error_constants.h +++ b/src/commands/error_constants.h @@ -54,12 +54,4 @@ inline constexpr const char *errParsingNumkeys = "error parsing numkeys"; inline constexpr const char *errNumkeysMustBePositive = "numkeys need to be a positive integer"; inline constexpr const char *errWrongKeyword = "wrong keyword"; inline constexpr const char *errInvalidRankValue = "rank needs to be non-negative"; -inline constexpr const char *errParseLowCutQuantile = "T-Digest: error parsing low_cut_percentile"; -inline constexpr const char *errParseHighCutQuantile = "T-Digest: error parsing high_cut_percentile"; -inline constexpr const char *errLowCutQuantileRange = - "T-Digest: low_cut_percentile and high_cut_percentile should be in [0,1]"; -inline constexpr const char *errHighCutQuantileRange = - "T-Digest: low_cut_percentile and high_cut_percentile should be in [0,1]"; -inline constexpr const char *errLowCutQuantileLess = - "T-Digest: low_cut_percentile should be lower than high_cut_percentile"; } // namespace redis diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index f8b47978efd..6c58d8402bb 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -40,11 +40,11 @@ const ( errMsgKeyNotExist = "key does not exist" errNumkeysMustBePositive = "numkeys need to be a positive integer" errCompressionParameterMustBePositive = "compression parameter needs to be a positive integer" - errMsgParseLowCutQuantile = "T-Digest: error parsing low_cut_percentile" - errMsgParseHighCutQuantile = "T-Digest: error parsing high_cut_percentile" - errMsgLowCutQuantileRange = "T-Digest: low_cut_percentile and high_cut_percentile should be in [0,1]" - errMsgHighCutQuantileRange = "T-Digest: low_cut_percentile and high_cut_percentile should be in [0,1]" - errMsgLowCutQuantileLess = "T-Digest: low_cut_percentile should be lower than high_cut_percentile" + errMsgParseLowCutQuantile = "error parsing low_cut_percentile" + errMsgParseHighCutQuantile = "error parsing high_cut_percentile" + errMsgLowCutQuantileRange = "low_cut_percentile and high_cut_percentile should be in [0,1]" + errMsgHighCutQuantileRange = "low_cut_percentile and high_cut_percentile should be in [0,1]" + errMsgLowCutQuantileLess = "low_cut_percentile should be lower than high_cut_percentile" ) type tdigestInfo struct {