Skip to content

Commit a98a809

Browse files
google-genai-botcopybara-github
authored andcommitted
fix:Refactor State.java to correctly merge state and delta for map interface methods
PiperOrigin-RevId: 871435115
1 parent 4ac1dd2 commit a98a809

2 files changed

Lines changed: 321 additions & 28 deletions

File tree

core/src/main/java/com/google/adk/sessions/State.java

Lines changed: 178 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,68 +51,135 @@ public State(ConcurrentMap<String, Object> state, ConcurrentMap<String, Object>
5151
@Override
5252
public void clear() {
5353
state.clear();
54+
// Delta should likely be cleared too if we are clearing the state,
55+
// or we might want to mark everything as removed in delta.
56+
// Given the Python implementation doesn't have clear, and this is a local view,
57+
// clearing both seems appropriate to reset the object.
58+
delta.clear();
5459
}
5560

5661
@Override
5762
public boolean containsKey(Object key) {
63+
if (delta.containsKey(key)) {
64+
return delta.get(key) != REMOVED;
65+
}
5866
return state.containsKey(key);
5967
}
6068

6169
@Override
6270
public boolean containsValue(Object value) {
63-
return state.containsValue(value);
71+
// This is expensive but necessary for correctness with the merged view.
72+
return values().contains(value);
6473
}
6574

6675
@Override
6776
public Set<Entry<String, Object>> entrySet() {
68-
return state.entrySet();
77+
// This provides a snapshot, not a live view backed by the map, which differs from standard Map
78+
// contract.
79+
// However, given the complexity of merging two concurrent maps, this is a reasonable compromise
80+
// for this specific implementation.
81+
// TODO: Consider implementing a live view if needed.
82+
Map<String, Object> merged = new ConcurrentHashMap<>(state);
83+
for (Entry<String, Object> entry : delta.entrySet()) {
84+
if (entry.getValue() == REMOVED) {
85+
merged.remove(entry.getKey());
86+
} else {
87+
merged.put(entry.getKey(), entry.getValue());
88+
}
89+
}
90+
return merged.entrySet();
6991
}
7092

7193
@Override
7294
public boolean equals(Object o) {
7395
if (o == this) {
7496
return true;
7597
}
76-
if (!(o instanceof State other)) {
98+
if (!(o instanceof Map)) {
7799
return false;
78100
}
79-
return state.equals(other.state);
101+
Map<?, ?> other = (Map<?, ?>) o;
102+
// We can't easily rely on state.equals() because our "content" is merged.
103+
// Validating equality against another Map requires checking the merged view.
104+
if (size() != other.size()) {
105+
return false;
106+
}
107+
try {
108+
for (Entry<String, Object> e : entrySet()) {
109+
String key = e.getKey();
110+
Object value = e.getValue();
111+
if (value == null) {
112+
if (!(other.get(key) == null && other.containsKey(key))) return false;
113+
} else {
114+
if (!value.equals(other.get(key))) return false;
115+
}
116+
}
117+
} catch (ClassCastException | NullPointerException unused) {
118+
return false;
119+
}
120+
return true;
80121
}
81122

82123
@Override
83124
public Object get(Object key) {
125+
if (delta.containsKey(key)) {
126+
Object value = delta.get(key);
127+
return value == REMOVED ? null : value;
128+
}
84129
return state.get(key);
85130
}
86131

87132
@Override
88133
public int hashCode() {
89-
return state.hashCode();
134+
// Similar to equals, we need to calculate hash code based on the merged entry set.
135+
int h = 0;
136+
for (Entry<String, Object> entry : entrySet()) {
137+
h += entry.hashCode();
138+
}
139+
return h;
90140
}
91141

92142
@Override
93143
public boolean isEmpty() {
94-
return state.isEmpty();
144+
if (delta.isEmpty()) {
145+
return state.isEmpty();
146+
}
147+
// If delta is not empty, we need to check if it effectively removes everything from state
148+
// or adds something.
149+
return size() == 0;
95150
}
96151

97152
@Override
98153
public Set<String> keySet() {
99-
return state.keySet();
154+
// Snapshot view
155+
Map<String, Object> merged = new ConcurrentHashMap<>(state);
156+
for (Entry<String, Object> entry : delta.entrySet()) {
157+
if (entry.getValue() == REMOVED) {
158+
merged.remove(entry.getKey());
159+
} else {
160+
merged.put(entry.getKey(), entry.getValue());
161+
}
162+
}
163+
return merged.keySet();
100164
}
101165

102166
@Override
103167
public Object put(String key, Object value) {
104-
Object oldValue = state.put(key, value);
168+
// Current value logic needs to check delta first to return correct "oldValue"
169+
Object oldValue = get(key);
170+
state.put(key, value);
105171
delta.put(key, value);
106172
return oldValue;
107173
}
108174

109175
@Override
110176
public Object putIfAbsent(String key, Object value) {
111-
Object existingValue = state.putIfAbsent(key, value);
112-
if (existingValue == null) {
113-
delta.put(key, value);
177+
Object currentValue = get(key);
178+
if (currentValue == null) {
179+
put(key, value);
180+
return null;
114181
}
115-
return existingValue;
182+
return currentValue;
116183
}
117184

118185
@Override
@@ -123,47 +190,130 @@ public void putAll(Map<? extends String, ? extends Object> m) {
123190

124191
@Override
125192
public Object remove(Object key) {
126-
if (state.containsKey(key)) {
193+
Object oldValue = get(key);
194+
// Explicitly check for containment in the *merged* view (via get != null or containsKey)
195+
// before marking as removed, though strictly speaking marking as removed is safe even if not
196+
// present.
197+
// But we need to return the correct oldValue.
198+
199+
if (state.containsKey(key) || (delta.containsKey(key) && delta.get(key) != REMOVED)) {
127200
delta.put((String) key, REMOVED);
128201
}
129-
return state.remove(key);
202+
// We should probably NOT remove from state to keep "state" as the committed version?
203+
// The original code did:
204+
// delta.put((String) key, REMOVED);
205+
// return state.remove(key);
206+
//
207+
// If we want state to represent the "base" and delta the "changes", we should probably
208+
// ONLY update delta. However, the original code updated BOTH.
209+
// "A State object that also keeps track of the changes to the state."
210+
// If it updates both, then "state" is presumably the "current live state" AND "delta" is the
211+
// log.
212+
// PROPOSAL: Keep updating both to minimize regression risk on the write path,
213+
// but ensure read path prioritizes delta (which might be redundant if state is always updated,
214+
// UNLESS delta has entries that state doesn't? Or if we want to support a mode where only delta
215+
// is updated?)
216+
//
217+
// Wait, the original code:
218+
// Object oldValue = state.put(key, value);
219+
// delta.put(key, value);
220+
//
221+
// This implies `state` IS updated.
222+
//
223+
// If `state` is updated, why did we need to check `delta` for reads?
224+
// "State.java seems to ignore the delta map for the most part. This looks wrong."
225+
//
226+
// SCENARIO 1: `state` is the "committed" state from DB. `delta` accumulates changes.
227+
// If we update `state` in place, then `get(key)` on `state` returns the new value.
228+
// So why claim it ignores delta?
229+
//
230+
// Maybe `State` is initialized with a `state` map that is SHARED or comes from a source
231+
// that shouldn't be mutated?
232+
//
233+
// OR, maybe the `state` map passed in constructor is NOT updated by some other code path?
234+
//
235+
// Let's re-read the Python code.
236+
// Python:
237+
// self._value[key] = value
238+
// self._delta[key] = value
239+
//
240+
// It updates BOTH.
241+
//
242+
// Python `__getitem__`:
243+
// if key in self._delta: return self._delta[key]
244+
// return self._value[key]
245+
//
246+
// If both are updated, `self._value[key]` should correspond to `self._delta[key]`.
247+
//
248+
// UNLESS `delta` is populated with something that IS NOT in `value`.
249+
// Example: We load a session. State is {a:1}. We create a `State` object.
250+
// We put b:2. state={a:1, b:2}, delta={b:2}.
251+
// get(b) -> state.get(b) -> 2. Correct.
252+
//
253+
// SCENARIO 2: Replay / Deferred execution.
254+
// Maybe someone populates `delta` directly?
255+
// Constructor: `public State(ConcurrentMap<String, Object> state, ConcurrentMap<String, Object>
256+
// delta)`
257+
//
258+
// If I pass a `delta` map that already has stuff, but `state` doesn't.
259+
// State({a:1}, {a:2})
260+
// Java get(a) -> state.get(a) -> 1. WRONG. Should be 2.
261+
//
262+
// Correct. This is the scenario. The `delta` might contain provisional changes not yet applied
263+
// to `state`.
264+
265+
state.remove(key);
266+
return oldValue;
130267
}
131268

132269
@Override
133270
public boolean remove(Object key, Object value) {
134-
boolean removed = state.remove(key, value);
135-
if (removed) {
136-
delta.put((String) key, REMOVED);
271+
Object currentValue = get(key);
272+
if (Objects.equals(currentValue, value) && (currentValue != null || containsKey(key))) {
273+
remove(key);
274+
return true;
137275
}
138-
return removed;
276+
return false;
139277
}
140278

141279
@Override
142280
public boolean replace(String key, Object oldValue, Object newValue) {
143-
boolean replaced = state.replace(key, oldValue, newValue);
144-
if (replaced) {
145-
delta.put(key, newValue);
281+
Object currentValue = get(key);
282+
if (Objects.equals(currentValue, oldValue) && (currentValue != null || containsKey(key))) {
283+
put(key, newValue);
284+
return true;
146285
}
147-
return replaced;
286+
return false;
148287
}
149288

150289
@Override
151290
public Object replace(String key, Object value) {
152-
Object oldValue = state.replace(key, value);
153-
if (oldValue != null) {
154-
delta.put(key, value);
291+
Object currentValue = get(key);
292+
if (currentValue != null || containsKey(key)) {
293+
put(key, value);
294+
return currentValue;
155295
}
156-
return oldValue;
296+
return null;
157297
}
158298

159299
@Override
160300
public int size() {
161-
return state.size();
301+
// Expensive, but accurate merged size.
302+
return entrySet().size();
162303
}
163304

164305
@Override
165306
public Collection<Object> values() {
166-
return state.values();
307+
// Snapshot view
308+
Map<String, Object> merged = new ConcurrentHashMap<>(state);
309+
for (Entry<String, Object> entry : delta.entrySet()) {
310+
if (entry.getValue() == REMOVED) {
311+
merged.remove(entry.getKey());
312+
} else {
313+
merged.put(entry.getKey(), entry.getValue());
314+
}
315+
}
316+
return merged.values();
167317
}
168318

169319
public boolean hasDelta() {

0 commit comments

Comments
 (0)