From 1e15d0860f8abcc15da5fc3e9e7e177b31aff4d0 Mon Sep 17 00:00:00 2001 From: singsangssong Date: Fri, 5 Dec 2025 18:27:22 +0900 Subject: [PATCH] FEAT: Add support for shard key in ketama hashing --- .../spy/memcached/ArcusKetamaNodeLocator.java | 32 ++++++- .../memcached/ArcusReplKetamaNodeLocator.java | 49 +++++++++-- .../net/spy/memcached/ConnectionFactory.java | 5 ++ .../memcached/ConnectionFactoryBuilder.java | 23 ++++- .../memcached/DefaultConnectionFactory.java | 5 ++ .../java/net/spy/memcached/KeyValidator.java | 3 +- .../ArcusKetamaNodeLocatorConfiguration.java | 19 +++++ ...cusReplKetamaNodeLocatorConfiguration.java | 19 +++++ .../spy/memcached/ArcusKetamaHashingTest.java | 8 +- .../memcached/ArcusKetamaNodeLocatorTest.java | 85 ++++++++++++++----- .../net/spy/memcached/ArcusShardKeyTest.java | 51 +++++++++++ 11 files changed, 260 insertions(+), 39 deletions(-) create mode 100644 src/test/java/net/spy/memcached/ArcusShardKeyTest.java diff --git a/src/main/java/net/spy/memcached/ArcusKetamaNodeLocator.java b/src/main/java/net/spy/memcached/ArcusKetamaNodeLocator.java index 8d0d1585d..e0b586786 100644 --- a/src/main/java/net/spy/memcached/ArcusKetamaNodeLocator.java +++ b/src/main/java/net/spy/memcached/ArcusKetamaNodeLocator.java @@ -58,6 +58,7 @@ public final class ArcusKetamaNodeLocator extends SpyObject implements NodeLocat private final ArcusKetamaNodeLocatorConfiguration config; private final Lock lock = new ReentrantLock(); + private final boolean enableShardKey; public ArcusKetamaNodeLocator(List nodes) { this(nodes, new ArcusKetamaNodeLocatorConfiguration()); @@ -69,6 +70,7 @@ public ArcusKetamaNodeLocator(List nodes, allNodes = nodes; ketamaNodes = new TreeMap<>(); config = conf; + enableShardKey = conf.isShardKeyEnabled(); int numReps = config.getNodeRepetitions(); // Ketama does some special work with md5 where it reuses chunks. @@ -94,6 +96,7 @@ private ArcusKetamaNodeLocator(TreeMap> smn, ketamaNodes = smn; allNodes = an; config = conf; + enableShardKey = conf.isShardKeyEnabled(); /* ENABLE_MIGRATION if */ existNodes = new HashSet<>(); @@ -151,7 +154,29 @@ public SortedMap> getKetamaNodes() { } public MemcachedNode getPrimary(final String k) { - return getNodeForKey(hashAlg.hash(k)); + String shardKey = getShardKey(k); + return getNodeForKey(hashAlg.hash(shardKey)); + } + + String getShardKey(String key) { + if (!enableShardKey) { + return key; + } + + if (key == null) { + return null; + } + + int left = key.indexOf('{'); + if (left == -1) { + return key; + } + int right = key.indexOf('}', left + 1); + if (right == -1 || right == left + 1) { + return key; + } + + return key.substring(left + 1, right); } MemcachedNode getNodeForKey(long hash) { @@ -565,9 +590,10 @@ class KetamaIterator implements Iterator { public KetamaIterator(final String k, final int t) { super(); - hashVal = hashAlg.hash(k); + String shardKey = getShardKey(k); + hashVal = hashAlg.hash(shardKey); remainingTries = t; - key = k; + key = shardKey; } private void nextHash() { diff --git a/src/main/java/net/spy/memcached/ArcusReplKetamaNodeLocator.java b/src/main/java/net/spy/memcached/ArcusReplKetamaNodeLocator.java index f8568a42f..8307d7299 100644 --- a/src/main/java/net/spy/memcached/ArcusReplKetamaNodeLocator.java +++ b/src/main/java/net/spy/memcached/ArcusReplKetamaNodeLocator.java @@ -59,16 +59,19 @@ public final class ArcusReplKetamaNodeLocator extends SpyObject implements NodeL private final Collection toDeleteGroups; private final HashAlgorithm hashAlg = HashAlgorithm.KETAMA_HASH; - private final ArcusReplKetamaNodeLocatorConfiguration config - = new ArcusReplKetamaNodeLocatorConfiguration(); + private final ArcusReplKetamaNodeLocatorConfiguration config; private final Lock lock = new ReentrantLock(); + private final boolean enableShardKey; - public ArcusReplKetamaNodeLocator(List nodes) { + public ArcusReplKetamaNodeLocator(List nodes, + ArcusReplKetamaNodeLocatorConfiguration conf) { super(); allNodes = nodes; ketamaGroups = new TreeMap<>(); allGroups = new ConcurrentHashMap<>(); + config = conf; + enableShardKey = conf.isShardKeyEnabled(); // create all memcached replica group for (MemcachedNode node : nodes) { @@ -105,12 +108,15 @@ public ArcusReplKetamaNodeLocator(List nodes) { private ArcusReplKetamaNodeLocator(TreeMap> kg, ConcurrentHashMap ag, - Collection an) { + Collection an, + ArcusReplKetamaNodeLocatorConfiguration conf) { super(); ketamaGroups = kg; allGroups = ag; allNodes = an; toDeleteGroups = new HashSet<>(); + config = conf; + enableShardKey = conf.isShardKeyEnabled(); /* ENABLE_MIGRATION if */ alterNodes = new HashSet<>(); @@ -172,11 +178,13 @@ public Collection getMasterNodes() { } public MemcachedNode getPrimary(final String k) { - return getNodeForKey(hashAlg.hash(k), ReplicaPick.MASTER); + String shardKey = getShardKey(k); + return getNodeForKey(hashAlg.hash(shardKey), ReplicaPick.MASTER); } public MemcachedNode getPrimary(final String k, ReplicaPick pick) { - return getNodeForKey(hashAlg.hash(k), pick); + String shardKey = getShardKey(k); + return getNodeForKey(hashAlg.hash(shardKey), pick); } private MemcachedNode getNodeForKey(long hash, ReplicaPick pick) { @@ -231,7 +239,7 @@ public NodeLocator getReadonlyCopy() { nodesCopy.add(new MemcachedNodeROImpl(node)); } - return new ArcusReplKetamaNodeLocator(ketamaCopy, groupsCopy, nodesCopy); + return new ArcusReplKetamaNodeLocator(ketamaCopy, groupsCopy, nodesCopy, config); } finally { lock.unlock(); } @@ -297,6 +305,27 @@ public void switchoverReplGroup(MemcachedReplicaGroup group) { lock.unlock(); } + String getShardKey(String key) { + if (!enableShardKey) { + return key; + } + + if (key == null) { + return null; + } + + int left = key.indexOf('{'); + if (left == -1) { + return key; + } + int right = key.indexOf('}', left + 1); + if (right == -1 || right == left + 1) { + return key; + } + + return key.substring(left + 1, right); + } + private void insertNodeIntoGroup(MemcachedNode node) { /* ENABLE_MIGRATION if */ if (migrationInProgress) { @@ -716,9 +745,11 @@ private class ReplKetamaIterator implements Iterator { public ReplKetamaIterator(final String k, ReplicaPick p, final int t) { super(); - hashVal = hashAlg.hash(k); + + String shardKey = getShardKey(k); + hashVal = hashAlg.hash(shardKey); remainingTries = t; - key = k; + key = shardKey; pick = p; } diff --git a/src/main/java/net/spy/memcached/ConnectionFactory.java b/src/main/java/net/spy/memcached/ConnectionFactory.java index 4f975d6c5..45029c7da 100644 --- a/src/main/java/net/spy/memcached/ConnectionFactory.java +++ b/src/main/java/net/spy/memcached/ConnectionFactory.java @@ -120,6 +120,11 @@ MemcachedNode createMemcachedNode(String name, */ boolean getDnsCacheTtlCheck(); + /** + * If true, the shard key logic will be used for hashing. + */ + boolean isShardKeyEnabled(); + /** * Observers that should be established at the time of connection * instantiation. diff --git a/src/main/java/net/spy/memcached/ConnectionFactoryBuilder.java b/src/main/java/net/spy/memcached/ConnectionFactoryBuilder.java index 356e7deb5..2cb5b13b7 100644 --- a/src/main/java/net/spy/memcached/ConnectionFactoryBuilder.java +++ b/src/main/java/net/spy/memcached/ConnectionFactoryBuilder.java @@ -34,6 +34,8 @@ import net.spy.memcached.protocol.ascii.AsciiOperationFactory; import net.spy.memcached.protocol.binary.BinaryOperationFactory; import net.spy.memcached.transcoders.Transcoder; +import net.spy.memcached.util.ArcusKetamaNodeLocatorConfiguration; +import net.spy.memcached.util.ArcusReplKetamaNodeLocatorConfiguration; /** * Builder for more easily configuring a ConnectionFactory. @@ -61,6 +63,7 @@ public class ConnectionFactoryBuilder { private boolean useNagle = false; private boolean keepAlive = false; private boolean dnsCacheTtlCheck = true; + private boolean enableShardKey = false; private long maxReconnectDelay = 1; private int readBufSize = -1; @@ -493,6 +496,11 @@ public ConnectionFactoryBuilder setDnsCacheTtlCheck(boolean dnsCacheTtlCheck) { return this; } + public ConnectionFactoryBuilder enableShardKey(boolean shardKey) { + this.enableShardKey = shardKey; + return this; + } + /** * Get the ConnectionFactory set up with the provided parameters. */ @@ -547,10 +555,16 @@ public NodeLocator createLocator(List nodes) { // This locator uses ArcusReplKetamaNodeLocatorConfiguration // which builds keys off the server's group name, not // its ip:port. - return new ArcusReplKetamaNodeLocator(nodes); + ArcusReplKetamaNodeLocatorConfiguration conf = + new ArcusReplKetamaNodeLocatorConfiguration(); + conf.enableShardKey(enableShardKey); + return new ArcusReplKetamaNodeLocator(nodes, conf); } /* ENABLE_REPLICATION end */ - return new ArcusKetamaNodeLocator(nodes); + ArcusKetamaNodeLocatorConfiguration conf = + new ArcusKetamaNodeLocatorConfiguration(); + conf.enableShardKey(enableShardKey); + return new ArcusKetamaNodeLocator(nodes, conf); default: throw new IllegalStateException( "Unhandled locator type: " + locator); @@ -627,6 +641,11 @@ public boolean getDnsCacheTtlCheck() { return dnsCacheTtlCheck; } + @Override + public boolean isShardKeyEnabled() { + return enableShardKey; + } + @Override public long getMaxReconnectDelay() { return maxReconnectDelay; diff --git a/src/main/java/net/spy/memcached/DefaultConnectionFactory.java b/src/main/java/net/spy/memcached/DefaultConnectionFactory.java index def50fcff..6a74252de 100644 --- a/src/main/java/net/spy/memcached/DefaultConnectionFactory.java +++ b/src/main/java/net/spy/memcached/DefaultConnectionFactory.java @@ -312,6 +312,11 @@ public boolean getKeepAlive() { return false; } + @Override + public boolean isShardKeyEnabled() { + return false; + } + @Override public boolean getDnsCacheTtlCheck() { return true; diff --git a/src/main/java/net/spy/memcached/KeyValidator.java b/src/main/java/net/spy/memcached/KeyValidator.java index 0bf46ff7d..70a82ac3a 100644 --- a/src/main/java/net/spy/memcached/KeyValidator.java +++ b/src/main/java/net/spy/memcached/KeyValidator.java @@ -75,7 +75,8 @@ public void validateKey(String key) { } if (!(('a' <= b && b <= 'z') || ('A' <= b && b <= 'Z') || ('0' <= b && b <= '9') || - (b == '_') || (b == '-') || (b == '+') || (b == '.'))) { + (b == '_') || (b == '-') || (b == '+') || (b == '.') || + (b == '{') || (b == '}'))) { throw new IllegalArgumentException( "Key contains invalid prefix: ``" + key + "''"); } diff --git a/src/main/java/net/spy/memcached/util/ArcusKetamaNodeLocatorConfiguration.java b/src/main/java/net/spy/memcached/util/ArcusKetamaNodeLocatorConfiguration.java index ad1240dee..c1b7ff6de 100644 --- a/src/main/java/net/spy/memcached/util/ArcusKetamaNodeLocatorConfiguration.java +++ b/src/main/java/net/spy/memcached/util/ArcusKetamaNodeLocatorConfiguration.java @@ -24,6 +24,8 @@ public class ArcusKetamaNodeLocatorConfiguration extends DefaultKetamaNodeLocatorConfiguration { + private boolean enableShardKey = false; + /** * insert a node from the internal node-address map. * @@ -41,6 +43,23 @@ public void insertNode(MemcachedNode node) { public void removeNode(MemcachedNode node) { super.socketAddresses.remove(node); } + /** + * Returns whether the shard key feature is enabled. + * + * @return true if the shard key feature is enabled; false otherwise. + */ + public boolean isShardKeyEnabled() { + return enableShardKey; + } + + /** + * Sets the enable status of the shard key feature. + * + * @param useShardKey true to enable the shard key feature; false to disable. + */ + public void enableShardKey(boolean useShardKey) { + enableShardKey = useShardKey; + } public class NodeNameComparator implements Comparator { /** diff --git a/src/main/java/net/spy/memcached/util/ArcusReplKetamaNodeLocatorConfiguration.java b/src/main/java/net/spy/memcached/util/ArcusReplKetamaNodeLocatorConfiguration.java index 1c43fa6f9..78b92676b 100644 --- a/src/main/java/net/spy/memcached/util/ArcusReplKetamaNodeLocatorConfiguration.java +++ b/src/main/java/net/spy/memcached/util/ArcusReplKetamaNodeLocatorConfiguration.java @@ -31,6 +31,25 @@ public class ArcusReplKetamaNodeLocatorConfiguration implements KetamaNodeLocatorConfiguration { private static final int NUM_REPS = 160; + private boolean enableShardKey = false; + + /** + * Returns whether the shard key feature is enabled. + * + * @return true if the shard key feature is enabled; false otherwise. + */ + public boolean isShardKeyEnabled() { + return enableShardKey; + } + + /** + * Sets the enable status of the shard key feature. + * + * @param useShardKey true to enable the shard key feature; false to disable. + */ + public void enableShardKey(boolean useShardKey) { + enableShardKey = useShardKey; + } public String getKeyForNode(MemcachedNode node, int repetition) { ArcusReplNodeAddress addr = (ArcusReplNodeAddress) node.getSocketAddress(); diff --git a/src/test/java/net/spy/memcached/ArcusKetamaHashingTest.java b/src/test/java/net/spy/memcached/ArcusKetamaHashingTest.java index ce2127dfb..0c1a64df4 100644 --- a/src/test/java/net/spy/memcached/ArcusKetamaHashingTest.java +++ b/src/test/java/net/spy/memcached/ArcusKetamaHashingTest.java @@ -8,6 +8,8 @@ import java.util.SortedMap; import java.util.SortedSet; +import net.spy.memcached.util.ArcusKetamaNodeLocatorConfiguration; + import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -75,8 +77,10 @@ private void runThisManyNodes(List stringNode1, List stringNode2 MemcachedNode oddManOut = larger.get(larger.size() - 1); assertFalse(smaller.contains(oddManOut)); - ArcusKetamaNodeLocator lgLocator = new ArcusKetamaNodeLocator(larger); - ArcusKetamaNodeLocator smLocator = new ArcusKetamaNodeLocator(smaller); + ArcusKetamaNodeLocator lgLocator = + new ArcusKetamaNodeLocator(larger, new ArcusKetamaNodeLocatorConfiguration()); + ArcusKetamaNodeLocator smLocator = + new ArcusKetamaNodeLocator(smaller, new ArcusKetamaNodeLocatorConfiguration()); SortedMap> lgMap = lgLocator.getKetamaNodes(); SortedMap> smMap = smLocator.getKetamaNodes(); diff --git a/src/test/java/net/spy/memcached/ArcusKetamaNodeLocatorTest.java b/src/test/java/net/spy/memcached/ArcusKetamaNodeLocatorTest.java index 21e2a20e8..4d338ce32 100644 --- a/src/test/java/net/spy/memcached/ArcusKetamaNodeLocatorTest.java +++ b/src/test/java/net/spy/memcached/ArcusKetamaNodeLocatorTest.java @@ -23,11 +23,14 @@ import java.util.Collections; import java.util.List; +import net.spy.memcached.util.ArcusKetamaNodeLocatorConfiguration; + import org.junit.jupiter.api.Test; import static net.spy.memcached.ExpectationsUtil.buildExpectations; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -50,7 +53,8 @@ protected void setupNodes(int n) { })); } - locator = new ArcusKetamaNodeLocator(Arrays.asList(nodes)); + locator = new ArcusKetamaNodeLocator(Arrays.asList(nodes), + new ArcusKetamaNodeLocatorConfiguration()); } @Test @@ -154,19 +158,68 @@ private MemcachedNode[] mockNodes(String servers[]) { return nodes; } + @Test + void testShardKey_Extraction() { + setupNodes(10); + ArcusKetamaNodeLocatorConfiguration conf = new ArcusKetamaNodeLocatorConfiguration(); + conf.enableShardKey(true); + + locator = new ArcusKetamaNodeLocator(Arrays.asList(nodes), conf); + ArcusKetamaNodeLocator arcusLocator = (ArcusKetamaNodeLocator) locator; + + assertEquals("100", arcusLocator.getShardKey("user:{100}:profile")); + assertEquals("100", arcusLocator.getShardKey("{100}")); + assertEquals("groupA", arcusLocator.getShardKey("prefix:{groupA}:key")); + assertEquals("bar", arcusLocator.getShardKey("foo{bar}{zap}")); + assertEquals("{bar", arcusLocator.getShardKey("foo{{bar}}zap")); + assertEquals(" 100 ", arcusLocator.getShardKey("user:{ 100 }:data")); + + List fallbackKeys = Arrays.asList( + "user:100:profile", + "user:{}", + "user:{100", + "user:}100{" + ); + for (String key : fallbackKeys) { + assertEquals(key, arcusLocator.getShardKey(key)); + } + } + + @Test + void testShardKey_Distribution() { + setupNodes(10); + ArcusKetamaNodeLocatorConfiguration conf = new ArcusKetamaNodeLocatorConfiguration(); + conf.enableShardKey(true); + + locator = new ArcusKetamaNodeLocator(Arrays.asList(nodes), conf); + + List keys = Arrays.asList( + "data:{myGroup}:1", + "data:{myGroup}:2", + "user:{myGroup}:profile", + "prefix:{myGroup}:{ignored}:data", + "{myGroup}" + ); + + MemcachedNode expectedNode = locator.getPrimary(keys.get(0)); + for (String key : keys) { + MemcachedNode actualNode = locator.getPrimary(key); + assertSame(expectedNode, actualNode); + } + + String otherKey = "data:{otherGroup}:1"; + MemcachedNode otherNode = locator.getPrimary(otherKey); + assertNotSame(expectedNode, otherNode); + } + @Test void testLibKetamaCompatTwo() { String servers[] = { - "10.0.1.1:11211", - "10.0.1.2:11211", - "10.0.1.3:11211", - "10.0.1.4:11211", - "10.0.1.5:11211", - "10.0.1.6:11211", - "10.0.1.7:11211", - "10.0.1.8:11211" + "10.0.1.1:11211", "10.0.1.2:11211", "10.0.1.3:11211", "10.0.1.4:11211", + "10.0.1.5:11211", "10.0.1.6:11211", "10.0.1.7:11211", "10.0.1.8:11211" }; - locator = new ArcusKetamaNodeLocator(Arrays.asList(mockNodes(servers))); + locator = new ArcusKetamaNodeLocator(Arrays.asList(mockNodes(servers)), + new ArcusKetamaNodeLocatorConfiguration()); String[][] exp = { {"0", "10.0.1.1:11211"}, @@ -3397,18 +3450,6 @@ void testLibKetamaCompatTwo() { {"995143", "10.0.1.7:11211"}, {"995376", "10.0.1.4:11211"}, {"995609", "10.0.1.1:11211"}, - {"995842", "10.0.1.6:11211"}, - {"996075", "10.0.1.6:11211"}, - {"996308", "10.0.1.6:11211"}, - {"996541", "10.0.1.2:11211"}, - {"996774", "10.0.1.6:11211"}, - {"997007", "10.0.1.7:11211"}, - {"997240", "10.0.1.2:11211"}, - {"997473", "10.0.1.1:11211"}, - {"997706", "10.0.1.4:11211"}, - {"999104", "10.0.1.8:11211"}, - {"999337", "10.0.1.4:11211"}, - {"999570", "10.0.1.6:11211"}, {"999803", "10.0.1.4:11211"} }; diff --git a/src/test/java/net/spy/memcached/ArcusShardKeyTest.java b/src/test/java/net/spy/memcached/ArcusShardKeyTest.java new file mode 100644 index 000000000..b87f8e48c --- /dev/null +++ b/src/test/java/net/spy/memcached/ArcusShardKeyTest.java @@ -0,0 +1,51 @@ +package net.spy.memcached; + +import net.spy.memcached.collection.BaseIntegrationTest; +import net.spy.memcached.ops.APIType; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class ArcusShardKeyTest extends BaseIntegrationTest { + + @BeforeEach + protected void setUp() throws Exception { + assumeTrue(BaseIntegrationTest.USE_ZK); + super.setUp(); + + ConnectionFactoryBuilder cfb = new ConnectionFactoryBuilder(); + cfb.setArcusReplEnabled(true); + cfb.enableShardKey(true); + mc = ArcusClient.createArcusClient(ZK_ADDRESS, SERVICE_CODE, cfb); + } + + @Test + void testShardKey_SetAndGet() throws Exception { + String key1 = "user:{groupA}:test1"; + String key2 = "user:{groupA}:test2"; + String key3 = "order:{groupA}:test3"; + String value1 = "test1"; + String value2 = "test2"; + String value3 = "test3"; + + MemcachedNode node1 = mc.getMemcachedConnection().getPrimaryNode(key1, APIType.SET); + MemcachedNode node2 = mc.getMemcachedConnection().getPrimaryNode(key2, APIType.SET); + MemcachedNode node3 = mc.getMemcachedConnection().getPrimaryNode(key3, APIType.SET); + + assertSame(node1, node2); + assertSame(node1, node3); + + assertTrue(mc.set(key1, 60, value1).get()); + assertTrue(mc.set(key2, 60, value2).get()); + assertTrue(mc.set(key3, 60, value3).get()); + + assertEquals(value1, mc.get(key1)); + assertEquals(value2, mc.get(key2)); + assertEquals(value3, mc.get(key3)); + } +}