Skip to content

Commit 9611a18

Browse files
committed
test: add VertexAiClientTest for URL encoding of user parameters
10 tests covering all 5 methods that construct URLs with user-supplied values. Verifies query parameter injection (& = characters), path traversal (../ sequences), and normal values pass through correctly.
1 parent 9168126 commit 9611a18

File tree

2 files changed

+257
-10
lines changed

2 files changed

+257
-10
lines changed

core/src/main/java/com/google/adk/sessions/VertexAiClient.java

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,22 @@ private static String encodeParam(String value) {
112112
Maybe<JsonNode> listSessions(String reasoningEngineId, String userId) {
113113
return performApiRequest(
114114
"GET",
115-
"reasoningEngines/" + reasoningEngineId
116-
+ "/sessions?filter=user_id=" + encodeParam(userId),
115+
"reasoningEngines/"
116+
+ reasoningEngineId
117+
+ "/sessions?filter=user_id="
118+
+ encodeParam(userId),
117119
"")
118120
.flatMapMaybe(VertexAiClient::getJsonResponse);
119121
}
120122

121123
Maybe<JsonNode> listEvents(String reasoningEngineId, String sessionId) {
122124
return performApiRequest(
123125
"GET",
124-
"reasoningEngines/" + reasoningEngineId
125-
+ "/sessions/" + encodeParam(sessionId) + "/events",
126+
"reasoningEngines/"
127+
+ reasoningEngineId
128+
+ "/sessions/"
129+
+ encodeParam(sessionId)
130+
+ "/events",
126131
"")
127132
.doOnSuccess(apiResponse -> logger.debug("List events response {}", apiResponse))
128133
.flatMapMaybe(VertexAiClient::getJsonResponse);
@@ -131,17 +136,15 @@ Maybe<JsonNode> listEvents(String reasoningEngineId, String sessionId) {
131136
Maybe<JsonNode> getSession(String reasoningEngineId, String sessionId) {
132137
return performApiRequest(
133138
"GET",
134-
"reasoningEngines/" + reasoningEngineId
135-
+ "/sessions/" + encodeParam(sessionId),
139+
"reasoningEngines/" + reasoningEngineId + "/sessions/" + encodeParam(sessionId),
136140
"")
137141
.flatMapMaybe(apiResponse -> getJsonResponse(apiResponse));
138142
}
139143

140144
Completable deleteSession(String reasoningEngineId, String sessionId) {
141145
return performApiRequest(
142146
"DELETE",
143-
"reasoningEngines/" + reasoningEngineId
144-
+ "/sessions/" + encodeParam(sessionId),
147+
"reasoningEngines/" + reasoningEngineId + "/sessions/" + encodeParam(sessionId),
145148
"")
146149
.doOnSuccess(ApiResponse::close)
147150
.ignoreElement();
@@ -150,8 +153,11 @@ Completable deleteSession(String reasoningEngineId, String sessionId) {
150153
Completable appendEvent(String reasoningEngineId, String sessionId, String eventJson) {
151154
return performApiRequest(
152155
"POST",
153-
"reasoningEngines/" + reasoningEngineId
154-
+ "/sessions/" + encodeParam(sessionId) + ":appendEvent",
156+
"reasoningEngines/"
157+
+ reasoningEngineId
158+
+ "/sessions/"
159+
+ encodeParam(sessionId)
160+
+ ":appendEvent",
155161
eventJson)
156162
.flatMapCompletable(
157163
response -> {
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
package com.google.adk.sessions;
2+
3+
import static com.google.common.truth.Truth.assertThat;
4+
import static org.mockito.ArgumentMatchers.anyString;
5+
import static org.mockito.Mockito.verify;
6+
import static org.mockito.Mockito.when;
7+
8+
import okhttp3.MediaType;
9+
import okhttp3.ResponseBody;
10+
import org.junit.Before;
11+
import org.junit.Test;
12+
import org.junit.runner.RunWith;
13+
import org.junit.runners.JUnit4;
14+
import org.mockito.ArgumentCaptor;
15+
import org.mockito.Mock;
16+
import org.mockito.MockitoAnnotations;
17+
18+
/**
19+
* Unit tests for URL encoding in {@link VertexAiClient}.
20+
*
21+
* <p>Verifies that userId and sessionId values are properly URL-encoded before being concatenated
22+
* into API request paths, preventing query parameter injection and path traversal attacks.
23+
*/
24+
@RunWith(JUnit4.class)
25+
public class VertexAiClientTest {
26+
27+
private static final MediaType JSON_MEDIA_TYPE =
28+
MediaType.parse("application/json; charset=utf-8");
29+
30+
@Mock private HttpApiClient mockApiClient;
31+
32+
private VertexAiClient client;
33+
34+
@Before
35+
public void setUp() {
36+
MockitoAnnotations.openMocks(this);
37+
client = new VertexAiClient("test-project", "test-location", mockApiClient);
38+
}
39+
40+
/** Returns a mock ApiResponse with the given JSON body. */
41+
private static ApiResponse responseWithBody(String body) {
42+
return new ApiResponse() {
43+
@Override
44+
public ResponseBody getResponseBody() {
45+
return ResponseBody.create(JSON_MEDIA_TYPE, body);
46+
}
47+
48+
@Override
49+
public void close() {}
50+
};
51+
}
52+
53+
// ---------------------------------------------------------------------------
54+
// listSessions: userId encoding
55+
// ---------------------------------------------------------------------------
56+
57+
@Test
58+
public void listSessions_encodesUserIdWithQueryInjection() {
59+
String maliciousUserId = "user&extra=value";
60+
when(mockApiClient.request(anyString(), anyString(), anyString()))
61+
.thenReturn(responseWithBody("{\"sessions\": []}"));
62+
63+
client.listSessions("123", maliciousUserId).blockingGet();
64+
65+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
66+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
67+
68+
String path = pathCaptor.getValue();
69+
// The ampersand must be encoded as %26, not left raw
70+
assertThat(path).contains("user%26extra%3Dvalue");
71+
assertThat(path).doesNotContain("&extra=value");
72+
}
73+
74+
@Test
75+
public void listSessions_encodesUserIdWithSpaces() {
76+
String userIdWithSpaces = "user name with spaces";
77+
when(mockApiClient.request(anyString(), anyString(), anyString()))
78+
.thenReturn(responseWithBody("{\"sessions\": []}"));
79+
80+
client.listSessions("123", userIdWithSpaces).blockingGet();
81+
82+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
83+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
84+
85+
String path = pathCaptor.getValue();
86+
// Spaces must be encoded (as + or %20)
87+
assertThat(path).doesNotContain(" name ");
88+
assertThat(path).contains("filter=user_id=user");
89+
}
90+
91+
@Test
92+
public void listSessions_normalUserIdPassesThroughCorrectly() {
93+
String normalUserId = "user123";
94+
when(mockApiClient.request(anyString(), anyString(), anyString()))
95+
.thenReturn(responseWithBody("{\"sessions\": []}"));
96+
97+
client.listSessions("456", normalUserId).blockingGet();
98+
99+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
100+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
101+
102+
String path = pathCaptor.getValue();
103+
assertThat(path).isEqualTo("reasoningEngines/456/sessions?filter=user_id=user123");
104+
}
105+
106+
// ---------------------------------------------------------------------------
107+
// getSession: sessionId encoding
108+
// ---------------------------------------------------------------------------
109+
110+
@Test
111+
public void getSession_encodesSessionIdWithPathTraversal() {
112+
String maliciousSessionId = "../../secret";
113+
when(mockApiClient.request(anyString(), anyString(), anyString()))
114+
.thenReturn(
115+
responseWithBody(
116+
"{\"name\": \"sessions/safe\", \"updateTime\": \"2024-12-12T12:12:12.123456Z\"}"));
117+
118+
client.getSession("123", maliciousSessionId).blockingGet();
119+
120+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
121+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
122+
123+
String path = pathCaptor.getValue();
124+
// Path traversal characters must be encoded
125+
assertThat(path).doesNotContain("../../");
126+
assertThat(path).contains("..%2F..%2Fsecret");
127+
}
128+
129+
@Test
130+
public void getSession_encodesSessionIdWithSlashes() {
131+
String sessionIdWithSlashes = "session/with/slashes";
132+
when(mockApiClient.request(anyString(), anyString(), anyString()))
133+
.thenReturn(
134+
responseWithBody(
135+
"{\"name\": \"sessions/safe\", \"updateTime\": \"2024-12-12T12:12:12.123456Z\"}"));
136+
137+
client.getSession("123", sessionIdWithSlashes).blockingGet();
138+
139+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
140+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
141+
142+
String path = pathCaptor.getValue();
143+
// Slashes in sessionId must be encoded as %2F
144+
assertThat(path).contains("session%2Fwith%2Fslashes");
145+
}
146+
147+
@Test
148+
public void getSession_normalSessionIdPassesThroughCorrectly() {
149+
String normalSessionId = "abc123";
150+
when(mockApiClient.request(anyString(), anyString(), anyString()))
151+
.thenReturn(
152+
responseWithBody(
153+
"{\"name\": \"sessions/abc123\", \"updateTime\": \"2024-12-12T12:12:12.123456Z\"}"));
154+
155+
client.getSession("456", normalSessionId).blockingGet();
156+
157+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
158+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
159+
160+
String path = pathCaptor.getValue();
161+
assertThat(path).isEqualTo("reasoningEngines/456/sessions/abc123");
162+
}
163+
164+
// ---------------------------------------------------------------------------
165+
// deleteSession: sessionId encoding
166+
// ---------------------------------------------------------------------------
167+
168+
@Test
169+
public void deleteSession_encodesSessionIdWithSpecialCharacters() {
170+
String maliciousSessionId = "session&admin=true";
171+
when(mockApiClient.request(anyString(), anyString(), anyString()))
172+
.thenReturn(responseWithBody(""));
173+
174+
client.deleteSession("123", maliciousSessionId).blockingAwait();
175+
176+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
177+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
178+
179+
String path = pathCaptor.getValue();
180+
assertThat(path).doesNotContain("&admin=true");
181+
assertThat(path).contains("session%26admin%3Dtrue");
182+
}
183+
184+
// ---------------------------------------------------------------------------
185+
// listEvents: sessionId encoding
186+
// ---------------------------------------------------------------------------
187+
188+
@Test
189+
public void listEvents_encodesSessionIdWithPathTraversal() {
190+
String maliciousSessionId = "../other-engine/sessions/target/events";
191+
when(mockApiClient.request(anyString(), anyString(), anyString()))
192+
.thenReturn(responseWithBody("{\"sessionEvents\": []}"));
193+
194+
client.listEvents("123", maliciousSessionId).blockingGet();
195+
196+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
197+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
198+
199+
String path = pathCaptor.getValue();
200+
// The slashes and dots must be encoded, not treated as path separators
201+
assertThat(path).doesNotContain("../other-engine");
202+
assertThat(path).startsWith("reasoningEngines/123/sessions/");
203+
assertThat(path).endsWith("/events");
204+
}
205+
206+
// ---------------------------------------------------------------------------
207+
// appendEvent: sessionId encoding
208+
// ---------------------------------------------------------------------------
209+
210+
@Test
211+
public void appendEvent_encodesSessionIdWithSpecialCharacters() {
212+
String maliciousSessionId = "sess%00ion";
213+
when(mockApiClient.request(anyString(), anyString(), anyString()))
214+
.thenReturn(responseWithBody("{}"));
215+
216+
client.appendEvent("123", maliciousSessionId, "{}").blockingAwait();
217+
218+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
219+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
220+
221+
String path = pathCaptor.getValue();
222+
// The % must itself be encoded as %25
223+
assertThat(path).contains("sess%2500ion");
224+
assertThat(path).endsWith(":appendEvent");
225+
}
226+
227+
@Test
228+
public void appendEvent_normalSessionIdPassesThroughCorrectly() {
229+
String normalSessionId = "session42";
230+
when(mockApiClient.request(anyString(), anyString(), anyString()))
231+
.thenReturn(responseWithBody("{}"));
232+
233+
client.appendEvent("789", normalSessionId, "{}").blockingAwait();
234+
235+
ArgumentCaptor<String> pathCaptor = ArgumentCaptor.forClass(String.class);
236+
verify(mockApiClient).request(anyString(), pathCaptor.capture(), anyString());
237+
238+
String path = pathCaptor.getValue();
239+
assertThat(path).isEqualTo("reasoningEngines/789/sessions/session42:appendEvent");
240+
}
241+
}

0 commit comments

Comments
 (0)