Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 161 additions & 56 deletions ai/src/main/java/com/google/genkit/ai/GenerateAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.genkit.ai.middleware.GenerateNext;
import com.google.genkit.ai.middleware.GenerateParams;
import com.google.genkit.ai.middleware.GenerationMiddleware;
import com.google.genkit.ai.middleware.ModelNext;
import com.google.genkit.ai.middleware.ModelParams;
import com.google.genkit.ai.middleware.ToolNext;
import com.google.genkit.ai.middleware.ToolParams;
import com.google.genkit.ai.telemetry.ModelTelemetryHelper;
import com.google.genkit.core.*;
import com.google.genkit.core.tracing.SpanMetadata;
Expand Down Expand Up @@ -94,42 +101,62 @@ public ModelResponse run(
throw new GenkitException("GenerateActionOptions cannot be null");
}

// Resolve middleware names from the registry. The Dev UI's Middleware panel sends
// selected middleware as a list of names in `options.use`. Each name is looked up
// in the "middleware" value bucket (registered via Genkit.Builder.middleware(...)).
// Mirrors the JS SDK's resolveMiddleware() in js/ai/src/generate/action.ts.
final List<GenerationMiddleware> middlewares = resolveMiddlewares(options.getUse());

// Core: run the full tool loop (possibly streaming). model.run() inside is wrapped
// by the wrapModel chain; tool execution is wrapped by the wrapTool chain.
GenerateNext core =
(gctx, gparams) ->
runIterations(gctx, gparams.getRequest(), gparams.getOnChunk(), middlewares);

// Outermost: wrapGenerate chain
GenerateNext chain = chainGenerate(middlewares, core);

int initialMsgIdx = options.getMessages() != null ? options.getMessages().size() : 0;
return chain.apply(ctx, new GenerateParams(options, 0, initialMsgIdx, streamCallback));
}

