Skip to content

Commit db17679

Browse files
committed
fix: Pre-merge stateDelta before onUserMessageCallback (#1099)
The stateDelta was not merged into session state before onUserMessageCallback was invoked, causing plugins to see null values when reading caller-provided state entries. This fix pre-merges stateDelta into the session's in-memory state before creating the InvocationContext. Since the session is already a copy from getSession(), this is safe and does not affect the persistence path which still happens via EventActions in appendNewMessageToSession. Added test: onUserMessageCallback_withStateDelta_seesMergedState
1 parent 88eb0f5 commit db17679

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ protected Flowable<Event> runAsyncImpl(
448448
BaseAgent rootAgent = this.agent;
449449
String invocationId = InvocationContext.newInvocationContextId();
450450

451+
// Pre-merge stateDelta so onUserMessageCallback can access it.
452+
// Safe: session is a copy; persistence still happens via appendNewMessageToSession.
453+
if (stateDelta != null && !stateDelta.isEmpty()) {
454+
stateDelta.forEach((key, value) -> session.state().put(key, value));
455+
}
456+
451457
// Create initial context
452458
InvocationContext initialContext =
453459
newInvocationContextBuilder(session)

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,38 @@ public void beforeRunCallback_withStateDelta_seesMergedState() {
877877
assertThat(sessionInCallback.state()).containsEntry("number", 123);
878878
}
879879

880+
@Test
881+
public void onUserMessageCallback_withStateDelta_seesMergedState() {
882+
ArgumentCaptor<InvocationContext> contextCaptor =
883+
ArgumentCaptor.forClass(InvocationContext.class);
884+
when(plugin.onUserMessageCallback(contextCaptor.capture(), any())).thenReturn(Maybe.empty());
885+
886+
ImmutableMap<String, Object> stateDelta =
887+
ImmutableMap.of("callback_key", "callback_value", "number", 123);
888+
889+
var unused =
890+
runner
891+
.runAsync(
892+
"user",
893+
session.id(),
894+
createContent("test with state"),
895+
RunConfig.builder().build(),
896+
stateDelta)
897+
.toList()
898+
.blockingGet();
899+
900+
// Verify onUserMessageCallback was called
901+
verify(plugin).onUserMessageCallback(any(), any());
902+
903+
// Verify the context passed to onUserMessageCallback has the merged state
904+
InvocationContext capturedContext = contextCaptor.getValue();
905+
Session sessionInCallback = capturedContext.session();
906+
907+
// Verify state delta was merged before onUserMessageCallback was invoked
908+
assertThat(sessionInCallback.state()).containsEntry("callback_key", "callback_value");
909+
assertThat(sessionInCallback.state()).containsEntry("number", 123);
910+
}
911+
880912
@Test
881913
public void runAsync_ensureEventsAreAppendedInOrder() throws Exception {
882914
Event event1 = TestUtils.createEvent("1");

0 commit comments

Comments
 (0)