Skip to content

Commit 922a5e1

Browse files
google-genai-botcopybara-github
authored andcommitted
fix:Refactor State.java to correctly merge state and delta for map interface methods
PiperOrigin-RevId: 871435115
1 parent 4ac1dd2 commit 922a5e1

File tree

2 files changed

+258
-28
lines changed

2 files changed

+258
-28
lines changed

core/src/main/java/com/google/adk/sessions/State.java

Lines changed: 115 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,68 +51,135 @@ public State(ConcurrentMap<String, Object> state, ConcurrentMap<String, Object>
5151
@Override
5252
public void clear() {
5353
state.clear();
54+
// Delta should likely be cleared too if we are clearing the state,
55+
// or we might want to mark everything as removed in delta.
56+
// Given the Python implementation doesn't have clear, and this is a local view,
57+
// clearing both seems appropriate to reset the object.
58+
delta.clear();
5459
}
5560

5661
@Override
5762
public boolean containsKey(Object key) {
63+
if (delta.containsKey(key)) {
64+
return delta.get(key) != REMOVED;
65+
}
5866
return state.containsKey(key);
5967
}
6068

6169
@Override
6270
public boolean containsValue(Object value) {
63-
return state.containsValue(value);
71+
// This is expensive but necessary for correctness with the merged view.
72+
return values().contains(value);
6473
}
6574

6675
@Override
6776
public Set<Entry<String, Object>> entrySet() {
68-
return state.entrySet();
77+
// This provides a snapshot, not a live view backed by the map, which differs from standard Map
78+
// contract.
79+
// However, given the complexity of merging two concurrent maps, this is a reasonable compromise
80+
// for this specific implementation.
81+
// TODO: Consider implementing a live view if needed.
82+
Map<String, Object> merged = new ConcurrentHashMap<>(state);
83+
for (Entry<String, Object> entry : delta.entrySet()) {
84+
if (entry.getValue() == REMOVED) {
85+
merged.remove(entry.getKey());
86+
} else {
87+
merged.put(entry.getKey(), entry.getValue());
88+
}
89+
}
90+
return merged.entrySet();
6991
}
7092

7193
@Override
7294
public boolean equals(Object o) {
7395
if (o == this) {
7496
return true;
7597
}
76-
if (!(o instanceof State other)) {
98+
if (!(o instanceof Map)) {
7799
return false;
78100
}
79-
return state.equals(other.state);
101+
Map<?, ?> other = (Map<?, ?>) o;
102+
// We can't easily rely on state.equals() because our "content" is merged.
103+
// Validating equality against another Map requires checking the merged view.
104+
if (size() != other.size()) {
105+
return false;
106+
}
107+
try {
108+
for (Entry<String, Object> e : entrySet()) {
109+
String key = e.getKey();
110+
Object value = e.getValue();
111+
if (value == null) {
112+
if (!(other.get(key) == null && other.containsKey(key))) return false;
113+
} else {
114+
if (!value.equals(other.get(key))) return false;
115+
}
116+
}
117+
} catch (ClassCastException | NullPointerException unused) {
118+
return false;
119+
}
120+
return true;
80121
}
81122

82123
@Override
83124
public Object get(Object key) {
125+
if (delta.containsKey(key)) {
126+
Object value = delta.get(key);
127+
return value == REMOVED ? null : value;
128+
}
84129
return state.get(key);
85130
}
86131

87132
@Override
88133
public int hashCode() {
89-
return state.hashCode();
134+
// Similar to equals, we need to calculate hash code based on the merged entry set.
135+
int h = 0;
136+
for (Entry<String, Object> entry : entrySet()) {
137+
h += entry.hashCode();
138+
}
139+
return h;
90140
}
91141

92142
@Override
93143
public boolean isEmpty() {
94-
return state.isEmpty();
144+
if (delta.isEmpty()) {
145+
return state.isEmpty();
146+
}
147+
// If delta is not empty, we need to check if it effectively removes everything from state
148+
// or adds something.
149+
return size() == 0;
95150
}
96151

97152
@Override
98153
public Set<String> keySet() {
99-
return state.keySet();
154+
// Snapshot view
155+
Map<String, Object> merged = new ConcurrentHashMap<>(state);
156+
for (Entry<String, Object> entry : delta.entrySet()) {
157+
if (entry.getValue() == REMOVED) {
158+
merged.remove(entry.getKey());
159+
} else {
160+
merged.put(entry.getKey(), entry.getValue());
161+
}
162+
}
163+
return merged.keySet();
100164
}
101165

102166
@Override
103167
public Object put(String key, Object value) {
104-
Object oldValue = state.put(key, value);
168+
// Current value logic needs to check delta first to return correct "oldValue"
169+
Object oldValue = get(key);
170+
state.put(key, value);
105171
delta.put(key, value);
106172
return oldValue;
107173
}
108174

109175
@Override
110176
public Object putIfAbsent(String key, Object value) {
111-
Object existingValue = state.putIfAbsent(key, value);
112-
if (existingValue == null) {
113-
delta.put(key, value);
177+
Object currentValue = get(key);
178+
if (currentValue == null) {
179+
put(key, value);
180+
return null;
114181
}
115-
return existingValue;
182+
return currentValue;
116183
}
117184

118185
@Override
@@ -123,47 +190,67 @@ public void putAll(Map<? extends String, ? extends Object> m) {
123190

124191
@Override
125192
public Object remove(Object key) {
126-
if (state.containsKey(key)) {
193+
Object oldValue = get(key);
194+
// We explicitly check for containment in the merged view to ensure we return the correct old
195+
// value.
196+
if (state.containsKey(key) || (delta.containsKey(key) && delta.get(key) != REMOVED)) {
127197
delta.put((String) key, REMOVED);
128198
}
129-
return state.remove(key);
199+
200+
// We remove from the state map to keep it consistent with the write-through behavior of this
201+
// class.
202+
state.remove(key);
203+
return oldValue;
130204
}
131205

132206
@Override
133207
public boolean remove(Object key, Object value) {
134-
boolean removed = state.remove(key, value);
135-
if (removed) {
136-
delta.put((String) key, REMOVED);
208+
Object currentValue = get(key);
209+
if (Objects.equals(currentValue, value) && (currentValue != null || containsKey(key))) {
210+
remove(key);
211+
return true;
137212
}
138-
return removed;
213+
return false;
139214
}
140215

141216
@Override
142217
public boolean replace(String key, Object oldValue, Object newValue) {
143-
boolean replaced = state.replace(key, oldValue, newValue);
144-
if (replaced) {
145-
delta.put(key, newValue);
218+
Object currentValue = get(key);
219+
if (Objects.equals(currentValue, oldValue) && (currentValue != null || containsKey(key))) {
220+
put(key, newValue);
221+
return true;
146222
}
147-
return replaced;
223+
return false;
148224
}
149225

150226
@Override
151227
public Object replace(String key, Object value) {
152-
Object oldValue = state.replace(key, value);
153-
if (oldValue != null) {
154-
delta.put(key, value);
228+
Object currentValue = get(key);
229+
if (currentValue != null || containsKey(key)) {
230+
put(key, value);
231+
return currentValue;
155232
}
156-
return oldValue;
233+
return null;
157234
}
158235

159236
@Override
160237
public int size() {
161-
return state.size();
238+
// Expensive, but accurate merged size.
239+
return entrySet().size();
162240
}
163241

164242
@Override
165243
public Collection<Object> values() {
166-
return state.values();
244+
// Snapshot view
245+
Map<String, Object> merged = new ConcurrentHashMap<>(state);
246+
for (Entry<String, Object> entry : delta.entrySet()) {
247+
if (entry.getValue() == REMOVED) {
248+
merged.remove(entry.getKey());
249+
} else {
250+
merged.put(entry.getKey(), entry.getValue());
251+
}
252+
}
253+
return merged.values();
167254
}
168255

169256
public boolean hasDelta() {
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.sessions;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
21+
import java.util.HashMap;
22+
import java.util.Map;
23+
import java.util.concurrent.ConcurrentHashMap;
24+
import java.util.concurrent.ConcurrentMap;
25+
import org.junit.Test;
26+
import org.junit.runner.RunWith;
27+
import org.junit.runners.JUnit4;
28+
29+
@RunWith(JUnit4.class)
30+
public final class StateDiffTest {
31+
32+
@Test
33+
public void get_returnsValueFromDeltaIfPresent() {
34+
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
35+
stateMap.put("key", "initialValue");
36+
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
37+
deltaMap.put("key", "newValue");
38+
State state = new State(stateMap, deltaMap);
39+
40+
assertThat(state.get("key")).isEqualTo("newValue");
41+
}
42+
43+
@Test
44+
public void get_returnsValueFromStateIfNotInDelta() {
45+
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
46+
stateMap.put("key", "initialValue");
47+
State state = new State(stateMap);
48+
49+
assertThat(state.get("key")).isEqualTo("initialValue");
50+
}
51+
52+
@Test
53+
public void get_returnsNullIfKeyInDeltaAsRemoved() {
54+
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
55+
stateMap.put("key", "initialValue");
56+
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
57+
deltaMap.put("key", State.REMOVED);
58+
State state = new State(stateMap, deltaMap);
59+
60+
assertThat(state.get("key")).isNull();
61+
}
62+
63+
@Test
64+
public void containsKey_respectsDelta() {
65+
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
66+
stateMap.put("key1", "value1");
67+
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
68+
deltaMap.put("key1", State.REMOVED);
69+
deltaMap.put("key2", "value2");
70+
State state = new State(stateMap, deltaMap);
71+
72+
assertThat(state.containsKey("key1")).isFalse();
73+
assertThat(state.containsKey("key2")).isTrue();
74+
}
75+
76+
@Test
77+
public void size_respectsDelta() {
78+
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
79+
stateMap.put("key1", "value1");
80+
stateMap.put("key2", "value2");
81+
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
82+
deltaMap.put("key1", State.REMOVED);
83+
deltaMap.put("key3", "value3");
84+
State state = new State(stateMap, deltaMap);
85+
86+
assertThat(state.size()).isEqualTo(2); // key2, key3
87+
}
88+
89+
@Test
90+
public void isEmpty_respectsDelta() {
91+
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
92+
stateMap.put("key1", "value1");
93+
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
94+
deltaMap.put("key1", State.REMOVED);
95+
State state = new State(stateMap, deltaMap);
96+
97+
assertThat(state.isEmpty()).isTrue();
98+
}
99+
100+
@Test
101+
public void entrySet_reflectsMergedState() {
102+
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
103+
stateMap.put("key1", "value1");
104+
stateMap.put("key2", "value2");
105+
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
106+
deltaMap.put("key1", "newValue1");
107+
deltaMap.put("key3", "value3");
108+
State state = new State(stateMap, deltaMap);
109+
110+
Map<String, Object> expected = new HashMap<>();
111+
expected.put("key1", "newValue1");
112+
expected.put("key2", "value2");
113+
expected.put("key3", "value3");
114+
115+
assertThat(state.entrySet()).containsExactlyElementsIn(expected.entrySet());
116+
}
117+
118+
@Test
119+
public void keySet_reflectsMergedState() {
120+
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
121+
stateMap.put("key1", "value1");
122+
stateMap.put("key2", "value2");
123+
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
124+
deltaMap.put("key1", State.REMOVED);
125+
deltaMap.put("key3", "value3");
126+
State state = new State(stateMap, deltaMap);
127+
128+
assertThat(state.keySet()).containsExactly("key2", "key3");
129+
}
130+
131+
@Test
132+
public void values_reflectsMergedState() {
133+
ConcurrentMap<String, Object> stateMap = new ConcurrentHashMap<>();
134+
stateMap.put("key1", "value1");
135+
stateMap.put("key2", "value2");
136+
ConcurrentMap<String, Object> deltaMap = new ConcurrentHashMap<>();
137+
deltaMap.put("key1", "newValue1");
138+
deltaMap.put("key3", "value3");
139+
State state = new State(stateMap, deltaMap);
140+
141+
assertThat(state.values()).containsExactly("newValue1", "value2", "value3");
142+
}
143+
}

0 commit comments

Comments
 (0)