Skip to content

Commit f4cd1b7

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Forward state delta from all events to parent session instead of just the last event
PiperOrigin-RevId: 900231745
1 parent 1a3dd61 commit f4cd1b7

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

core/src/main/java/com/google/adk/tools/AgentTool.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
166166
.sessionService()
167167
.createSession(toolContext.agentName(), "tmp-user", toolContext.state(), null)
168168
.flatMapPublisher(session -> runner.runAsync(session.userId(), session.id(), content))
169+
.doOnNext(
170+
event -> {
171+
if (event.actions() != null
172+
&& event.actions().stateDelta() != null
173+
&& !event.actions().stateDelta().isEmpty()) {
174+
updateState(event.actions().stateDelta(), toolContext.state());
175+
}
176+
})
169177
.lastElement()
170178
.map(Optional::of)
171179
.defaultIfEmpty(Optional.empty())
@@ -177,13 +185,6 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
177185
Event lastEvent = optionalLastEvent.get();
178186
Optional<String> outputText = lastEvent.content().map(Content::text);
179187

180-
// Forward state delta to parent session.
181-
if (lastEvent.actions() != null
182-
&& lastEvent.actions().stateDelta() != null
183-
&& !lastEvent.actions().stateDelta().isEmpty()) {
184-
updateState(lastEvent.actions().stateDelta(), toolContext.state());
185-
}
186-
187188
if (outputText.isEmpty()) {
188189
return ImmutableMap.of();
189190
}

core/src/test/java/com/google/adk/tools/AgentToolTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,46 @@ public void call_withSkipSummarizationAndStateDelta_propagatesStateAndSetsSkipSu
495495
assertThat(toolContext.actions().skipSummarization()).hasValue(true);
496496
}
497497

498+
@Test
499+
public void call_withMultipleStateDeltasInResponse_propagatesAllStateDeltas() throws Exception {
500+
AfterAgentCallback firstCallback =
501+
(callbackContext) -> {
502+
callbackContext.state().put("key1", "val1");
503+
return Maybe.empty();
504+
};
505+
AfterAgentCallback secondCallback =
506+
(callbackContext) -> {
507+
callbackContext.state().put("key2", "val2");
508+
return Maybe.empty();
509+
};
510+
LlmAgent firstAgent =
511+
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
512+
.name("first_agent")
513+
.afterAgentCallback(firstCallback)
514+
.build();
515+
LlmAgent secondAgent =
516+
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
517+
.name("second_agent")
518+
.afterAgentCallback(secondCallback)
519+
.build();
520+
SequentialAgent sequentialAgent =
521+
SequentialAgent.builder()
522+
.name("sequence")
523+
.description("Process the query through multiple steps")
524+
.subAgents(ImmutableList.of(firstAgent, secondAgent))
525+
.build();
526+
ToolContext toolContext = createToolContext(sequentialAgent);
527+
assertThat(toolContext.state()).isEmpty();
528+
529+
Map<String, Object> unused =
530+
AgentTool.create(sequentialAgent)
531+
.runAsync(ImmutableMap.of("request", "test"), toolContext)
532+
.blockingGet();
533+
534+
assertThat(toolContext.state()).containsEntry("key1", "val1");
535+
assertThat(toolContext.state()).containsEntry("key2", "val2");
536+
}
537+
498538
@Test
499539
public void
500540
declaration_sequentialAgentWithFirstSubAgentInputSchema_returnsDeclarationWithSchema() {

0 commit comments

Comments
 (0)