diff --git a/src/commands/cmd_pubsub.cc b/src/commands/cmd_pubsub.cc index 0a0cdebd54f..f9902d4322e 100644 --- a/src/commands/cmd_pubsub.cc +++ b/src/commands/cmd_pubsub.cc @@ -252,6 +252,22 @@ class CommandPubSub : public Commander { std::string subcommand_; }; +class CommandSPublish : public Commander { + public: + Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, [[maybe_unused]] Connection *conn, + std::string *output) override { + uint16_t slot = 0; + if (srv->GetConfig()->cluster_enabled) { + slot = GetSlotIdFromKey(args_[1]); + } + + int receivers = srv->SPublish(args_[1], args_[2], slot); + *output = redis::Integer(receivers); + + return Status::OK(); + } +}; + REDIS_REGISTER_COMMANDS(Pubsub, MakeCmdAttr("publish", 3, "read-only", NO_KEY), MakeCmdAttr("mpublish", -3, "read-only", NO_KEY), MakeCmdAttr("subscribe", -2, "read-only no-multi no-script", NO_KEY), @@ -260,6 +276,7 @@ REDIS_REGISTER_COMMANDS(Pubsub, MakeCmdAttr("publish", 3, "read- MakeCmdAttr("punsubscribe", -1, "read-only no-multi no-script", NO_KEY), MakeCmdAttr("ssubscribe", -2, "read-only no-multi no-script", NO_KEY), MakeCmdAttr("sunsubscribe", -1, "read-only no-multi no-script", NO_KEY), + MakeCmdAttr("spublish", 3, "read-only ok-loading", NO_KEY), MakeCmdAttr("pubsub", -2, "read-only no-script", NO_KEY), ) } // namespace redis diff --git a/src/server/server.cc b/src/server/server.cc index cc1f5b67793..12ee2ea88cf 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -619,6 +619,33 @@ void Server::ListSChannelSubscribeNum(const std::vector &channels, } } +int Server::SPublish(const std::string &channel, const std::string &msg, uint16_t slot) { + assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0); + + int cnt = 0; + std::lock_guard guard(pubsub_shard_channels_mu_); + + auto iter = pubsub_shard_channels_[slot].find(channel); + if (iter == pubsub_shard_channels_[slot].end()) { + return cnt; + } + + std::string reply; + reply.append(redis::MultiLen(3)); + reply.append(redis::BulkString("smessage")); + reply.append(redis::BulkString(channel)); + reply.append(redis::BulkString(msg)); + + for (const auto &conn_ctx : iter->second) { + auto s = conn_ctx.owner->Reply(conn_ctx.fd, reply); + if (s.IsOK()) { + cnt++; + } + } + + return cnt; +} + void Server::BlockOnKey(const std::string &key, redis::Connection *conn) { std::lock_guard guard(blocking_keys_mu_); diff --git a/src/server/server.h b/src/server/server.h index 63b83e28306..0c93a39377b 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -227,6 +227,7 @@ class Server { size_t GetPubSubPatternSize() const { return pubsub_patterns_.size(); } void SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot); void SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot); + int SPublish(const std::string &channel, const std::string &msg, uint16_t slot); void GetSChannelsByPattern(const std::string &pattern, std::vector *channels); void ListSChannelSubscribeNum(const std::vector &channels, std::vector *channel_subscribe_nums); diff --git a/tests/gocase/unit/pubsub/pubsubshard_test.go b/tests/gocase/unit/pubsub/pubsubshard_test.go index e55cd70f6fd..587c2d4a474 100644 --- a/tests/gocase/unit/pubsub/pubsubshard_test.go +++ b/tests/gocase/unit/pubsub/pubsubshard_test.go @@ -162,3 +162,85 @@ func TestPubSubShard(t *testing.T) { } }) } + +func TestSPublish(t *testing.T) { + ctx := context.Background() + + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("SPUBLISH to no subscribers", func(t *testing.T) { + // Should return 0 when no subscribers + n, err := rdb.Do(ctx, "SPUBLISH", "mychannel", "hello").Int() + require.NoError(t, err) + require.EqualValues(t, 0, n) + }) + + t.Run("SPUBLISH to one subscriber", func(t *testing.T) { + pubsub := rdb.SSubscribe(ctx, "mychannel") + defer pubsub.Close() + + // Receive the subscription message + receiveType(t, pubsub, &redis.Subscription{}) + require.EqualValues(t, 1, receiveType(t, pubsub, &redis.Subscription{}).Count) + + // Publish message + n, err := rdb.Do(ctx, "SPUBLISH", "mychannel", "hello world").Int() + require.NoError(t, err) + require.EqualValues(t, 1, n) + + // Receive the message + msg := receiveType(t, pubsub, &redis.Message{}) + require.EqualValues(t, "mychannel", msg.Channel) + require.EqualValues(t, "hello world", msg.Payload) + }) + + t.Run("SPUBLISH to multiple subscribers", func(t *testing.T) { + channel := "testchannel{tag}" + + pubsub1 := rdb.SSubscribe(ctx, channel) + defer pubsub1.Close() + receiveType(t, pubsub1, &redis.Subscription{}) + + pubsub2 := rdb.SSubscribe(ctx, channel) + defer pubsub2.Close() + receiveType(t, pubsub2, &redis.Subscription{}) + + // Publish message + n, err := rdb.Do(ctx, "SPUBLISH", channel, "message from spublish").Int() + require.NoError(t, err) + require.EqualValues(t, 2, n) + + // Both subscribers should receive the message + msg1 := receiveType(t, pubsub1, &redis.Message{}) + require.EqualValues(t, "message from spublish", msg1.Payload) + + msg2 := receiveType(t, pubsub2, &redis.Message{}) + require.EqualValues(t, "message from spublish", msg2.Payload) + }) + + t.Run("SPUBLISH with cluster enabled", func(t *testing.T) { + csrv := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer csrv.Close() + crdb := csrv.NewClient() + defer func() { require.NoError(t, crdb.Close()) }() + + nodeID := "test_node_id_12345678901234567890123456789012" + require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODEID", nodeID).Err()) + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383", nodeID, csrv.Host(), csrv.Port()) + require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + pubsub := crdb.SSubscribe(ctx, "mychannel{tag}") + defer pubsub.Close() + receiveType(t, pubsub, &redis.Subscription{}) + + n, err := crdb.Do(ctx, "SPUBLISH", "mychannel{tag}", "cluster message").Int() + require.NoError(t, err) + require.EqualValues(t, 1, n) + + msg := receiveType(t, pubsub, &redis.Message{}) + require.EqualValues(t, "cluster message", msg.Payload) + }) +}