diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index f839a3f7f1a..ad144bf372c 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -45,6 +45,11 @@ constexpr auto kInfoUnmergedWeight = "Unmerged weight"; constexpr auto kInfoObservations = "Observations"; constexpr auto kInfoTotalCompressions = "Total compressions"; constexpr auto kNan = "nan"; + +constexpr const char *errParseLowCutQuantile = "error parsing low_cut_percentile"; +constexpr const char *errParseHighCutQuantile = "error parsing high_cut_percentile"; +constexpr const char *errCutQuantileRange = "low_cut_percentile and high_cut_percentile should be in [0,1]"; +constexpr const char *errLowCutQuantileLess = "low_cut_percentile should be lower than high_cut_percentile"; } // namespace class CommandTDigestCreate : public Commander { @@ -492,6 +497,67 @@ class CommandTDigestMerge : public Commander { TDigestMergeOptions options_; }; +class CommandTDigestTrimmedMean : public Commander { + public: + Status Parse(const std::vector &args) override { + if (args.size() != 4) { + return {Status::RedisParseErr, errWrongNumOfArguments}; + } + + key_name_ = args[1]; + + auto low_cut_quantile = ParseFloat(args[2]); + if (!low_cut_quantile || std::isnan(*low_cut_quantile)) { + return {Status::RedisParseErr, errParseLowCutQuantile}; + } + low_cut_quantile_ = *low_cut_quantile; + + auto high_cut_quantile = ParseFloat(args[3]); + if (!high_cut_quantile || std::isnan(*high_cut_quantile)) { + return {Status::RedisParseErr, errParseHighCutQuantile}; + } + high_cut_quantile_ = *high_cut_quantile; + + if (!std::isfinite(low_cut_quantile_) || low_cut_quantile_ < 0.0 || low_cut_quantile_ > 1.0) { + return {Status::RedisParseErr, errCutQuantileRange}; + } + if (!std::isfinite(high_cut_quantile_) || high_cut_quantile_ < 0.0 || high_cut_quantile_ > 1.0) { + return {Status::RedisParseErr, errCutQuantileRange}; + } + if (low_cut_quantile_ >= high_cut_quantile_) { + return {Status::RedisParseErr, errLowCutQuantileLess}; + } + + return Status::OK(); + } + + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + TDigest tdigest(srv->storage, conn->GetNamespace()); + TDigestTrimmedMeanResult result; + + auto s = tdigest.TrimmedMean(ctx, key_name_, low_cut_quantile_, high_cut_quantile_, &result); + if (!s.ok()) { + if (s.IsNotFound()) { + return {Status::RedisExecErr, errKeyNotFound}; + } + return {Status::RedisExecErr, s.ToString()}; + } + + if (!result.mean.has_value()) { + *output = redis::BulkString(kNan); + } else { + *output = redis::BulkString(util::Float2String(*result.mean)); + } + + return Status::OK(); + } + + private: + std::string key_name_; + double low_cut_quantile_; + double high_cut_quantile_; +}; + std::vector GetMergeKeyRange(const std::vector &args) { auto numkeys = ParseInt(args[2], 10).ValueOr(0); return {{1, 1, 1}, {3, 2 + numkeys, 1}}; @@ -507,6 +573,7 @@ REDIS_REGISTER_COMMANDS(TDigest, MakeCmdAttr("tdigest.crea MakeCmdAttr("tdigest.byrevrank", -3, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.byrank", -3, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.quantile", -3, "read-only", 1, 1, 1), + MakeCmdAttr("tdigest.trimmed_mean", 4, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.reset", 2, "write", 1, 1, 1), MakeCmdAttr("tdigest.merge", -4, "write", GetMergeKeyRange)); } // namespace redis diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc index c44ff6b2823..dcdc0f44c34 100644 --- a/src/types/redis_tdigest.cc +++ b/src/types/redis_tdigest.cc @@ -759,6 +759,42 @@ rocksdb::Status TDigest::applyNewCentroids(ObserverOrUniquePtrGetLockManager(), ns_key); + if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) { + return status; + } + + if (metadata.total_observations == 0) { + result->mean.reset(); + return rocksdb::Status::OK(); + } + + if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) { + return status; + } + } + + // Dump centroids and create DummyCentroids wrapper for TDigest algorithm + std::vector centroids; + if (auto status = dumpCentroids(ctx, ns_key, metadata, ¢roids); !status.ok()) { + return status; + } + auto dump_centroids = DummyCentroids(metadata, centroids); + auto trimmed_mean_result = TDigestTrimmedMean(dump_centroids, low_cut_quantile, high_cut_quantile); + if (!trimmed_mean_result) { + return rocksdb::Status::InvalidArgument(trimmed_mean_result.Msg()); + } + + result->mean = *trimmed_mean_result; + return rocksdb::Status::OK(); +} + std::string TDigest::internalSegmentGuardPrefixKey(const TDigestMetadata& metadata, const std::string& ns_key, SegmentType seg) const { std::string prefix_key; diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h index 236ec9eb462..a949c5832d1 100644 --- a/src/types/redis_tdigest.h +++ b/src/types/redis_tdigest.h @@ -53,6 +53,10 @@ struct TDigestQuantitleResult { std::optional> quantiles; }; +struct TDigestTrimmedMeanResult { + std::optional mean; +}; + class TDigest : public SubKeyScanner { public: using Slice = rocksdb::Slice; @@ -85,6 +89,8 @@ class TDigest : public SubKeyScanner { std::vector* result); rocksdb::Status ByRank(engine::Context& ctx, const Slice& digest_name, const std::vector& inputs, std::vector* result); + rocksdb::Status TrimmedMean(engine::Context& ctx, const Slice& digest_name, double low_cut_quantile, + double high_cut_quantile, TDigestTrimmedMeanResult* result); rocksdb::Status GetMetaData(engine::Context& context, const Slice& digest_name, TDigestMetadata* metadata); private: diff --git a/src/types/tdigest.h b/src/types/tdigest.h index b9c4a6fcbee..1f152c1e00d 100644 --- a/src/types/tdigest.h +++ b/src/types/tdigest.h @@ -309,3 +309,46 @@ inline Status TDigestRank(TD&& td, const std::vector& inputs, std::vecto } return Status::OK(); } + +template +inline StatusOr TDigestTrimmedMean(TD&& td, double low_cut_quantile, double high_cut_quantile) { + if (td.Size() == 0) { + return std::numeric_limits::quiet_NaN(); + } + + const double total_weight = td.TotalWeight(); + const double leftmost_weight = std::floor(total_weight * low_cut_quantile); + const double rightmost_weight = std::ceil(total_weight * high_cut_quantile); + + double count_done = 0.0; + double trimmed_sum = 0.0; + double trimmed_count = 0.0; + + auto iter = td.Begin(); + while (iter->Valid()) { + auto centroid = GET_OR_RET(iter->GetCentroid()); + const double n_weight = centroid.weight; + double count_add = n_weight; + + // Keep only the portion of this centroid that overlaps with the trimmed rank range. + count_add -= std::min(std::max(0.0, leftmost_weight - count_done), count_add); + count_add = std::min(std::max(0.0, rightmost_weight - count_done), count_add); + + count_done += n_weight; + + trimmed_sum += centroid.mean * count_add; + trimmed_count += count_add; + + if (count_done >= rightmost_weight) { + break; + } + + iter->Next(); + } + + if (trimmed_count == 0.0) { + return std::numeric_limits::quiet_NaN(); + } + + return trimmed_sum / trimmed_count; +} diff --git a/tests/cppunit/types/tdigest_test.cc b/tests/cppunit/types/tdigest_test.cc index ed12df1368f..9fc1fb1f729 100644 --- a/tests/cppunit/types/tdigest_test.cc +++ b/tests/cppunit/types/tdigest_test.cc @@ -524,3 +524,81 @@ TEST_F(RedisTDigestTest, ByRank_And_ByRevRank) { EXPECT_EQ(result[0], 1.0) << "Rank 0 should be minimum"; EXPECT_TRUE(std::isinf(result[3])) << "Rank >= total_weight should be infinity"; } + +TEST_F(RedisTDigestTest, TrimmedMean) { + std::string test_digest_name = "test_digest_trimmed_mean" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + + std::vector values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + status = tdigest_->Add(*ctx_, test_digest_name, values); + ASSERT_TRUE(status.ok()) << status.ToString(); + + redis::TDigestTrimmedMeanResult result; + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.1, 0.9, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + EXPECT_NEAR(*result.mean, 5.5, 0.01); + + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.0, 1.0, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + EXPECT_NEAR(*result.mean, 5.5, 0.01); + + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.25, 0.75, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + EXPECT_NEAR(*result.mean, 5.5, 0.01); +} + +TEST_F(RedisTDigestTest, TrimmedMeanEmptyDigest) { + std::string test_digest_name = "test_digest_trimmed_mean_empty" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + + redis::TDigestTrimmedMeanResult result; + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.1, 0.9, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_FALSE(result.mean.has_value()); +} + +TEST_F(RedisTDigestTest, TrimmedMeanUnorderedInput) { + std::string test_digest_name = "test_digest_trimmed_mean_unordered" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + + std::vector values = {5, 2, 8, 1, 9, 3, 7, 4, 6, 10}; + status = tdigest_->Add(*ctx_, test_digest_name, values); + ASSERT_TRUE(status.ok()) << status.ToString(); + + redis::TDigestTrimmedMeanResult result; + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.1, 0.9, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + EXPECT_NEAR(*result.mean, 5.5, 0.01); +} + +TEST_F(RedisTDigestTest, TrimmedMeanComplexInput) { + std::string test_digest_name = "test_digest_trimmed_mean_complex" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + + std::vector values = {-10, 5, -3, 5, 0, 5, 3, -5, 10, -10}; + status = tdigest_->Add(*ctx_, test_digest_name, values); + ASSERT_TRUE(status.ok()) << status.ToString(); + + redis::TDigestTrimmedMeanResult result; + status = tdigest_->TrimmedMean(*ctx_, test_digest_name, 0.2, 0.8, &result); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_TRUE(result.mean.has_value()); + ASSERT_FALSE(std::isnan(*result.mean)); + EXPECT_NEAR(*result.mean, 5.0 / 6.0, 0.01); +} diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index bcc4a832caa..6c58d8402bb 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -40,6 +40,11 @@ const ( errMsgKeyNotExist = "key does not exist" errNumkeysMustBePositive = "numkeys need to be a positive integer" errCompressionParameterMustBePositive = "compression parameter needs to be a positive integer" + errMsgParseLowCutQuantile = "error parsing low_cut_percentile" + errMsgParseHighCutQuantile = "error parsing high_cut_percentile" + errMsgLowCutQuantileRange = "low_cut_percentile and high_cut_percentile should be in [0,1]" + errMsgHighCutQuantileRange = "low_cut_percentile and high_cut_percentile should be in [0,1]" + errMsgLowCutQuantileLess = "low_cut_percentile should be lower than high_cut_percentile" ) type tdigestInfo struct { @@ -717,6 +722,118 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { require.EqualValues(t, expected[i], rank, "REVRANK mismatch at index %d", i) } }) + + t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) { + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist) + }) + + t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) { + emptyKey := "tdigest_empty" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, "compression", "100").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", "0.9") + require.NoError(t, result.Err()) + require.Equal(t, "nan", result.Val()) + }) + + t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) { + key := "tdigest_basic" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9") + require.NoError(t, result.Err()) + mean, err := strconv.ParseFloat(result.Val().(string), 64) + require.NoError(t, err) + require.InDelta(t, 5.5, mean, 0.01) + }) + + t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) { + key := "tdigest_no_trim" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1") + require.NoError(t, result.Err()) + mean, err := strconv.ParseFloat(result.Val().(string), 64) + require.NoError(t, err) + require.InDelta(t, 5.5, mean, 0.01) + }) + + t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) { + key := "tdigest_skewed" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1", "1", "1", "1", "10", "100").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8") + require.NoError(t, result.Err()) + mean, err := strconv.ParseFloat(result.Val().(string), 64) + require.NoError(t, err) + require.InDelta(t, 2.8, mean, 0.01) + }) + + t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t *testing.T) { + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "key").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "key", "0.1").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", "key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg) + }) + + t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t *testing.T) { + key := "tdigest_invalid" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5").Err()) + + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "-0.1", "0.9").Err(), errMsgLowCutQuantileRange) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "1.1").Err(), errMsgHighCutQuantileRange) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.9", "0.1").Err(), errMsgLowCutQuantileLess) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.5", "0.5").Err(), errMsgLowCutQuantileLess) + }) + + t.Run("TDIGEST.TRIMMED_MEAN invalid quantile parsing", func(t *testing.T) { + key := "tdigest_invalid_parse" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5").Err()) + + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "abc", "0.9").Err(), errMsgParseLowCutQuantile) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "nan", "0.9").Err(), errMsgParseLowCutQuantile) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "abc").Err(), errMsgParseHighCutQuantile) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "nan").Err(), errMsgParseHighCutQuantile) + }) + + t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) { + key := "tdigest_single" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42", "42").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9") + require.NoError(t, result.Err()) + mean, err := strconv.ParseFloat(result.Val().(string), 64) + require.NoError(t, err) + require.InDelta(t, 42.0, mean, 0.01) + }) + + t.Run("TDIGEST.TRIMMED_MEAN with extreme trimming", func(t *testing.T) { + key := "tdigest_extreme" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.4", "0.6") + require.NoError(t, result.Err()) + mean, err := strconv.ParseFloat(result.Val().(string), 64) + require.NoError(t, err) + require.InDelta(t, 5.5, mean, 0.01) + }) + + t.Run("TDIGEST.TRIMMED_MEAN with nearly equal quantiles", func(t *testing.T) { + key := "tdigest_nearly_equal" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "1000").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", "3", "4", "5", "6", "7", "8", "9", "10").Err()) + + result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.5", "0.5000000001") + require.NoError(t, result.Err()) + require.Equal(t, "6", result.Val()) + }) } func TestTDigestByRankAndByRevRank(t *testing.T) {