diff --git a/.gitignore b/.gitignore index 524f096..37b730b 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml hs_err_pid* replay_pid* + +# maven build directories +target/* diff --git a/pom.xml b/pom.xml index 2f1e60f..f699fb8 100755 --- a/pom.xml +++ b/pom.xml @@ -22,6 +22,7 @@ proxy-socket-core + proxy-socket-udp proxy-socket-guava diff --git a/proxy-socket-core/pom.xml b/proxy-socket-core/pom.xml index ccee5d6..b101eed 100644 --- a/proxy-socket-core/pom.xml +++ b/proxy-socket-core/pom.xml @@ -33,6 +33,24 @@ + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.3.0 + + + + test-jar + + + + + + + diff --git a/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/ProxyProtocolMetricsListener.java b/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/ProxyProtocolMetricsListener.java index 11ca5ab..f199f52 100644 --- a/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/ProxyProtocolMetricsListener.java +++ b/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/ProxyProtocolMetricsListener.java @@ -5,6 +5,8 @@ package net.airvantage.proxysocket.core; import net.airvantage.proxysocket.core.v2.ProxyHeader; + +import java.net.InetAddress; import java.net.InetSocketAddress; /** @@ -16,4 +18,7 @@ default void onHeaderParsed(ProxyHeader header) {} default void onParseError(Exception e) {} default void onCacheHit(InetSocketAddress client) {} default void onCacheMiss(InetSocketAddress client) {} + default void onUntrustedProxy(InetAddress proxy) {} + default void onTrustedProxy(InetAddress proxy) {} + default void onLocal(InetAddress proxy) {} } diff --git a/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2Decoder.java b/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2Decoder.java index 95f0edd..7e21bfe 100644 --- a/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2Decoder.java +++ b/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2Decoder.java @@ -65,17 +65,11 @@ public static ProxyHeader parse(byte[] data, int offset, int length, boolean par throw new ProxyProtocolParseException("Invalid version"); } - Command command; - switch (cmd) { - case 0x00: - // Early return for LOCAL command - return new ProxyHeader(Command.LOCAL, AddressFamily.AF_UNSPEC, TransportProtocol.UNSPEC, null, null, null, PROTOCOL_SIGNATURE_FIXED_LENGTH); - case 0x01: - command = Command.PROXY; - break; - default: - throw new ProxyProtocolParseException("Invalid command"); - } + Command command = switch (cmd) { + case 0x00 -> Command.LOCAL; + case 0x01 -> Command.PROXY; + default -> throw new ProxyProtocolParseException("Invalid command"); + }; // Byte 14: address family and protocol int famProto = data[pos++] & 0xFF; diff --git a/proxy-socket-core/src/main/java/net/airvantage/proxysocket/tools/SubnetPredicate.java b/proxy-socket-core/src/main/java/net/airvantage/proxysocket/tools/SubnetPredicate.java new file mode 100644 index 0000000..7f8cde6 --- /dev/null +++ b/proxy-socket-core/src/main/java/net/airvantage/proxysocket/tools/SubnetPredicate.java @@ -0,0 +1,160 @@ +/** + * BSD-3-Clause License. + * Copyright (c) 2025 Semtech + */ +package net.airvantage.proxysocket.tools; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.function.Predicate; + +/** + * Predicate compatible class that tests whether an InetSocketAddress belongs to a given subnet (CIDR). + * Supports both IPv4 and IPv6 CIDR notation. + * + *

Example usage: + *