/**
* Executes the tool-call loop. Each iteration wraps the model call in the {@code wrapModel}
* middleware chain and tool execution in the {@code wrapTool} chain.
*/
private ModelResponse runIterations(
ActionContext ctx,
GenerateActionOptions options,
Consumer<ModelResponseChunk> streamCallback,
List<GenerationMiddleware> middlewares)
throws GenkitException {

String modelName = options.getModel();
if (modelName == null || modelName.isEmpty()) {
throw new GenkitException("Model name is required");
}

// Resolve the model action key
String modelKey = resolveModelKey(modelName);

// Look up the model in the registry
Action<?, ?, ?> action = registry.lookupAction(modelKey);
if (action == null) {
throw new GenkitException("Model not found: " + modelName + " (key: " + modelKey + ")");
}

if (!(action instanceof Model)) {
throw new GenkitException("Action is not a model: " + modelKey);
}
final Model model = (Model) action;

Model model = (Model) action;

// Build the model request from the options
ModelRequest request = buildModelRequest(options);

logger.debug("Generating with model: {}", modelKey);

// Determine if we should return tool requests without executing them
boolean returnToolRequests = Boolean.TRUE.equals(options.getReturnToolRequests());

// Get max turns for tool loop (default to 5)
int maxTurns = options.getMaxTurns() != null ? options.getMaxTurns() : 5;
int turn = 0;

String flowName = ctx.getFlowName();

while (turn < maxTurns) {
// Create span metadata for the model call
SpanMetadata spanMetadata =
SpanMetadata.builder()
.name(modelName)
Expand All @@ -144,57 +171,54 @@ public ModelResponse run(
final ModelRequest currentRequest = request;
final String spanPath = "/generate/" + modelName;

// Run the model wrapped in a span
// Run the model wrapped in a span and through the wrapModel middleware chain.
ModelResponse response =
Tracer.runInNewSpan(
ctx,
spanMetadata,
request,
(spanCtx, req) -> {
ActionContext newCtx = ctx.withSpanContext(spanCtx);
if (streamCallback != null && model.supportsStreaming()) {
return ModelTelemetryHelper.runWithTelemetryStreaming(
modelName,
flowName,
spanPath,
currentRequest,
r -> model.run(newCtx, r, streamCallback));
} else {
return ModelTelemetryHelper.runWithTelemetry(
modelName, flowName, spanPath, currentRequest, r -> model.run(newCtx, r));
}
ModelNext modelCore =
(mctx, mparams) -> {
ModelRequest mreq = mparams.getRequest();
Consumer<ModelResponseChunk> sc = mparams.getStreamCallback();
if (sc != null && model.supportsStreaming()) {
return ModelTelemetryHelper.runWithTelemetryStreaming(
modelName, flowName, spanPath, mreq, r -> model.run(mctx, r, sc));
} else {
return ModelTelemetryHelper.runWithTelemetry(
modelName, flowName, spanPath, mreq, r -> model.run(mctx, r));
}
};
ModelNext wrappedModel = chainModel(middlewares, modelCore);
return wrappedModel.apply(newCtx, new ModelParams(currentRequest, streamCallback));
});

// Check if the model requested tool calls
List<Part> toolRequestParts = extractToolRequestParts(response);

// If no tool requests or we should return them without executing, return
// response
if (toolRequestParts.isEmpty() || returnToolRequests) {
return response;
}

// Check if we have tools to execute
if (options.getTools() == null || options.getTools().isEmpty()) {
// No tools available, return response with tool requests
return response;
}

// Execute tools
List<Part> toolResponseParts = executeTools(ctx, toolRequestParts, options.getTools());
// Execute tools through the wrapTool chain
List<Part> toolResponseParts =
executeTools(ctx, toolRequestParts, options.getTools(), middlewares);

// Add the assistant message with tool requests
Message assistantMessage = response.getMessage();
List<Message> updatedMessages = new ArrayList<>(request.getMessages());
updatedMessages.add(assistantMessage);

// Add tool response message
Message toolResponseMessage = new Message();
toolResponseMessage.setRole(Role.TOOL);
toolResponseMessage.setContent(toolResponseParts);
updatedMessages.add(toolResponseMessage);

// Update request with new messages for next turn
request =
ModelRequest.builder()
.messages(updatedMessages)
Expand All @@ -209,6 +233,80 @@ public ModelResponse run(
throw new GenkitException("Max tool execution turns (" + maxTurns + ") exceeded");
}

/**
* Resolves middleware references to fresh per-call middleware instances by looking them up in the
* registry's {@code "middleware"} value bucket. Accepts either bare JSON strings or objects with
* a {@code name} field (the shape the Dev UI's Middleware panel sends). Unknown names are logged
* and skipped.
*/
private List<GenerationMiddleware> resolveMiddlewares(List<JsonNode> refs) {
if (refs == null || refs.isEmpty()) {
return List.of();
}
List<GenerationMiddleware> resolved = new ArrayList<>(refs.size());
for (JsonNode ref : refs) {
if (ref == null || ref.isNull()) continue;
String name;
if (ref.isTextual()) {
name = ref.asText();
} else if (ref.isObject() && ref.hasNonNull("name")) {
name = ref.get("name").asText();
} else {
logger.warn("Unrecognized middleware reference shape: {}", ref);
continue;
}
if (name == null || name.isEmpty()) continue;
Object value = registry.lookupValue("middleware", name);
if (value instanceof GenerationMiddleware) {
// Use a fresh instance per call so middleware state is per-invocation.
resolved.add(((GenerationMiddleware) value).newInstance());
} else {
logger.warn(
"Middleware '{}' was requested but is not registered. "
+ "Register via Genkit.Builder.middleware(...).",
name);
}
}
return resolved;
}

/** Chains wrapGenerate hooks. First middleware is outermost. */
private static GenerateNext chainGenerate(
List<GenerationMiddleware> middlewares, GenerateNext core) {
if (middlewares.isEmpty()) return core;
GenerateNext current = core;
for (int i = middlewares.size() - 1; i >= 0; i--) {
final GenerationMiddleware mw = middlewares.get(i);
final GenerateNext next = current;
current = (ctx, params) -> mw.wrapGenerate(ctx, params, next);
}
return current;
}

/** Chains wrapModel hooks. First middleware is outermost. */
private static ModelNext chainModel(List<GenerationMiddleware> middlewares, ModelNext core) {
if (middlewares.isEmpty()) return core;
ModelNext current = core;
for (int i = middlewares.size() - 1; i >= 0; i--) {
final GenerationMiddleware mw = middlewares.get(i);
final ModelNext next = current;
current = (ctx, params) -> mw.wrapModel(ctx, params, next);
}
return current;
}

/** Chains wrapTool hooks. First middleware is outermost. */
private static ToolNext chainTool(List<GenerationMiddleware> middlewares, ToolNext core) {
if (middlewares.isEmpty()) return core;
ToolNext current = core;
for (int i = middlewares.size() - 1; i >= 0; i--) {
final GenerationMiddleware mw = middlewares.get(i);
final ToolNext next = current;
current = (ctx, params) -> mw.wrapTool(ctx, params, next);
}
return current;
}

/** Extracts tool request parts from a model response. */
private List<Part> extractToolRequestParts(ModelResponse response) {
List<Part> toolRequestParts = new ArrayList<>();
Expand All @@ -224,20 +322,45 @@ private List<Part> extractToolRequestParts(ModelResponse response) {
return toolRequestParts;
}

/** Executes tools and returns the response parts. */
/** Executes tools through the wrapTool middleware chain and returns the response parts. */
private List<Part> executeTools(
ActionContext ctx, List<Part> toolRequestParts, List<String> toolNames) {
ActionContext ctx,
List<Part> toolRequestParts,
List<String> toolNames,
List<GenerationMiddleware> middlewares) {
List<Part> responseParts = new ArrayList<>();

// Core tool invocation — runs after all wrapTool middleware
ToolNext toolCore =
(tctx, tparams) -> {
Tool<?, ?> tool = tparams.getTool();
ToolRequest toolReq = tparams.getRequest();
Object toolInput = toolReq.getInput();

// Convert input if necessary
if (toolInput instanceof Map
&& tool.getInputClass() != null
&& !Map.class.isAssignableFrom(tool.getInputClass())) {
toolInput = objectMapper.convertValue(toolInput, tool.getInputClass());
}
Comment on lines +340 to +345
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a local objectMapper directly to convert the tool input can lead to deserialization failures if the input contains custom types, custom date formats, or custom naming strategies registered globally. It is highly recommended to use JsonUtils.convert(...) instead, which is consistent with Genkit.java and uses the centrally configured object mapper.

          // Convert input if necessary
          Class<?> inputClass = tool.getInputClass();
          if (inputClass != null && toolInput != null && !inputClass.isInstance(toolInput)) {
            toolInput = JsonUtils.convert(toolInput, inputClass);
          }


@SuppressWarnings("unchecked")
Tool<Object, Object> typedTool = (Tool<Object, Object>) tool;
Object result = typedTool.run(tctx, toolInput);

Part responsePart = new Part();
responsePart.setToolResponse(
new ToolResponse(toolReq.getRef(), toolReq.getName(), result));
return responsePart;
};
ToolNext wrappedTool = chainTool(middlewares, toolCore);

for (Part toolRequestPart : toolRequestParts) {
ToolRequest toolRequest = toolRequestPart.getToolRequest();
String toolName = toolRequest.getName();
Object toolInput = toolRequest.getInput();

// Find the tool
Tool<?, ?> tool = findTool(toolName, toolNames);
if (tool == null) {
// Tool not found, create an error response
Part errorPart = new Part();
ToolResponse errorResponse =
new ToolResponse(
Expand All @@ -249,26 +372,8 @@ private List<Part> executeTools(
}

try {
// Execute the tool
@SuppressWarnings("unchecked")
Tool<Object, Object> typedTool = (Tool<Object, Object>) tool;

// Convert input if necessary
Object convertedInput = toolInput;
if (toolInput instanceof Map
&& tool.getInputClass() != null
&& !Map.class.isAssignableFrom(tool.getInputClass())) {
convertedInput = objectMapper.convertValue(toolInput, tool.getInputClass());
}

Object result = typedTool.run(ctx, convertedInput);

// Create tool response part
Part responsePart = new Part();
ToolResponse toolResponse = new ToolResponse(toolRequest.getRef(), toolName, result);
responsePart.setToolResponse(toolResponse);
Part responsePart = wrappedTool.apply(ctx, new ToolParams(toolRequestPart, tool));
responseParts.add(responsePart);

logger.debug("Executed tool '{}' successfully", toolName);
} catch (Exception e) {
logger.error("Tool execution failed for '{}': {}", toolName, e.getMessage());
Expand Down
27 changes: 27 additions & 0 deletions ai/src/main/java/com/google/genkit/ai/GenerateActionOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -81,6 +82,23 @@ public class GenerateActionOptions {
@JsonProperty("stepName")
private String stepName;

/**
* Middleware references. In the JS SDK, this is a list of {@code ModelMiddleware} functions or
* name strings. The Dev UI populates this field with the names of middlewares the user has
* selected in the Middleware panel (which correspond to middlewares registered via {@code
* Genkit.Builder.middleware(...)} under the {@code "middleware"} value bucket).
*
* <p>At runtime, {@link GenerateAction#run} resolves each name via {@code
* registry.lookupValue("middleware", name)} and dispatches the {@code wrapGenerate}/{@code
* wrapModel}/{@code wrapTool} hooks around the model invocation.
*/
/**
* Each element may be a JSON string (middleware name) or a JSON object with a {@code name} field,
* matching the shape sent by the Dev UI's Middleware panel.
*/
@JsonProperty("use")
private List<JsonNode> use;

/** Default constructor for JSON deserialization. */
public GenerateActionOptions() {}

Expand Down Expand Up @@ -109,6 +127,7 @@ public GenerateActionOptions withMessages(List<Message> newMessages) {
copy.returnToolRequests = this.returnToolRequests;
copy.maxTurns = this.maxTurns;
copy.stepName = this.stepName;
copy.use = this.use;
return copy;
}

Expand Down Expand Up @@ -216,4 +235,12 @@ public String getStepName() {
public void setStepName(String stepName) {
this.stepName = stepName;
}

public List<JsonNode> getUse() {
return use;
}

public void setUse(List<JsonNode> use) {
this.use = use;
}
}
Loading
Loading