Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/commands/cmd_pubsub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,22 @@ class CommandPubSub : public Commander {
std::string subcommand_;
};

class CommandSPublish : public Commander {
public:
Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, [[maybe_unused]] Connection *conn,
std::string *output) override {
uint16_t slot = 0;
if (srv->GetConfig()->cluster_enabled) {
slot = GetSlotIdFromKey(args_[1]);
}

int receivers = srv->SPublish(args_[1], args_[2], slot);
*output = redis::Integer(receivers);

return Status::OK();
}
};

REDIS_REGISTER_COMMANDS(Pubsub, MakeCmdAttr<CommandPublish>("publish", 3, "read-only", NO_KEY),
MakeCmdAttr<CommandMPublish>("mpublish", -3, "read-only", NO_KEY),
MakeCmdAttr<CommandSubscribe>("subscribe", -2, "read-only no-multi no-script", NO_KEY),
Expand All @@ -260,6 +276,7 @@ REDIS_REGISTER_COMMANDS(Pubsub, MakeCmdAttr<CommandPublish>("publish", 3, "read-
MakeCmdAttr<CommandPUnSubscribe>("punsubscribe", -1, "read-only no-multi no-script", NO_KEY),
MakeCmdAttr<CommandSSubscribe>("ssubscribe", -2, "read-only no-multi no-script", NO_KEY),
MakeCmdAttr<CommandSUnSubscribe>("sunsubscribe", -1, "read-only no-multi no-script", NO_KEY),
MakeCmdAttr<CommandSPublish>("spublish", 3, "read-only ok-loading", NO_KEY),
MakeCmdAttr<CommandPubSub>("pubsub", -2, "read-only no-script", NO_KEY), )

} // namespace redis
27 changes: 27 additions & 0 deletions src/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,33 @@ void Server::ListSChannelSubscribeNum(const std::vector<std::string> &channels,
}
}

int Server::SPublish(const std::string &channel, const std::string &msg, uint16_t slot) {
assert((config_->cluster_enabled && slot < HASH_SLOTS_SIZE) || slot == 0);

int cnt = 0;
std::lock_guard<std::mutex> guard(pubsub_shard_channels_mu_);

auto iter = pubsub_shard_channels_[slot].find(channel);
if (iter == pubsub_shard_channels_[slot].end()) {
return cnt;
}

std::string reply;
reply.append(redis::MultiLen(3));
reply.append(redis::BulkString("smessage"));
reply.append(redis::BulkString(channel));
reply.append(redis::BulkString(msg));

for (const auto &conn_ctx : iter->second) {
auto s = conn_ctx.owner->Reply(conn_ctx.fd, reply);
if (s.IsOK()) {
cnt++;
}
}

return cnt;
}

void Server::BlockOnKey(const std::string &key, redis::Connection *conn) {
std::lock_guard<std::mutex> guard(blocking_keys_mu_);

Expand Down
1 change: 1 addition & 0 deletions src/server/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class Server {
size_t GetPubSubPatternSize() const { return pubsub_patterns_.size(); }
void SSubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot);
void SUnsubscribeChannel(const std::string &channel, redis::Connection *conn, uint16_t slot);
int SPublish(const std::string &channel, const std::string &msg, uint16_t slot);
void GetSChannelsByPattern(const std::string &pattern, std::vector<std::string> *channels);
void ListSChannelSubscribeNum(const std::vector<std::string> &channels,
std::vector<ChannelSubscribeNum> *channel_subscribe_nums);
Expand Down
82 changes: 82 additions & 0 deletions tests/gocase/unit/pubsub/pubsubshard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,85 @@ func TestPubSubShard(t *testing.T) {
}
})
}

func TestSPublish(t *testing.T) {
ctx := context.Background()

srv := util.StartServer(t, map[string]string{})
defer srv.Close()
rdb := srv.NewClient()
defer func() { require.NoError(t, rdb.Close()) }()

t.Run("SPUBLISH to no subscribers", func(t *testing.T) {
// Should return 0 when no subscribers
n, err := rdb.Do(ctx, "SPUBLISH", "mychannel", "hello").Int()
require.NoError(t, err)
require.EqualValues(t, 0, n)
})

t.Run("SPUBLISH to one subscriber", func(t *testing.T) {
pubsub := rdb.SSubscribe(ctx, "mychannel")
defer pubsub.Close()

// Receive the subscription message
receiveType(t, pubsub, &redis.Subscription{})
require.EqualValues(t, 1, receiveType(t, pubsub, &redis.Subscription{}).Count)

// Publish message
n, err := rdb.Do(ctx, "SPUBLISH", "mychannel", "hello world").Int()
require.NoError(t, err)
require.EqualValues(t, 1, n)

// Receive the message
msg := receiveType(t, pubsub, &redis.Message{})
require.EqualValues(t, "mychannel", msg.Channel)
require.EqualValues(t, "hello world", msg.Payload)
})

t.Run("SPUBLISH to multiple subscribers", func(t *testing.T) {
channel := "testchannel{tag}"

pubsub1 := rdb.SSubscribe(ctx, channel)
defer pubsub1.Close()
receiveType(t, pubsub1, &redis.Subscription{})

pubsub2 := rdb.SSubscribe(ctx, channel)
defer pubsub2.Close()
receiveType(t, pubsub2, &redis.Subscription{})

// Publish message
n, err := rdb.Do(ctx, "SPUBLISH", channel, "message from spublish").Int()
require.NoError(t, err)
require.EqualValues(t, 2, n)

// Both subscribers should receive the message
msg1 := receiveType(t, pubsub1, &redis.Message{})
require.EqualValues(t, "message from spublish", msg1.Payload)

msg2 := receiveType(t, pubsub2, &redis.Message{})
require.EqualValues(t, "message from spublish", msg2.Payload)
})

t.Run("SPUBLISH with cluster enabled", func(t *testing.T) {
csrv := util.StartServer(t, map[string]string{"cluster-enabled": "yes"})
defer csrv.Close()
crdb := csrv.NewClient()
defer func() { require.NoError(t, crdb.Close()) }()

nodeID := "test_node_id_12345678901234567890123456789012"
require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODEID", nodeID).Err())
clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383", nodeID, csrv.Host(), csrv.Port())
require.NoError(t, crdb.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err())

pubsub := crdb.SSubscribe(ctx, "mychannel{tag}")
defer pubsub.Close()
receiveType(t, pubsub, &redis.Subscription{})

n, err := crdb.Do(ctx, "SPUBLISH", "mychannel{tag}", "cluster message").Int()
require.NoError(t, err)
require.EqualValues(t, 1, n)

msg := receiveType(t, pubsub, &redis.Message{})
require.EqualValues(t, "cluster message", msg.Payload)
})
}
Loading