+ * // Single subnet
+ * Predicate predicate = new SubnetPredicate("10.0.0.0/8");
+ *
+ * // Multiple subnets
+ * Predicate predicate =
+ *     new SubnetPredicate("10.0.0.0/8")
+ *         .or(new SubnetPredicate("192.168.0.0/16"))
+ *         .or(new SubnetPredicate("2001:db8::/32"))
+ * );
+ * 
+ * + * Note: it's possible to create a SubnetPredicate with a hostname instead of an IP address, + * the address will be resolved to an IP address using InetAddress.getByName(hostname). + * If the hostname is not resolvable, an IllegalArgumentException will be thrown. + * But if the hostname resolves to a mix of IPv4/IPv6 addresses or multiple addresses, + * the predicate will only match the first address found. + * + * Thread-safety: This class is immutable and thread-safe. + */ +public class SubnetPredicate implements Predicate { + private final byte[] networkAddress; + private final int prefixLength; + private final int addressLength; // 4 for IPv4, 16 for IPv6 + + /** + * Creates a predicate for the given CIDR subnet. + * + * @param cidr CIDR notation string (e.g., "10.0.0.0/8" or "2001:db8::/32") + * @throws IllegalArgumentException if the CIDR notation is invalid + */ + public SubnetPredicate(String cidr) { + if (cidr == null || cidr.isEmpty()) { + throw new IllegalArgumentException("CIDR notation cannot be null or empty"); + } + + int slashIndex = cidr.indexOf('/'); + if (slashIndex == -1) { + throw new IllegalArgumentException("Invalid CIDR notation: missing '/' separator"); + } + + String addressPart = cidr.substring(0, slashIndex); + String prefixLengthPart = cidr.substring(slashIndex + 1); + + try { + this.prefixLength = Integer.parseInt(prefixLengthPart); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid prefix length: " + prefixLengthPart, e); + } + + byte[] rawAddress; + try { + InetAddress addr = InetAddress.getByName(addressPart); + rawAddress = addr.getAddress(); + } catch (UnknownHostException e) { + throw new IllegalArgumentException("Invalid IP address: " + addressPart, e); + } + + this.addressLength = rawAddress.length; + + // Validate prefix length + int maxPrefixLength = addressLength * 8; // converts address length to bits + if (prefixLength < 0 || prefixLength > maxPrefixLength) { + throw new IllegalArgumentException( + "Invalid prefix length " + prefixLength + " for address type (must be 0-" + maxPrefixLength + ")" + ); + } + + // Apply the mask to the network address to normalize it + this.networkAddress = applyMask(rawAddress); + } + + /** + * Tests whether the given socket address belongs to this subnet. + * + * @param socketAddress the socket address to test + * @return true if the address is in this subnet, false otherwise + */ + @Override + public boolean test(InetSocketAddress socketAddress) { + if (socketAddress == null) { + return false; + } + + InetAddress address = socketAddress.getAddress(); + if (address == null) { + return false; + } + + byte[] testAddress = address.getAddress(); + + // Different address families don't match + if (testAddress.length != addressLength) { + return false; + } + + byte[] maskedTestAddress = applyMask(testAddress); + + // Compare network portions + for (int i = 0; i < networkAddress.length; i++) { + if (networkAddress[i] != maskedTestAddress[i]) { + return false; + } + } + + return true; + } + + /** + * Applies a subnet mask to an IP address. + * + * @param address the raw IP address bytes + * @return the masked address bytes + */ + private byte[] applyMask(byte[] address) { + byte[] result = new byte[address.length]; + + int fullBytes = prefixLength / 8; + int remainingBits = prefixLength % 8; + + // Copy the full bytes + System.arraycopy(address, 0, result, 0, fullBytes); + + // Apply mask to the partial byte if any + if (remainingBits > 0) { + int mask = 0xFF << (8 - remainingBits); + result[fullBytes] = (byte) (address[fullBytes] & mask); + } + + // Remaining bytes are already 0 + return result; + } + + @Override + public String toString() { + try { + InetAddress addr = InetAddress.getByAddress(networkAddress); + return addr.getHostAddress() + "/" + prefixLength; + } catch (UnknownHostException e) { + return "SubnetPredicate[invalid]"; + } + } +} + diff --git a/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/cache/ConcurrentMapProxyAddressCache.java b/proxy-socket-core/src/main/java/net/airvantage/proxysocket/tools/cache/ConcurrentMapProxyAddressCache.java similarity index 95% rename from proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/cache/ConcurrentMapProxyAddressCache.java rename to proxy-socket-core/src/main/java/net/airvantage/proxysocket/tools/cache/ConcurrentMapProxyAddressCache.java index 4060190..9fbbf3a 100644 --- a/proxy-socket-core/src/main/java/net/airvantage/proxysocket/core/cache/ConcurrentMapProxyAddressCache.java +++ b/proxy-socket-core/src/main/java/net/airvantage/proxysocket/tools/cache/ConcurrentMapProxyAddressCache.java @@ -2,7 +2,7 @@ * BSD-3-Clause License. * Copyright (c) 2025 Semtech */ -package net.airvantage.proxysocket.core.cache; +package net.airvantage.proxysocket.tools.cache; import net.airvantage.proxysocket.core.ProxyAddressCache; import java.net.InetSocketAddress; diff --git a/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/AwsProxyEncoderHelper.java b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/AwsProxyEncoderHelper.java index 885e144..5eedf99 100644 --- a/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/AwsProxyEncoderHelper.java +++ b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/AwsProxyEncoderHelper.java @@ -1,5 +1,5 @@ -/* - * MIT License +/** + * BSD-3-Clause License. * Copyright (c) 2025 Semtech * * Helper class to encode PROXY protocol v2 headers using AWS ProProt library. @@ -28,9 +28,10 @@ public final class AwsProxyEncoderHelper { private final Header header = new Header(); public AwsProxyEncoderHelper command(ProxyHeader.Command cmd) { - this.command = cmd == ProxyHeader.Command.LOCAL - ? ProxyProtocolSpec.Command.LOCAL - : ProxyProtocolSpec.Command.PROXY; + this.command = switch (cmd) { + case LOCAL -> ProxyProtocolSpec.Command.LOCAL; + case PROXY -> ProxyProtocolSpec.Command.PROXY; + }; return this; } @@ -73,16 +74,10 @@ public AwsProxyEncoderHelper addTlv(int type, byte[] value) { public byte[] build() throws IOException { header.setCommand(command); - header.setAddressFamily(family); - header.setTransportProtocol(protocol); - // AWS ProProt validates addresses even for LOCAL command, set dummy values - if (command == ProxyProtocolSpec.Command.LOCAL && source == null) { - header.setSrcAddress(new byte[]{0, 0, 0, 0}); - header.setDstAddress(new byte[]{0, 0, 0, 0}); - header.setSrcPort(0); - header.setDstPort(0); - } else { + if (command != ProxyProtocolSpec.Command.LOCAL) { + header.setAddressFamily(family); + header.setTransportProtocol(protocol); if (source != null) { header.setSrcAddress(source.getAddress().getAddress()); header.setSrcPort(source.getPort()); @@ -92,6 +87,12 @@ public byte[] build() throws IOException { header.setDstAddress(destination.getAddress().getAddress()); header.setDstPort(destination.getPort()); } + } else { + // Spec clearly state that for LOCAL command, we + // 1. must discard the protocol block including the family and + // 2. \x00 is expected to be used for the protocol field. + header.setAddressFamily(ProxyProtocolSpec.AddressFamily.AF_UNSPEC); + header.setTransportProtocol(ProxyProtocolSpec.TransportProtocol.UNSPEC); } ByteArrayOutputStream out = new ByteArrayOutputStream(); diff --git a/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2DecoderTest.java b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2DecoderTest.java index 03ed2b0..f5c33cc 100644 --- a/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2DecoderTest.java +++ b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2DecoderTest.java @@ -1,7 +1,7 @@ -/* - * MIT License +/** + * BSD-3-Clause License. * Copyright (c) 2025 Semtech - + * * Validation of ProxyProtocolV2Decoder against hardcoded headers for known cases */ package net.airvantage.proxysocket.core.v2; diff --git a/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2Test.java b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2Test.java index ce25b5d..9446234 100644 --- a/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2Test.java +++ b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/v2/ProxyProtocolV2Test.java @@ -1,7 +1,7 @@ -/* - * MIT License +/** + * BSD-3-Clause License. * Copyright (c) 2025 Semtech - + * * Validation of ProxyProtocolV2Decoder using AWS ProProt library */ package net.airvantage.proxysocket.core.v2; diff --git a/proxy-socket-core/src/test/java/net/airvantage/proxysocket/tools/SubnetPredicateTest.java b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/tools/SubnetPredicateTest.java new file mode 100644 index 0000000..cccff54 --- /dev/null +++ b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/tools/SubnetPredicateTest.java @@ -0,0 +1,233 @@ +/** + * BSD-3-Clause License. + * Copyright (c) 2025 Semtech + */ +package net.airvantage.proxysocket.tools; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.*; + +class SubnetPredicateTest { + + // ========== Parameterized Tests for Matching Addresses ========== + + @ParameterizedTest(name = "{0} should match {1}") + @CsvSource({ + // IPv4 - Single Host /32 + "192.168.1.100/32, 192.168.1.100", + // IPv4 - Class C /24 + "192.168.1.0/24, 192.168.1.0", + "192.168.1.0/24, 192.168.1.1", + "192.168.1.0/24, 192.168.1.100", + "192.168.1.0/24, 192.168.1.255", + // IPv4 - Class B /16 + "172.16.0.0/16, 172.16.0.0", + "172.16.0.0/16, 172.16.1.1", + "172.16.0.0/16, 172.16.255.255", + // IPv4 - Class A /8 + "10.0.0.0/8, 10.0.0.0", + "10.0.0.0/8, 10.0.0.1", + "10.0.0.0/8, 10.255.255.255", + "10.0.0.0/8, 10.123.45.67", + // IPv4 - /25 + "192.168.1.0/25, 192.168.1.0", + "192.168.1.0/25, 192.168.1.127", + // IPv4 - /23 + "192.168.0.0/23, 192.168.0.0", + "192.168.0.0/23, 192.168.0.255", + "192.168.0.0/23, 192.168.1.0", + "192.168.0.0/23, 192.168.1.255", + // IPv4 - /0 (matches all IPv4) + "0.0.0.0/0, 0.0.0.0", + "0.0.0.0/0, 1.2.3.4", + "0.0.0.0/0, 192.168.1.1", + "0.0.0.0/0, 255.255.255.255", + // IPv6 - Single Host /128 + "2001:db8::1/128, 2001:db8::1", + // IPv6 - /64 + "2001:db8::/64, 2001:db8::1", + "2001:db8::/64, 2001:db8::ffff:ffff:ffff:ffff", + "2001:db8::/64, 2001:db8:0:0:1234:5678:9abc:def0", + // IPv6 - /32 + "2001:db8::/32, 2001:db8::", + "2001:db8::/32, 2001:db8:ffff:ffff:ffff:ffff:ffff:ffff", + // IPv6 - Loopback + "::1/128, ::1", + // IPv6 - /0 (matches all IPv6) + "::/0, ::", + "::/0, ::1", + "::/0, 2001:db8::1", + "::/0, ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff" + }) + void testAddressShouldMatch(String cidr, String address) throws UnknownHostException { + SubnetPredicate predicate = new SubnetPredicate(cidr); + assertTrue(predicate.test(addr(address, 8080)), + address + " should match " + cidr); + } + + @ParameterizedTest(name = "{0} should NOT match {1}") + @CsvSource({ + // IPv4 - Single Host /32 + "192.168.1.100/32, 192.168.1.101", + "192.168.1.100/32, 192.168.1.99", + // IPv4 - Class C /24 + "192.168.1.0/24, 192.168.0.255", + "192.168.1.0/24, 192.168.2.0", + "192.168.1.0/24, 192.167.1.1", + // IPv4 - Class B /16 + "172.16.0.0/16, 172.15.255.255", + "172.16.0.0/16, 172.17.0.0", + // IPv4 - Class A /8 + "10.0.0.0/8, 9.255.255.255", + "10.0.0.0/8, 11.0.0.0", + // IPv4 - /25 + "192.168.1.0/25, 192.168.1.128", + "192.168.1.0/25, 192.168.1.255", + // IPv4 - /23 + "192.168.0.0/23, 192.168.2.0", + "192.168.0.0/23, 192.167.255.255", + // IPv4 - /0 (doesn't match IPv6) + "0.0.0.0/0, ::1", + // IPv6 - Single Host /128 + "2001:db8::1/128, 2001:db8::2", + // IPv6 - /64 + "2001:db8::/64, 2001:db8:0:1::1", + // IPv6 - /32 + "2001:db8::/32, 2001:db9::", + "2001:db8::/32, 2001:db7:ffff:ffff:ffff:ffff:ffff:ffff", + // IPv6 - Loopback + "::1/128, ::2", + // IPv6 - /0 (doesn't match IPv4) + "::/0, 192.168.1.1" + }) + void testAddressShouldNotMatch(String cidr, String address) throws UnknownHostException { + SubnetPredicate predicate = new SubnetPredicate(cidr); + assertFalse(predicate.test(addr(address, 8080)), + address + " should NOT match " + cidr); + } + + // ========== Edge Cases ========== + + @Test + void testNullSocketAddress() { + SubnetPredicate predicate = new SubnetPredicate("192.168.1.0/24"); + assertFalse(predicate.test(null)); + } + + @Test + void testPortIsIgnored() throws UnknownHostException { + SubnetPredicate predicate = new SubnetPredicate("192.168.1.0/24"); + + assertTrue(predicate.test(addr("192.168.1.100", 80))); + assertTrue(predicate.test(addr("192.168.1.100", 443))); + assertTrue(predicate.test(addr("192.168.1.100", 8080))); + assertTrue(predicate.test(addr("192.168.1.100", 65535))); + } + + @Test + void testIPv4vsIPv6_NoMatch() throws UnknownHostException { + SubnetPredicate ipv4Predicate = new SubnetPredicate("192.168.1.0/24"); + SubnetPredicate ipv6Predicate = new SubnetPredicate("2001:db8::/32"); + + // IPv4 predicate doesn't match IPv6 address + assertFalse(ipv4Predicate.test(addr("2001:db8::1", 8080))); + + // IPv6 predicate doesn't match IPv4 address + assertFalse(ipv6Predicate.test(addr("192.168.1.1", 8080))); + } + + @Test + void testPredicateComposition_Or() throws UnknownHostException { + SubnetPredicate predicate1 = new SubnetPredicate("192.168.1.0/24"); + SubnetPredicate predicate2 = new SubnetPredicate("10.0.0.0/8"); + Predicate combined = predicate1.or(predicate2); + + assertTrue(combined.test(addr("192.168.1.100", 8080))); + assertTrue(combined.test(addr("10.20.30.40", 8080))); + assertFalse(combined.test(addr("172.16.0.1", 8080))); + } + + @Test + void testPredicateComposition_And() throws UnknownHostException { + SubnetPredicate predicate1 = new SubnetPredicate("192.168.0.0/16"); + SubnetPredicate predicate2 = new SubnetPredicate("192.168.1.0/24"); + Predicate combined = predicate1.and(predicate2); + + assertTrue(combined.test(addr("192.168.1.100", 8080))); + assertFalse(combined.test(addr("192.168.2.100", 8080))); + } + + @Test + void testPredicateComposition_Negate() throws UnknownHostException { + SubnetPredicate predicate = new SubnetPredicate("192.168.1.0/24"); + + assertTrue(predicate.test(addr("192.168.1.100", 8080))); + assertFalse(predicate.negate().test(addr("192.168.1.100", 8080))); + + assertFalse(predicate.test(addr("192.168.2.100", 8080))); + assertTrue(predicate.negate().test(addr("192.168.2.100", 8080))); + } + + @Test + void testToString_IPv4() { + SubnetPredicate predicate = new SubnetPredicate("192.168.1.0/24"); + assertEquals("192.168.1.0/24", predicate.toString()); + } + + @Test + void testToString_IPv6() { + SubnetPredicate predicate = new SubnetPredicate("2001:db8::/32"); + assertTrue(predicate.toString().contains("2001:db8")); + assertTrue(predicate.toString().contains("/32")); + } + + // ========== Invalid Input Tests ========== + + @Test + void testInvalidCIDR_NullInput() { + assertThrows(IllegalArgumentException.class, () -> new SubnetPredicate(null)); + } + + @ParameterizedTest(name = "Invalid CIDR: ''{0}''") + @ValueSource(strings = { + "", + "192.168.1.0", // Missing slash + "192.168.1.0/abc", // Invalid prefix + "192.168.1.0/33", // Prefix too large for IPv4 + "2001:db8::/129", // Prefix too large for IPv6 + "192.168.1.0/-1", // Negative prefix + "256.256.256.256/24" // Invalid IP address + }) + void testInvalidCIDR(String cidr) { + assertThrows(IllegalArgumentException.class, () -> new SubnetPredicate(cidr)); + } + + @Test + void testNonNormalizedCIDR_StillWorks() throws UnknownHostException { + // 192.168.1.100/24 is not normalized (should be 192.168.1.0/24) + // but the implementation should normalize it + SubnetPredicate predicate = new SubnetPredicate("192.168.1.100/24"); + + assertTrue(predicate.test(addr("192.168.1.0", 8080))); + assertTrue(predicate.test(addr("192.168.1.100", 8080))); + assertTrue(predicate.test(addr("192.168.1.255", 8080))); + assertFalse(predicate.test(addr("192.168.2.0", 8080))); + } + + // ========== Helper Methods ========== + + private InetSocketAddress addr(String host, int port) throws UnknownHostException { + return new InetSocketAddress(InetAddress.getByName(host), port); + } +} + diff --git a/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/cache/ConcurrentMapProxyAddressCacheTest.java b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/tools/cache/ConcurrentMapProxyAddressCacheTest.java similarity index 96% rename from proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/cache/ConcurrentMapProxyAddressCacheTest.java rename to proxy-socket-core/src/test/java/net/airvantage/proxysocket/tools/cache/ConcurrentMapProxyAddressCacheTest.java index 0875bae..158cd08 100644 --- a/proxy-socket-core/src/test/java/net/airvantage/proxysocket/core/cache/ConcurrentMapProxyAddressCacheTest.java +++ b/proxy-socket-core/src/test/java/net/airvantage/proxysocket/tools/cache/ConcurrentMapProxyAddressCacheTest.java @@ -1,8 +1,8 @@ -/* - * MIT License +/** + * BSD-3-Clause License. * Copyright (c) 2025 Semtech */ -package net.airvantage.proxysocket.core.cache; +package net.airvantage.proxysocket.tools.cache; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; diff --git a/proxy-socket-guava/src/test/java/net/airvantage/proxysocket/guava/GuavaProxyAddressCacheTest.java b/proxy-socket-guava/src/test/java/net/airvantage/proxysocket/guava/GuavaProxyAddressCacheTest.java index 245beda..159b259 100644 --- a/proxy-socket-guava/src/test/java/net/airvantage/proxysocket/guava/GuavaProxyAddressCacheTest.java +++ b/proxy-socket-guava/src/test/java/net/airvantage/proxysocket/guava/GuavaProxyAddressCacheTest.java @@ -1,5 +1,5 @@ -/* - * MIT License +/** + * BSD-3-Clause License. * Copyright (c) 2025 Semtech */ package net.airvantage.proxysocket.guava; diff --git a/proxy-socket-udp/pom.xml b/proxy-socket-udp/pom.xml new file mode 100644 index 0000000..372385e --- /dev/null +++ b/proxy-socket-udp/pom.xml @@ -0,0 +1,65 @@ + + + 4.0.0 + + net.airvantage + proxy-socket-java + 1.0.0-SNAPSHOT + + proxy-socket-udp + Proxy Protocol - UDP + jar + + + + net.airvantage + proxy-socket-core + ${project.version} + + + net.airvantage + proxy-socket-core + ${project.version} + test-jar + test + + + org.slf4j + slf4j-api + 2.0.17 + + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + + + + org.slf4j + slf4j-simple + 2.0.17 + test + + + org.mockito + mockito-core + 5.2.0 + test + + + + com.amazonaws.proprot + proprot + 1.0 + test + + + + + + + diff --git a/proxy-socket-udp/src/main/java/net/airvantage/proxysocket/udp/ProxyDatagramSocket.java b/proxy-socket-udp/src/main/java/net/airvantage/proxysocket/udp/ProxyDatagramSocket.java new file mode 100644 index 0000000..a26b9a2 --- /dev/null +++ b/proxy-socket-udp/src/main/java/net/airvantage/proxysocket/udp/ProxyDatagramSocket.java @@ -0,0 +1,121 @@ +/** + * BSD-3-Clause License. + * Copyright (c) 2025 Semtech + */ +package net.airvantage.proxysocket.udp; + +import net.airvantage.proxysocket.core.ProxyAddressCache; +import net.airvantage.proxysocket.core.ProxyProtocolMetricsListener; +import net.airvantage.proxysocket.core.ProxyProtocolParseException; +import net.airvantage.proxysocket.core.v2.ProxyHeader; +import net.airvantage.proxysocket.core.v2.ProxyProtocolV2Decoder; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.DatagramPacket; +import java.net.DatagramSocket; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.PortUnreachableException; +import java.net.SocketAddress; +import java.net.SocketException; +import java.net.SocketTimeoutException; +import java.nio.channels.IllegalBlockingModeException; +import java.util.function.Predicate; + +/** + * DatagramSocket that strips Proxy Protocol v2 headers and exposes real client address. + * + * Thread-safety: This class is thread-safe to the extent that {@link DatagramSocket} + * is documented as thread-safe for concurrent send/receive by the JDK. The internal + * cache and metrics listener are expected to be thread-safe. The implementation does + * not mutate shared state beyond those collaborators. + */ +public class ProxyDatagramSocket extends DatagramSocket { + private static final Logger LOG = LoggerFactory.getLogger(ProxyDatagramSocket.class); + + private final ProxyAddressCache addressCache; + private final ProxyProtocolMetricsListener metrics; + private final Predicate trustedProxyPredicate; + + public ProxyDatagramSocket(SocketAddress bindaddr, ProxyAddressCache cache, ProxyProtocolMetricsListener metrics, Predicate predicate) throws SocketException { + super(bindaddr); + this.addressCache = cache; + this.metrics = metrics; + this.trustedProxyPredicate = predicate; + } + + public ProxyDatagramSocket(ProxyAddressCache cache, ProxyProtocolMetricsListener metrics, Predicate predicate) throws SocketException { + this(new InetSocketAddress(0), cache, metrics, predicate); + } + + public ProxyDatagramSocket(int port, ProxyAddressCache cache, ProxyProtocolMetricsListener metrics, Predicate predicate) throws SocketException { + this(port, null, cache, metrics, predicate); + } + + public ProxyDatagramSocket(int port, java.net.InetAddress laddr, ProxyAddressCache cache, ProxyProtocolMetricsListener metrics, Predicate predicate) throws SocketException { + this(new InetSocketAddress(laddr, port), cache, metrics, predicate); + } + + @Override + public void receive(DatagramPacket packet) + throws IOException, SocketTimeoutException, PortUnreachableException, IllegalBlockingModeException { + + super.receive(packet); + + try { + InetSocketAddress lbAddress = (InetSocketAddress) packet.getSocketAddress(); + if (trustedProxyPredicate != null && !trustedProxyPredicate.test(lbAddress)) { + // Untrusted source: do not parse, deliver original packet + LOG.debug("Untrusted proxy source; delivering original packet."); + if (metrics != null) metrics.onUntrustedProxy(lbAddress.getAddress()); + return; + } + + ProxyHeader header = ProxyProtocolV2Decoder.parse(packet.getData(), packet.getOffset(), packet.getLength()); + if (metrics != null) metrics.onHeaderParsed(header); + + if (header.isLocal()) { + // LOCAL: not proxied + if (metrics != null) metrics.onLocal(lbAddress.getAddress()); + } + if (header.isProxy() && header.getProtocol() == ProxyHeader.TransportProtocol.DGRAM) { + if (metrics != null) metrics.onTrustedProxy(lbAddress.getAddress()); + + InetSocketAddress realClient = header.getSourceAddress(); + if (realClient != null) { // could be null if address family is unspecified or unix + if (addressCache != null) addressCache.put(realClient, lbAddress); + packet.setSocketAddress(realClient); + } + } + + int headerLen = header.getHeaderLength(); + LOG.trace("Stripping header: {} bytes, original length: {}", headerLen, packet.getLength()); + packet.setData(packet.getData(), packet.getOffset() + headerLen, packet.getLength() - headerLen); + } catch (ProxyProtocolParseException e) { + LOG.warn("Proxy socket parse error; delivering original packet.", e); + if (metrics != null) metrics.onParseError(e); + } + } + + @Override + public void send(DatagramPacket packet) throws IOException { + InetSocketAddress client = (InetSocketAddress) packet.getSocketAddress(); + InetSocketAddress lb = addressCache != null ? addressCache.get(client) : null; + + if (lb != null) { + packet.setSocketAddress(lb); + if (metrics != null) metrics.onCacheHit(client); + } else if (addressCache != null) { + // Cache miss: unable to map client to load balancer address, + LOG.warn("Cache miss for client {}; unable to map to load balancer address, dropping packet.", client); + if (metrics != null) metrics.onCacheMiss(client); + return; + // } else { + // No cache: deliver original packet + } + super.send(packet); + } +} diff --git a/proxy-socket-udp/src/test/java/net/airvantage/proxysocket/udp/ProxyDatagramSocketTest.java b/proxy-socket-udp/src/test/java/net/airvantage/proxysocket/udp/ProxyDatagramSocketTest.java new file mode 100644 index 0000000..655f0c2 --- /dev/null +++ b/proxy-socket-udp/src/test/java/net/airvantage/proxysocket/udp/ProxyDatagramSocketTest.java @@ -0,0 +1,263 @@ +/** + * BSD-3-Clause License. + * Copyright (c) 2025 Semtech + */ +package net.airvantage.proxysocket.udp; + +import net.airvantage.proxysocket.core.ProxyAddressCache; +import net.airvantage.proxysocket.core.ProxyProtocolMetricsListener; +import net.airvantage.proxysocket.core.v2.ProxyHeader; +import net.airvantage.proxysocket.core.v2.AwsProxyEncoderHelper; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import java.net.DatagramPacket; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for ProxyDatagramSocket IP address mapping and cache behavior. + */ +class ProxyDatagramSocketTest { + + private ProxyDatagramSocket socket; + private ProxyAddressCache mockCache; + private ProxyProtocolMetricsListener mockMetrics; + + private InetSocketAddress realClient; + private InetSocketAddress serviceAddress; + private InetSocketAddress backendAddress; + private int localPort; + + private byte[] buffer = new byte[2048]; + private byte[] proxyHeader; + + @BeforeEach + void setUp() throws Exception { + mockCache = mock(ProxyAddressCache.class); + mockMetrics = mock(ProxyProtocolMetricsListener.class); + + realClient = new InetSocketAddress(InetAddress.getLoopbackAddress(), 12345); + serviceAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), 54321); + socket = new ProxyDatagramSocket(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), mockCache, mockMetrics, null); + localPort = socket.getLocalPort(); + backendAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), localPort); + + proxyHeader = new AwsProxyEncoderHelper() + .family(ProxyHeader.AddressFamily.AF_INET) + .socket(ProxyHeader.TransportProtocol.DGRAM) + .source(realClient) + .destination(serviceAddress) + .build(); + } + + @AfterEach + void tearDown() { + if (socket != null && !socket.isClosed()) { + socket.close(); + } + } + + @Test + void receive_withValidProxyHeader() throws Exception { + // Arrange + byte[] payload = "test-data".getBytes(StandardCharsets.UTF_8); + byte[] packet = Utility.createPacket(proxyHeader, payload); + Utility.sendPacket(packet, serviceAddress, backendAddress); + + // Act + DatagramPacket receivePacket = new DatagramPacket(buffer, buffer.length); + socket.receive(receivePacket); + + // Assert - cache should be populated with realClient -> lbAddress mapping + ArgumentCaptor clientCaptor = ArgumentCaptor.forClass(InetSocketAddress.class); + ArgumentCaptor lbCaptor = ArgumentCaptor.forClass(InetSocketAddress.class); + verify(mockCache).put(clientCaptor.capture(), lbCaptor.capture()); + + assertEquals(realClient, clientCaptor.getValue()); + assertEquals(serviceAddress.getAddress(), lbCaptor.getValue().getAddress()); + assertEquals(serviceAddress.getPort(), lbCaptor.getValue().getPort()); + + // Verify packet was modified to show real client address + assertEquals(realClient, receivePacket.getSocketAddress()); + + // Verify payload was stripped of proxy header + assertEquals(payload.length, receivePacket.getLength()); + assertArrayEquals(payload, + java.util.Arrays.copyOfRange(receivePacket.getData(), + receivePacket.getOffset(), + receivePacket.getOffset() + receivePacket.getLength())); + + verify(mockMetrics).onHeaderParsed(any(ProxyHeader.class)); + verify(mockMetrics).onTrustedProxy(serviceAddress.getAddress()); + verify(mockMetrics, never()).onUntrustedProxy(any()); + verify(mockMetrics, never()).onParseError(any()); + verify(mockMetrics, never()).onLocal(any()); + } + + @Test + void send_withCacheHit_usesLoadBalancerAddress() throws Exception { + byte[] payload = "response".getBytes(StandardCharsets.UTF_8); + + // Mock cache to return lb address + when(mockCache.get(realClient)).thenReturn(serviceAddress); + + // Create a receiver to verify the packet destination + java.net.DatagramSocket receiver = new java.net.DatagramSocket(serviceAddress); + receiver.setSoTimeout(1000); + + try { + // Act - send to real client, should be redirected to LB + DatagramPacket sendPacket = new DatagramPacket(payload, payload.length, realClient); + socket.send(sendPacket); + + // Verify packet was sent to LB address + DatagramPacket receivePacket = new DatagramPacket(buffer, buffer.length); + receiver.receive(receivePacket); + + // Assert + assertArrayEquals(payload, + java.util.Arrays.copyOfRange(receivePacket.getData(), 0, receivePacket.getLength())); + + // Verify cache was queried + verify(mockCache).get(realClient); + + // Verify metrics - cache hit + verify(mockMetrics).onCacheHit(realClient); + verify(mockMetrics, never()).onCacheMiss(any()); + } finally { + receiver.close(); + } + } + + @Test + void send_withCacheMiss_dropsPacket() throws Exception { + // Arrange + byte[] payload = "response".getBytes(StandardCharsets.UTF_8); + + // Mock cache to return null (cache miss) + when(mockCache.get(realClient)).thenReturn(null); + + // Create a receiver at the client address to verify packet is NOT sent + java.net.DatagramSocket receiver = new java.net.DatagramSocket(realClient); + receiver.setSoTimeout(500); // Short timeout since we expect no packet + + try { + // Act - send to client address + DatagramPacket sendPacket = new DatagramPacket(payload, payload.length, realClient); + socket.send(sendPacket); + + // Try to receive - should timeout since packet was dropped + DatagramPacket receivePacket = new DatagramPacket(buffer, buffer.length); + + assertThrows(java.net.SocketTimeoutException.class, () -> { + receiver.receive(receivePacket); + }, "Expected packet to be dropped on cache miss"); + + // Verify cache was queried + verify(mockCache).get(realClient); + + // Verify metrics - cache miss + verify(mockMetrics).onCacheMiss(realClient); + verify(mockMetrics, never()).onCacheHit(any()); + } finally { + receiver.close(); + } + } + + @Test + void receive_withLocalCommand_doesNotPopulateCache() throws Exception { + // Arrange - create LOCAL command (not proxied) + byte[] payload = "local".getBytes(StandardCharsets.UTF_8); + byte[] localProxyHeader = new AwsProxyEncoderHelper() + .command(ProxyHeader.Command.LOCAL) + .build(); + byte[] packet = Utility.createPacket(localProxyHeader, payload); + Utility.sendPacket(packet, backendAddress); + + // Act + DatagramPacket receivePacket = new DatagramPacket(buffer, buffer.length); + socket.receive(receivePacket); + + // Assert - cache should NOT be populated for LOCAL commands + verify(mockCache, never()).put(any(), any()); + + // But metrics should still be called + verify(mockMetrics).onHeaderParsed(any()); + + // Payload should be stripped of header + assertEquals(payload.length, receivePacket.getLength()); + } + + @Test + void receive_withTcpProtocol_doesNotPopulateCache() throws Exception { + // Arrange - create header with TCP (not DGRAM) protocol + byte[] payload = "tcp".getBytes(StandardCharsets.UTF_8); + byte[] tcpProxyHeader = new AwsProxyEncoderHelper() + .family(ProxyHeader.AddressFamily.AF_INET) + .socket(ProxyHeader.TransportProtocol.STREAM) // TCP, not UDP + .source(realClient) + .destination(serviceAddress) + .build(); + + byte[] packet = Utility.createPacket(tcpProxyHeader, payload); + Utility.sendPacket(packet, backendAddress); + + // Act + DatagramPacket receivePacket = new DatagramPacket(buffer, buffer.length); + socket.receive(receivePacket); + + // Assert - cache should NOT be populated for non-DGRAM protocols + verify(mockCache, never()).put(any(), any()); + + // Metrics should still be called + verify(mockMetrics).onHeaderParsed(any()); + } + + + @Test + void receive_withValidProxyHeader_callsMetricsOnHeaderParsed() throws Exception { + // Arrange + byte[] payload = "test".getBytes(StandardCharsets.UTF_8); + byte[] packet = Utility.createPacket(proxyHeader, payload); + Utility.sendPacket(packet, backendAddress); + + // Act + DatagramPacket receivePacket = new DatagramPacket(buffer, buffer.length); + socket.receive(receivePacket); + + // Assert - onHeaderParsed should be called + ArgumentCaptor headerCaptor = ArgumentCaptor.forClass(ProxyHeader.class); + verify(mockMetrics).onHeaderParsed(headerCaptor.capture()); + + ProxyHeader capturedHeader = headerCaptor.getValue(); + assertNotNull(capturedHeader); + assertEquals(ProxyHeader.TransportProtocol.DGRAM, capturedHeader.getProtocol()); + assertEquals(realClient, capturedHeader.getSourceAddress()); + } + + @Test + void receive_withInvalidData_callsMetricsOnParseError() throws Exception { + // Arrange - send garbage data + byte[] garbage = "not-a-proxy-header".getBytes(StandardCharsets.UTF_8); + Utility.sendPacket(garbage, backendAddress); + + // Act + DatagramPacket receivePacket = new DatagramPacket(buffer, buffer.length); + socket.receive(receivePacket); + + // Assert - onParseError should be called + verify(mockMetrics).onParseError(any(Exception.class)); + + // Original packet should be delivered unchanged + assertEquals(garbage.length, receivePacket.getLength()); + } + +} + diff --git a/proxy-socket-udp/src/test/java/net/airvantage/proxysocket/udp/ProxyDatagramSocketUnTrustedProxyTest.java b/proxy-socket-udp/src/test/java/net/airvantage/proxysocket/udp/ProxyDatagramSocketUnTrustedProxyTest.java new file mode 100644 index 0000000..e27122a --- /dev/null +++ b/proxy-socket-udp/src/test/java/net/airvantage/proxysocket/udp/ProxyDatagramSocketUnTrustedProxyTest.java @@ -0,0 +1,103 @@ +/** + * BSD-3-Clause License. + * Copyright (c) 2025 Semtech + */ +package net.airvantage.proxysocket.udp; + +import net.airvantage.proxysocket.core.ProxyAddressCache; +import net.airvantage.proxysocket.core.ProxyProtocolMetricsListener; +import net.airvantage.proxysocket.core.v2.ProxyHeader; +import net.airvantage.proxysocket.core.v2.AwsProxyEncoderHelper; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.net.DatagramPacket; +import java.net.DatagramSocket; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for ProxyDatagramSocket with untrusted proxy source. + */ +class ProxyDatagramSocketUnTrustedProxyTest { + + private ProxyDatagramSocket socket; + private ProxyAddressCache mockCache; + private ProxyProtocolMetricsListener mockMetrics; + private InetSocketAddress realClient; + private InetSocketAddress serviceAddress; + private InetSocketAddress backendAddress; + private int localPort; + private byte[] buffer = new byte[2048]; + + @BeforeEach + void setUp() throws Exception { + mockCache = mock(ProxyAddressCache.class); + mockMetrics = mock(ProxyProtocolMetricsListener.class); + + realClient = new InetSocketAddress(InetAddress.getLoopbackAddress(), 12345); + serviceAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), 54321); + socket = new ProxyDatagramSocket(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), mockCache, mockMetrics, addr -> false); + localPort = socket.getLocalPort(); + backendAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), localPort); + } + + @AfterEach + void tearDown() { + if (socket != null && !socket.isClosed()) { + socket.close(); + } + } + + @Test + void receive_withUntrustedProxy_doesNotStripHeader() throws Exception { + InetAddress lbAddress; + + byte[] payload = "test".getBytes(StandardCharsets.UTF_8); + byte[] proxyHeader = new AwsProxyEncoderHelper() + .family(ProxyHeader.AddressFamily.AF_INET) + .socket(ProxyHeader.TransportProtocol.DGRAM) + .source(realClient) + .destination(serviceAddress) + .build(); + + byte[] packet = Utility.createPacket(proxyHeader, payload); + + try (DatagramSocket sender = new DatagramSocket(realClient)) { + sender.send(new DatagramPacket(packet, packet.length, backendAddress)); + lbAddress = sender.getLocalAddress(); + } + + // Act + DatagramPacket receivePacket = new DatagramPacket(buffer, buffer.length); + socket.receive(receivePacket); + + // Assert - no metrics should be called for untrusted sources + verify(mockMetrics, never()).onHeaderParsed(any()); + verify(mockMetrics, never()).onParseError(any()); + verify(mockMetrics, never()).onTrustedProxy(any()); + verify(mockMetrics).onUntrustedProxy(lbAddress); + + // Packet length should include proxy header (not stripped) + assertEquals(packet.length, receivePacket.getLength()); + } + + @Test + void receive_withUntrustedProxy_doesNotParse() throws Exception { + byte[] packet = "Not a proxy header".getBytes(StandardCharsets.UTF_8); + Utility.sendPacket(packet, backendAddress); + + // Act + DatagramPacket receivePacket = new DatagramPacket(buffer, buffer.length); + assertDoesNotThrow(() -> socket.receive(receivePacket), "No ProxyProtocolException should be thrown"); + + // Packet length should be the same as the original packet length + assertEquals(packet.length, receivePacket.getLength()); + } +} + diff --git a/proxy-socket-udp/src/test/java/net/airvantage/proxysocket/udp/Utility.java b/proxy-socket-udp/src/test/java/net/airvantage/proxysocket/udp/Utility.java new file mode 100644 index 0000000..fec669b --- /dev/null +++ b/proxy-socket-udp/src/test/java/net/airvantage/proxysocket/udp/Utility.java @@ -0,0 +1,81 @@ +/** + * BSD-3-Clause License. + * Copyright (c) 2025 Semtech + */ +package net.airvantage.proxysocket.udp; + +import java.io.IOException; +import java.net.DatagramPacket; +import java.net.DatagramSocket; +import java.net.InetSocketAddress; + +public final class Utility { + private Utility() {} + + /** + * Creates a packet by combining proxy header and payload. + */ + public static byte[] createPacket(byte[] proxyHeader, byte[] payload) { + byte[] packet = new byte[proxyHeader.length + payload.length]; + System.arraycopy(proxyHeader, 0, packet, 0, proxyHeader.length); + System.arraycopy(payload, 0, packet, proxyHeader.length, payload.length); + return packet; + } + + /** + * Sends a packet to the specified destination using an ephemeral DatagramSocket. + */ + public static void sendPacket(byte[] packet, InetSocketAddress destination) throws IOException { + try (DatagramSocket sender = new DatagramSocket()) { + sender.send(new DatagramPacket(packet, packet.length, destination)); + } + } + + /** + * Sends a packet to the specified destination using a DatagramSocket bound to the specified source address. + */ + public static void sendPacket(byte[] packet, InetSocketAddress source, InetSocketAddress destination) throws IOException { + try (DatagramSocket sender = new DatagramSocket(source)) { + sender.send(new DatagramPacket(packet, packet.length, destination)); + } + } + + /** + * Returns a hexdump of the specified data for debugging purposes. + * @param data The data to dump. + * @param offset The offset into the data to start dumping. + * @param length The length of the data to dump. + * @return A string containing the hexdump. + */ + public static String hexdump(byte[] data, int offset, int length) { + StringBuilder sb = new StringBuilder(); + for (int i = offset; i < offset + length; i += 16) { + // Offset + sb.append(String.format("%08x ", i)); + + // Hex bytes (first 8) + for (int j = 0; j < 8 && i + j < offset + length; j++) { + sb.append(String.format("%02x ", data[i + j])); + } + sb.append(" "); + + // Hex bytes (second 8) + for (int j = 8; j < 16 && i + j < offset + length; j++) { + sb.append(String.format("%02x ", data[i + j])); + } + + // Padding if last line is incomplete + int remaining = 16 - Math.min(16, offset + length - i); + sb.append(" ".repeat(Math.max(0, remaining))); + + // ASCII representation + sb.append(" |"); + for (int j = 0; j < 16 && i + j < offset + length; j++) { + byte b = data[i + j]; + sb.append((b >= 32 && b < 127) ? (char) b : '.'); + } + sb.append("|\n"); + } + return sb.toString(); + } +} \ No newline at end of file