Skip to content

Commit a377e8a

Browse files
committed
Replace use of Map<String, String> with ProxyMap of headers.
This allows us to support repeated headers. Signed-off-by: Hiram Chirino <hiram@hiramchirino.com>
1 parent a1c0c62 commit a377e8a

File tree

12 files changed

+273
-147
lines changed

12 files changed

+273
-147
lines changed

src/main/java/io/roastedroot/proxywasm/ABI.java

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import com.dylibso.chicory.wasm.InvalidException;
1313
import java.nio.ByteBuffer;
1414
import java.nio.charset.StandardCharsets;
15-
import java.util.HashMap;
1615
import java.util.Map;
1716

1817
@HostModule("env")
@@ -675,19 +674,18 @@ int proxyGetHeaderMapSize(int mapType, int returnSize) {
675674
try {
676675

677676
// Get the header map based on the map type
678-
Map<String, String> header = getMap(mapType);
677+
ProxyMap header = getMap(mapType);
679678
if (header == null) {
680679
return WasmResult.BAD_ARGUMENT.getValue();
681680
}
682681

683682
// to clone the headers so that they don't change on while we process them in the loop
684-
final Map<String, String> cloneMap = new HashMap<>();
683+
var cloneMap = new ArrayProxyMap(header);
685684
int totalBytesLen = U32_LEN; // Start with space for the count
686685

687-
for (Map.Entry<String, String> entry : header.entrySet()) {
686+
for (Map.Entry<String, String> entry : cloneMap.entries()) {
688687
String key = entry.getKey();
689688
String value = entry.getValue();
690-
cloneMap.put(key, value);
691689
totalBytesLen += U32_LEN + U32_LEN; // keyLen + valueLen
692690
totalBytesLen += key.length() + 1 + value.length() + 1; // key + \0 + value + \0
693691
}
@@ -717,19 +715,18 @@ int proxyGetHeaderMapPairs(int mapType, int returnDataPtr, int returnDataSize) {
717715
try {
718716

719717
// Get the header map based on the map type
720-
Map<String, String> header = getMap(mapType);
718+
ProxyMap header = getMap(mapType);
721719
if (header == null) {
722720
return WasmResult.NOT_FOUND.getValue();
723721
}
724722

725723
// to clone the headers so that they don't change on while we process them in the loop
726-
final Map<String, String> cloneMap = new HashMap<>();
724+
var cloneMap = new ArrayProxyMap(header);
727725
int totalBytesLen = U32_LEN; // Start with space for the count
728726

729-
for (Map.Entry<String, String> entry : header.entrySet()) {
727+
for (Map.Entry<String, String> entry : cloneMap.entries()) {
730728
String key = entry.getKey();
731729
String value = entry.getValue();
732-
cloneMap.put(key, value);
733730
totalBytesLen += U32_LEN + U32_LEN; // keyLen + valueLen
734731
totalBytesLen += key.length() + 1 + value.length() + 1; // key + \0 + value + \0
735732
}
@@ -745,7 +742,7 @@ int proxyGetHeaderMapPairs(int mapType, int returnDataPtr, int returnDataSize) {
745742
int dataPtr = lenPtr + ((U32_LEN + U32_LEN) * cloneMap.size());
746743

747744
// Write each key-value pair to memory
748-
for (Map.Entry<String, String> entry : cloneMap.entrySet()) {
745+
for (Map.Entry<String, String> entry : cloneMap.entries()) {
749746
String key = entry.getKey();
750747
String value = entry.getValue();
751748

@@ -802,14 +799,14 @@ int proxySetHeaderMapPairs(int mapType, int ptr, int size) {
802799

803800
try {
804801
// Get the header map based on the map type
805-
Map<String, String> headerMap = getMap(mapType);
802+
ProxyMap headerMap = getMap(mapType);
806803
if (headerMap == null) {
807804
return WasmResult.BAD_ARGUMENT.getValue();
808805
}
809806

810807
// Decode the map content and set each key-value pair
811-
Map<String, String> newMap = decodeMap(ptr, size);
812-
for (Map.Entry<String, String> entry : newMap.entrySet()) {
808+
ProxyMap newMap = decodeMap(ptr, size);
809+
for (Map.Entry<String, String> entry : newMap.entries()) {
813810
headerMap.put(entry.getKey(), entry.getValue());
814811
}
815812

@@ -837,7 +834,7 @@ int proxyGetHeaderMapValue(
837834
int mapType, int keyDataPtr, int keySize, int valueDataPtr, int valueSize) {
838835
try {
839836
// Get the header map based on the map type
840-
Map<String, String> headerMap = getMap(mapType);
837+
ProxyMap headerMap = getMap(mapType);
841838
if (headerMap == null) {
842839
return WasmResult.BAD_ARGUMENT.getValue();
843840
}
@@ -895,7 +892,7 @@ int proxyReplaceHeaderMapValue(
895892
int mapType, int keyDataPtr, int keySize, int valueDataPtr, int valueSize) {
896893
try {
897894
// Get the header map based on the map type
898-
Map<String, String> headerMap = getMap(mapType);
895+
ProxyMap headerMap = getMap(mapType);
899896
if (headerMap == null) {
900897
return WasmResult.BAD_ARGUMENT.getValue();
901898
}
@@ -907,7 +904,7 @@ int proxyReplaceHeaderMapValue(
907904
String value = readString(valueDataPtr, valueSize);
908905

909906
// Replace value in map
910-
var copy = new HashMap<>(headerMap);
907+
var copy = new ArrayProxyMap(headerMap);
911908
copy.put(key, value);
912909
setMap(mapType, copy);
913910

@@ -933,7 +930,7 @@ int proxyReplaceHeaderMapValue(
933930
int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) {
934931
try {
935932
// Get the header map based on the map type
936-
Map<String, String> headerMap = getMap(mapType);
933+
ProxyMap headerMap = getMap(mapType);
937934
if (headerMap == null) {
938935
return WasmResult.NOT_FOUND.getValue();
939936
}
@@ -945,7 +942,7 @@ int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) {
945942
}
946943

947944
// Remove key from map
948-
var copy = new HashMap<>(headerMap);
945+
var copy = new ArrayProxyMap(headerMap);
949946
copy.remove(key);
950947
setMap(mapType, copy);
951948

@@ -964,7 +961,7 @@ int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) {
964961
* @param mapType The type of map to get
965962
* @return The header map
966963
*/
967-
private Map<String, String> getMap(int mapType) {
964+
private ProxyMap getMap(int mapType) {
968965

969966
var knownType = MapType.fromInt(mapType);
970967
if (knownType == null) {
@@ -999,7 +996,7 @@ private Map<String, String> getMap(int mapType) {
999996
* @param map The header map to set
1000997
* @return WasmResult indicating success or failure
1001998
*/
1002-
private WasmResult setMap(int mapType, Map<String, String> map) {
999+
private WasmResult setMap(int mapType, ProxyMap map) {
10031000
var knownType = MapType.fromInt(mapType);
10041001
if (knownType == null) {
10051002
return handler.setCustomHeaders(mapType, map);
@@ -1043,9 +1040,9 @@ private WasmResult setMap(int mapType, Map<String, String> map) {
10431040
* @return The decoded map containing string keys and values
10441041
* @throws WasmException if there is an error accessing memory
10451042
*/
1046-
private HashMap<String, String> decodeMap(int addr, int mem_size) throws WasmException {
1043+
private ProxyMap decodeMap(int addr, int mem_size) throws WasmException {
10471044
if (mem_size < U32_LEN) {
1048-
return new HashMap<>();
1045+
return new ArrayProxyMap();
10491046
}
10501047

10511048
// Read header size (number of entries)
@@ -1055,11 +1052,11 @@ private HashMap<String, String> decodeMap(int addr, int mem_size) throws WasmExc
10551052
// mapSize + (key1_size + value1_size) * mapSize
10561053
long dataOffset = U32_LEN + (U32_LEN + U32_LEN) * mapSize;
10571054
if (dataOffset >= mem_size) {
1058-
return new HashMap<>();
1055+
return new ArrayProxyMap();
10591056
}
10601057

10611058
// Create result map with initial capacity
1062-
var result = new HashMap<String, String>((int) mapSize);
1059+
var result = new ArrayProxyMap((int) mapSize);
10631060

10641061
// Process each entry
10651062
for (int i = 0; i < mapSize; i++) {
@@ -1086,7 +1083,7 @@ private HashMap<String, String> decodeMap(int addr, int mem_size) throws WasmExc
10861083
dataOffset += valueSize + 1;
10871084

10881085
// Add to result map
1089-
result.put(key, value);
1086+
result.add(key, value);
10901087
}
10911088

10921089
return result;
@@ -1282,8 +1279,7 @@ int proxySendLocalResponse(
12821279
}
12831280

12841281
// Get and decode additional headers from memory
1285-
HashMap<String, String> additionalHeaders =
1286-
decodeMap(additionalHeadersMapData, additionalHeadersSize);
1282+
ProxyMap additionalHeaders = decodeMap(additionalHeadersMapData, additionalHeadersSize);
12871283

12881284
// Send the response through the handler
12891285
WasmResult result =
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package io.roastedroot.proxywasm;
2+
3+
import java.util.ArrayList;
4+
import java.util.Map;
5+
import java.util.Objects;
6+
7+
public class ArrayProxyMap implements ProxyMap {
8+
9+
final ArrayList<Map.Entry<String, String>> entries;
10+
11+
public ArrayProxyMap() {
12+
this.entries = new ArrayList<>();
13+
}
14+
15+
public ArrayProxyMap(int mapSize) {
16+
this.entries = new ArrayList<>(mapSize);
17+
}
18+
19+
public ArrayProxyMap(ProxyMap other) {
20+
this(other.size());
21+
for (Map.Entry<String, String> entry : other.entries()) {
22+
add(entry.getKey(), entry.getValue());
23+
}
24+
}
25+
26+
public ArrayProxyMap(Map<String, String> other) {
27+
this(other.size());
28+
for (Map.Entry<String, String> entry : other.entrySet()) {
29+
add(entry.getKey(), entry.getValue());
30+
}
31+
}
32+
33+
@Override
34+
public int size() {
35+
return entries.size();
36+
}
37+
38+
@Override
39+
public void add(String key, String value) {
40+
entries.add(Map.entry(key, value));
41+
}
42+
43+
@Override
44+
public void put(String key, String value) {
45+
this.remove(key);
46+
entries.add(Map.entry(key, value));
47+
}
48+
49+
@Override
50+
public Iterable<? extends Map.Entry<String, String>> entries() {
51+
return entries;
52+
}
53+
54+
@Override
55+
public String get(String key) {
56+
return entries.stream()
57+
.filter(x -> x.getKey().equals(key))
58+
.map(Map.Entry::getValue)
59+
.findFirst()
60+
.orElse(null);
61+
}
62+
63+
@Override
64+
public void remove(String key) {
65+
entries.removeIf(x -> x.getKey().equals(key));
66+
}
67+
68+
@Override
69+
public boolean equals(Object o) {
70+
if (o == null || getClass() != o.getClass()) {
71+
return false;
72+
}
73+
ArrayProxyMap that = (ArrayProxyMap) o;
74+
return Objects.equals(entries, that.entries);
75+
}
76+
77+
@Override
78+
public int hashCode() {
79+
return Objects.hashCode(entries);
80+
}
81+
82+
@Override
83+
public String toString() {
84+
return entries.toString();
85+
}
86+
}

0 commit comments

Comments
 (0)