diff --git a/src/services/file_search_tool.rs b/src/services/file_search_tool.rs index 2fcf61f5..1efdb609 100644 --- a/src/services/file_search_tool.rs +++ b/src/services/file_search_tool.rs @@ -175,9 +175,11 @@ impl FileSearchToolArguments { /// Parse arguments from a JSON string (as received from model output). /// - /// Returns `None` if parsing fails or if required fields are missing. - pub fn parse(arguments_json: &str) -> Option { - serde_json::from_str(arguments_json).ok() + /// Returns `Err` if parsing fails or if required fields are missing; the + /// caller turns that into a spec-shaped failure rather than dropping the + /// call. + pub fn parse(arguments_json: &str) -> Result { + serde_json::from_str(arguments_json) } /// Generate a complete OpenAI-compatible function tool definition. @@ -1001,7 +1003,7 @@ fn build_file_search_call_output( /// /// The format matches OpenAI's Responses API streaming format where each /// output item is sent as an `response.output_item.done` event. -fn format_file_search_call_sse_event(output: &FileSearchCallOutput) -> Option { +fn format_file_search_call_sse_event(output: &FileSearchCallOutput) -> Bytes { // Create the SSE event data // OpenAI sends output items as part of the response stream with type "response.output_item.done" let event_data = serde_json::json!({ @@ -1010,9 +1012,52 @@ fn format_file_search_call_sse_event(output: &FileSearchCallOutput) -> Option crate::services::server_tools::ToolExecutionHandle { + let id = call_id.to_string(); + let failed_item = FileSearchCallOutput { + type_: FileSearchCallOutputType::FileSearchCall, + id: id.clone(), + queries: Vec::new(), + status: WebSearchStatus::Failed, + results: None, + }; + let events = vec![ + format_file_search_in_progress_event(&id, 0), + format_file_search_call_sse_event(&failed_item), + ]; + + let continuation_item = ResponsesInputItem::FunctionCallOutput(FunctionCallOutput { + type_: FunctionCallOutputType::FunctionCallOutput, + id: Some(id.clone()), + call_id: id.clone(), + output: crate::services::server_tools::invalid_arguments_text("file_search", error), + status: None, + }); + let result = crate::services::server_tools::ToolCallResult { + call_id: id, + continuation_items: vec![continuation_item], + }; + + crate::services::server_tools::ToolExecutionHandle { + events: Box::pin(futures_util::stream::iter(events)), + result: Box::pin(async move { Ok(result) }), + } } /// Format a `response.file_search_call.in_progress` SSE event. @@ -1143,10 +1188,22 @@ fn inject_citation_annotations(chunk: &[u8], tracker: &CitationTracker) -> Bytes /// /// All fields except `query` are optional. This function uses [`FileSearchToolArguments::parse()`] /// to deserialize the arguments, ensuring consistency with the schema sent to the model. +/// Outcome of inspecting a `function_call` item named `file_search`. +/// +/// `Invalid` carries the call id and reason so the executor can synthesize +/// a `file_search_call` with status `failed` rather than dropping the call +/// (which would strand the loop) or aborting the whole turn. `None` means +/// the item is not a file_search call and should pass through untouched. +#[derive(Debug, Clone)] +pub enum FileSearchCallDetection { + Valid(Box), + Invalid { id: String, error: String }, +} + pub fn parse_file_search_tool_call( value: &Value, vector_store_ids: &[String], -) -> Option { +) -> Option { // Check if this is a function call let obj = value.as_object()?; @@ -1172,17 +1229,23 @@ pub fn parse_file_search_tool_call( // Parse arguments using the schema-defined struct let arguments_str = obj.get("arguments")?.as_str()?; - let args = FileSearchToolArguments::parse(arguments_str)?; - - Some(FileSearchToolCall { - id, - query: args.query, - vector_store_ids: vector_store_ids.to_vec(), - max_num_results: args.max_num_results.map(|v| v as usize), - score_threshold: args.score_threshold, - filters: args.filters, - ranking_options: args.ranking_options, - }) + match FileSearchToolArguments::parse(arguments_str) { + Ok(args) => Some(FileSearchCallDetection::Valid(Box::new( + FileSearchToolCall { + id, + query: args.query, + vector_store_ids: vector_store_ids.to_vec(), + max_num_results: args.max_num_results.map(|v| v as usize), + score_threshold: args.score_threshold, + filters: args.filters, + ranking_options: args.ranking_options, + }, + ))), + Err(e) => Some(FileSearchCallDetection::Invalid { + id, + error: format!("could not parse `arguments` (expected {{\"query\": \"...\"}}): {e}"), + }), + } } /// Check if a streaming chunk contains file_search tool calls. @@ -1192,7 +1255,7 @@ pub fn parse_file_search_tool_call( pub fn detect_file_search_in_chunk( chunk: &[u8], vector_store_ids: &[String], -) -> Vec { +) -> Vec { let Some(chunk_str) = std::str::from_utf8(chunk).ok() else { return Vec::new(); }; @@ -1238,18 +1301,26 @@ pub fn detect_file_search_in_chunk( .and_then(|v| v.as_str()) .unwrap_or("unknown") .to_string(); - if let Some(arguments_str) = json.get("arguments").and_then(|a| a.as_str()) - && let Some(args) = FileSearchToolArguments::parse(arguments_str) - { - found_calls.push(FileSearchToolCall { - id, - query: args.query, - vector_store_ids: vector_store_ids.to_vec(), - max_num_results: args.max_num_results.map(|v| v as usize), - score_threshold: args.score_threshold, - filters: args.filters, - ranking_options: args.ranking_options, - }); + if let Some(arguments_str) = json.get("arguments").and_then(|a| a.as_str()) { + match FileSearchToolArguments::parse(arguments_str) { + Ok(args) => found_calls.push(FileSearchCallDetection::Valid(Box::new( + FileSearchToolCall { + id, + query: args.query, + vector_store_ids: vector_store_ids.to_vec(), + max_num_results: args.max_num_results.map(|v| v as usize), + score_threshold: args.score_threshold, + filters: args.filters, + ranking_options: args.ranking_options, + }, + ))), + Err(e) => found_calls.push(FileSearchCallDetection::Invalid { + id, + error: format!( + "could not parse `arguments` (expected {{\"query\": \"...\"}}): {e}" + ), + }), + } } } @@ -1315,8 +1386,10 @@ pub fn check_response_for_file_search( // Check output array (Responses API format) if let Some(output) = json.get("output").and_then(|o| o.as_array()) { for item in output { - if let Some(tool_call) = parse_file_search_tool_call(item, vector_store_ids) { - tool_calls.push(tool_call); + if let Some(FileSearchCallDetection::Valid(tool_call)) = + parse_file_search_tool_call(item, vector_store_ids) + { + tool_calls.push(*tool_call); } } } @@ -1328,8 +1401,10 @@ pub fn check_response_for_file_search( && let Some(tc_array) = message.get("tool_calls").and_then(|t| t.as_array()) { for tc in tc_array { - if let Some(tool_call) = parse_file_search_tool_call(tc, vector_store_ids) { - tool_calls.push(tool_call); + if let Some(FileSearchCallDetection::Valid(tool_call)) = + parse_file_search_tool_call(tc, vector_store_ids) + { + tool_calls.push(*tool_call); } } } @@ -1404,10 +1479,21 @@ impl crate::services::server_tools::ServerExecutedTool for FileSearchExecutor { let vector_store_ids = self.context.get_vector_store_ids(); detect_file_search_in_chunk(event, &vector_store_ids) .into_iter() - .map(|tc| crate::services::server_tools::DetectedToolCall { - tool_name: "file_search", - call_id: tc.id.clone(), - arguments: serde_json::to_value(&tc).unwrap_or(Value::Null), + .map(|detection| match detection { + FileSearchCallDetection::Valid(tc) => { + crate::services::server_tools::DetectedToolCall::new( + "file_search", + tc.id.clone(), + serde_json::to_value(&*tc).unwrap_or(Value::Null), + ) + } + FileSearchCallDetection::Invalid { id, error } => { + crate::services::server_tools::DetectedToolCall::invalid( + "file_search", + id, + error, + ) + } }) .collect() } @@ -1420,6 +1506,12 @@ impl crate::services::server_tools::ServerExecutedTool for FileSearchExecutor { crate::services::server_tools::ToolExecutionHandle, crate::services::server_tools::ToolError, > { + // The model emitted a `file_search` call we recognized but couldn't + // parse. Surface a `file_search_call` with status `failed` and feed + // the error back so the loop continues — never drop it or abort. + if let Some(error) = &call.invalid { + return Ok(synthesize_file_search_invalid_handle(&call.call_id, error)); + } let tool_call: FileSearchToolCall = serde_json::from_value(call.arguments).map_err(|e| { crate::services::server_tools::ToolError::InvalidCall(format!( @@ -1536,9 +1628,9 @@ impl crate::services::server_tools::ServerExecutedTool for FileSearchExecutor { raw, include_results, ); - if let Some(evt) = format_file_search_call_sse_event(&call_output) { - let _ = event_tx.send(evt).await; - } + let _ = event_tx + .send(format_file_search_call_sse_event(&call_output)) + .await; } // Emit the completed event. @@ -1666,6 +1758,26 @@ mod tests { use super::*; use crate::streaming::SseBuffer; + /// Unwrap a detection to its valid call, panicking if it's invalid — + /// keeps the detection-shape tests terse. + fn expect_valid(detection: FileSearchCallDetection) -> FileSearchToolCall { + match detection { + FileSearchCallDetection::Valid(tc) => *tc, + FileSearchCallDetection::Invalid { error, .. } => { + panic!("expected a valid file_search call, got invalid: {error}") + } + } + } + + /// Like [`detect_file_search_in_chunk`] but unwraps every detection to a + /// valid call for the happy-path detection tests. + fn detect_valid(chunk: &[u8], vector_store_ids: &[String]) -> Vec { + detect_file_search_in_chunk(chunk, vector_store_ids) + .into_iter() + .map(expect_valid) + .collect() + } + #[test] fn test_parse_file_search_tool_call() { let json = serde_json::json!({ @@ -1679,7 +1791,7 @@ mod tests { let result = parse_file_search_tool_call(&json, &vector_store_ids); assert!(result.is_some()); - let tool_call = result.unwrap(); + let tool_call = expect_valid(result.unwrap()); assert_eq!(tool_call.id, "call_123"); assert_eq!(tool_call.query, "revenue growth in Q3"); assert_eq!(tool_call.vector_store_ids, vector_store_ids); @@ -1698,6 +1810,24 @@ mod tests { assert!(result.is_none()); } + #[test] + fn test_parse_file_search_tool_call_invalid_arguments() { + // `query` is a number, not a string → fails to deserialize. + let json = serde_json::json!({ + "type": "function_call", + "name": "file_search", + "call_id": "call_bad", + "arguments": "{\"query\": 123}" + }); + let FileSearchCallDetection::Invalid { id, error } = + parse_file_search_tool_call(&json, &["vs_abc".to_string()]).unwrap() + else { + panic!("expected an invalid file_search call"); + }; + assert_eq!(id, "call_bad"); + assert!(!error.is_empty()); + } + #[test] fn test_parse_file_search_tool_call_with_max_results() { let json = serde_json::json!({ @@ -1708,7 +1838,7 @@ mod tests { }); let vector_store_ids = vec!["vs_abc".to_string()]; - let result = parse_file_search_tool_call(&json, &vector_store_ids).unwrap(); + let result = expect_valid(parse_file_search_tool_call(&json, &vector_store_ids).unwrap()); assert_eq!(result.max_num_results, Some(5)); } @@ -1723,7 +1853,7 @@ mod tests { }); let vector_store_ids = vec!["vs_abc".to_string()]; - let result = parse_file_search_tool_call(&json, &vector_store_ids).unwrap(); + let result = expect_valid(parse_file_search_tool_call(&json, &vector_store_ids).unwrap()); assert_eq!(result.query, "policy document"); assert_eq!(result.score_threshold, Some(0.85)); @@ -1740,7 +1870,7 @@ mod tests { }); let vector_store_ids = vec!["vs_abc".to_string()]; - let result = parse_file_search_tool_call(&json, &vector_store_ids).unwrap(); + let result = expect_valid(parse_file_search_tool_call(&json, &vector_store_ids).unwrap()); assert_eq!(result.query, "budget report"); assert!(result.filters.is_some()); @@ -1765,7 +1895,7 @@ mod tests { }); let vector_store_ids = vec!["vs_abc".to_string()]; - let result = parse_file_search_tool_call(&json, &vector_store_ids).unwrap(); + let result = expect_valid(parse_file_search_tool_call(&json, &vector_store_ids).unwrap()); assert_eq!(result.query, "meeting notes"); assert!(result.filters.is_some()); @@ -1789,7 +1919,7 @@ mod tests { }); let vector_store_ids = vec!["vs_abc".to_string(), "vs_def".to_string()]; - let result = parse_file_search_tool_call(&json, &vector_store_ids).unwrap(); + let result = expect_valid(parse_file_search_tool_call(&json, &vector_store_ids).unwrap()); assert_eq!(result.id, "call_full"); assert_eq!(result.query, "quarterly earnings"); @@ -1810,7 +1940,7 @@ mod tests { }); let vector_store_ids = vec!["vs_abc".to_string()]; - let result = parse_file_search_tool_call(&json, &vector_store_ids).unwrap(); + let result = expect_valid(parse_file_search_tool_call(&json, &vector_store_ids).unwrap()); assert_eq!(result.query, "simple search"); assert!(result.max_num_results.is_none()); @@ -1823,7 +1953,7 @@ mod tests { let chunk = b"data: {\"output\": [{\"type\": \"function_call\", \"name\": \"file_search\", \"call_id\": \"call_abc\", \"arguments\": \"{\\\"query\\\": \\\"test query\\\"}\"}]}\n\n"; let vector_store_ids = vec!["vs_test".to_string()]; - let results = detect_file_search_in_chunk(chunk, &vector_store_ids); + let results = detect_valid(chunk, &vector_store_ids); assert_eq!(results.len(), 1); assert_eq!(results[0].id, "call_abc"); @@ -1855,7 +1985,7 @@ mod tests { "#; let vector_store_ids = vec!["vs_test".to_string()]; - let results = detect_file_search_in_chunk(chunk, &vector_store_ids); + let results = detect_valid(chunk, &vector_store_ids); assert_eq!(results.len(), 1); assert_eq!(results[0].id, "fc_abc123"); @@ -1871,7 +2001,7 @@ mod tests { "#; - let results = detect_file_search_in_chunk(chunk, &["vs_123".to_string()]); + let results = detect_valid(chunk, &["vs_123".to_string()]); assert_eq!(results.len(), 1); assert_eq!(results[0].query, "quarterly sales"); @@ -1897,7 +2027,7 @@ mod tests { "#; let vector_store_ids = vec!["vs_prod".to_string()]; - let results = detect_file_search_in_chunk(chunk, &vector_store_ids); + let results = detect_valid(chunk, &vector_store_ids); assert_eq!(results.len(), 1); // parse_file_search_tool_call prefers call_id over id @@ -1936,7 +2066,7 @@ mod tests { "#; let vector_store_ids = vec!["vs_finance".to_string()]; - let results = detect_file_search_in_chunk(chunk, &vector_store_ids); + let results = detect_valid(chunk, &vector_store_ids); assert_eq!(results.len(), 2); assert_eq!(results[0].id, "call_1"); @@ -1953,7 +2083,7 @@ mod tests { "#; let vector_store_ids = vec!["vs_data".to_string()]; - let results = detect_file_search_in_chunk(chunk, &vector_store_ids); + let results = detect_valid(chunk, &vector_store_ids); // Should only detect the 2 file_search calls, not get_weather or message assert_eq!(results.len(), 2); @@ -2164,7 +2294,7 @@ mod tests { results: None, }; - let sse_event = format_file_search_call_sse_event(&output).unwrap(); + let sse_event = format_file_search_call_sse_event(&output); let event_str = std::str::from_utf8(&sse_event).unwrap(); // Check SSE format diff --git a/src/services/mcp/executor.rs b/src/services/mcp/executor.rs index 591bbe44..ffbb1476 100644 --- a/src/services/mcp/executor.rs +++ b/src/services/mcp/executor.rs @@ -343,7 +343,13 @@ impl McpExecutor { error = %err, "MCP approval gate failing closed" ); - return self.synthesize_failed_call(binding, call_id, tool_name, arguments, err); + return self.synthesize_failed_call( + &binding.server_label, + call_id, + tool_name, + arguments, + err, + ); }; let approval_id = format!("mcpr_{}", uuid::Uuid::new_v4().simple()); @@ -468,7 +474,7 @@ impl McpExecutor { /// error JSON so the model sees a clean refusal. fn synthesize_failed_call( &self, - binding: &ServerBinding, + server_label: &str, call_id: &str, tool_name: &str, raw_args: &Value, @@ -482,7 +488,7 @@ impl McpExecutor { let failed_item = mcp_call_item( &item_id, - &binding.server_label, + server_label, tool_name, raw_args, "failed", @@ -994,6 +1000,27 @@ impl ServerExecutedTool for McpExecutor { .cloned() .unwrap_or(Value::Null); + // Recognized MCP call whose arguments couldn't be parsed: render a + // spec-shaped `mcp_call` failure and feed the error back so the loop + // continues — never drop it or dispatch null args to the server. + // Detection verifies the binding before marking a call invalid, so + // `resolve_binding` succeeds in practice; fall back to the sanitized + // label rather than `?` so the contract (always `Ok`, never `Err`) + // holds even for calls built via the public `invalid()` constructor. + if let Some(error) = &call.invalid { + let server_label = self + .resolve_binding(&sanitized_label) + .map(|b| b.server_label.clone()) + .unwrap_or_else(|| sanitized_label.clone()); + return self.synthesize_failed_call( + &server_label, + &call.call_id, + &tool_name, + &raw_args, + crate::services::server_tools::invalid_arguments_text("mcp", error), + ); + } + let binding = self .resolve_binding(&sanitized_label) .cloned() @@ -1186,23 +1213,39 @@ fn detect_in_chunk(chunk: &[u8], bindings: &[ServerBinding]) -> Vec(raw_args_str) { + Ok(parsed_args) => vec![DetectedToolCall::new( + "mcp", + call_id, + serde_json::json!({ + "__mcp_label": sanitized_label, + "__mcp_tool": tool_name, + "__mcp_args": parsed_args, + }), + )], + Err(e) => { + let mut detected = DetectedToolCall::new( + "mcp", + call_id, + serde_json::json!({ + "__mcp_label": sanitized_label, + "__mcp_tool": tool_name, + "__mcp_args": Value::Null, + }), + ); + detected.invalid = Some(format!("could not parse MCP tool `arguments` as JSON: {e}")); + vec![detected] + } + } } /// Concatenate every `data:` field on an SSE event into a single @@ -1862,15 +1905,15 @@ mod tests { let executor = McpExecutor::with_persistence(service, &payload, None, None, DEFAULT_CALL_TIMEOUT_SECS); - let call = DetectedToolCall { - tool_name: "mcp", - call_id: "c1".into(), - arguments: serde_json::json!({ + let call = DetectedToolCall::new( + "mcp", + "c1", + serde_json::json!({ "__mcp_label": "atlassian", "__mcp_tool": "jira_create", "__mcp_args": {"summary": "bug"}, }), - }; + ); let ctx = ToolContext { original_payload: payload, }; @@ -1933,6 +1976,108 @@ mod tests { } } + #[tokio::test] + async fn invalid_arguments_synthesize_failed_call_without_dispatch() { + use futures_util::StreamExt; + + // A call detection marked invalid (unparseable arguments) must + // render a `mcp_call` with status `failed` and feed the error back — + // without dispatching to the server. + let payload: CreateResponsesPayload = serde_json::from_value(serde_json::json!({ + "tools": [{"type":"mcp","server_label":"atlassian","server_url":"https://x"}] + })) + .unwrap(); + let executor = McpExecutor::new(McpService::new(), &payload); + + let mut call = DetectedToolCall::new( + "mcp", + "c1", + serde_json::json!({ + "__mcp_label": "atlassian", + "__mcp_tool": "jira_create", + "__mcp_args": serde_json::Value::Null, + }), + ); + call.invalid = Some("could not parse MCP tool `arguments` as JSON".to_string()); + let ctx = ToolContext { + original_payload: payload, + }; + let handle = executor.execute(call, &ctx).await.expect("handle returned"); + + let mut events = handle.events; + let mut terminal: Option = None; + let mut lifecycle_types: Vec = Vec::new(); + while let Some(bytes) = events.next().await { + let text = std::str::from_utf8(&bytes).unwrap(); + for line in text.lines() { + if let Some(rest) = line.strip_prefix("data:") { + let v: serde_json::Value = serde_json::from_str(rest.trim()).unwrap(); + match v.get("type").and_then(|t| t.as_str()) { + Some(t) if t.starts_with("response.mcp_call.") => { + lifecycle_types.push(t.to_string()); + } + Some("response.output_item.done") => terminal = Some(v["item"].clone()), + _ => {} + } + } + } + } + + assert!( + lifecycle_types + .iter() + .any(|t| t == "response.mcp_call.failed"), + "expected a failed lifecycle, got {lifecycle_types:?}" + ); + let terminal = terminal.expect("a terminal mcp_call item"); + assert_eq!(terminal["status"], "failed"); + assert!( + terminal["error"] + .as_str() + .unwrap() + .contains("Invalid arguments") + ); + + let result = handle.result.await.expect("result resolves"); + assert_eq!(result.call_id, "c1"); + let ResponsesInputItem::FunctionCallOutput(fco) = &result.continuation_items[0] else { + panic!("expected FunctionCallOutput continuation"); + }; + assert_eq!(fco.call_id, "c1"); + assert!(fco.output.contains("error")); + } + + #[tokio::test] + async fn invalid_call_without_resolvable_binding_still_returns_ok() { + // Honour the `execute()` contract unconditionally: a call marked + // invalid via the public `DetectedToolCall::invalid()` constructor + // carries `Value::Null` args (so no `__mcp_label` and no resolvable + // binding). It must still feed back a failed `mcp_call` rather than + // returning `Err` and aborting the turn. + let payload: CreateResponsesPayload = serde_json::from_value(serde_json::json!({ + "tools": [{"type":"mcp","server_label":"atlassian","server_url":"https://x"}] + })) + .unwrap(); + let executor = McpExecutor::new(McpService::new(), &payload); + + let call = DetectedToolCall::invalid("mcp", "c2", "unparseable arguments"); + let ctx = ToolContext { + original_payload: payload, + }; + + let handle = executor + .execute(call, &ctx) + .await + .expect("execute must return Ok even without a resolvable binding"); + let result = handle.result.await.expect("result resolves"); + assert_eq!(result.call_id, "c2"); + let ResponsesInputItem::FunctionCallOutput(fco) = &result.continuation_items[0] else { + panic!("expected FunctionCallOutput continuation"); + }; + assert_eq!(fco.call_id, "c2"); + assert!(fco.output.contains("error")); + } + fn executor() -> McpExecutor { let payload: CreateResponsesPayload = serde_json::from_value(serde_json::json!({ "tools": [{"type":"mcp","server_label":"atlassian","server_url":"https://x"}] diff --git a/src/services/mcp/tool_search/mod.rs b/src/services/mcp/tool_search/mod.rs index 05632314..3a65955b 100644 --- a/src/services/mcp/tool_search/mod.rs +++ b/src/services/mcp/tool_search/mod.rs @@ -257,19 +257,31 @@ impl ServerExecutedTool for ToolSearchExecutor { .get("arguments") .and_then(|v| v.as_str()) .unwrap_or("{}"); - let arguments: Value = serde_json::from_str(args_str).unwrap_or_else(|e| { - tracing::warn!( - error = %e, - arguments = %args_str, - "tool_search call arguments are not valid JSON; treating as empty" - ); - Value::Null - }); - vec![DetectedToolCall { - tool_name: TOOL_SEARCH_FUNCTION_NAME, - call_id, - arguments, - }] + // Distinguish a malformed call (unparseable JSON) from a well-formed + // one that simply matches nothing: the former is surfaced as an + // explicit `failed` the model can recover from; the latter runs + // normally and returns an empty tool list. + match serde_json::from_str::(args_str) { + Ok(arguments) => { + vec![DetectedToolCall::new( + TOOL_SEARCH_FUNCTION_NAME, + call_id, + arguments, + )] + } + Err(e) => { + tracing::warn!( + error = %e, + arguments = %args_str, + "tool_search call arguments are not valid JSON; surfacing as a failed call" + ); + vec![DetectedToolCall::invalid( + TOOL_SEARCH_FUNCTION_NAME, + call_id, + format!("could not parse `arguments` as JSON: {e}"), + )] + } + } } async fn execute( @@ -277,6 +289,57 @@ impl ServerExecutedTool for ToolSearchExecutor { call: DetectedToolCall, _ctx: &ToolContext, ) -> Result { + // A malformed tool_search call (unparseable arguments) is surfaced as + // an explicit `failed` tool_search_call — distinct from a well-formed + // search that matches nothing — plus a function_call_output the model + // can read and correct on its next turn. + if let Some(error) = &call.invalid { + let call_index = self.output_index.fetch_add(1, Ordering::Relaxed); + let call_item_id = next_item_id("ts"); + let failed_item = |status: &str| { + serde_json::json!({ + "type": "tool_search_call", + "id": call_item_id, + "call_id": Value::Null, + "execution": "server", + "arguments": call.arguments, + "status": status, + "error": error, + }) + }; + let events = vec![ + sse_output_item( + "response.output_item.added", + call_index, + self.next_seq(), + failed_item("in_progress"), + ), + sse_output_item( + "response.output_item.done", + call_index, + self.next_seq(), + failed_item("failed"), + ), + ]; + let continuation = ResponsesInputItem::FunctionCallOutput(FunctionCallOutput { + type_: FunctionCallOutputType::FunctionCallOutput, + id: None, + call_id: call.call_id.clone(), + output: crate::services::server_tools::invalid_arguments_text( + TOOL_SEARCH_FUNCTION_NAME, + error, + ), + status: None, + }); + let result = ToolCallResult { + call_id: call.call_id.clone(), + continuation_items: vec![continuation], + }; + return Ok(ToolExecutionHandle { + events: Box::pin(futures_util::stream::iter(events)), + result: Box::pin(async move { Ok(result) }), + }); + } let query = call .arguments .get("query") @@ -653,14 +716,76 @@ mod tests { assert!(calls.is_empty()); } + #[test] + fn detect_marks_malformed_arguments_invalid() { + let exec = executor_with_catalog(); + // Unparseable JSON in `arguments` — distinct from an empty query. + let ev = serde_json::json!({ + "type": "response.output_item.done", + "item": { + "type": "function_call", + "id": "fc_1", + "call_id": "call_bad", + "name": TOOL_SEARCH_FUNCTION_NAME, + "arguments": "{not valid json", + } + }); + let calls = exec.detect(format!("data: {ev}\n\n").as_bytes(), &ctx()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].call_id, "call_bad"); + assert!( + calls[0].invalid.is_some(), + "malformed tool_search args should be marked invalid" + ); + } + #[tokio::test] - async fn execute_emits_call_and_output_items_and_stashes_defs() { + async fn execute_invalid_emits_failed_call() { let exec = executor_with_catalog(); - let call = DetectedToolCall { - tool_name: TOOL_SEARCH_FUNCTION_NAME, - call_id: "call_1".to_string(), - arguments: serde_json::json!({"query": "search jira issues"}), + let call = DetectedToolCall::invalid( + TOOL_SEARCH_FUNCTION_NAME, + "call_bad", + "could not parse `arguments` as JSON", + ); + let handle = exec.execute(call, &ctx()).await.expect("execute"); + let events: Vec = handle.events.collect().await; + let parsed: Vec = events + .iter() + .flat_map(|b| { + String::from_utf8_lossy(b) + .lines() + .filter_map(|l| l.strip_prefix("data:").map(str::trim)) + .map(|d| serde_json::from_str(d).unwrap()) + .collect::>() + }) + .collect(); + // The terminal tool_search_call reports status `failed` with an error. + let done = parsed + .iter() + .find(|e| { + e["type"] == "response.output_item.done" && e["item"]["type"] == "tool_search_call" + }) + .expect("a tool_search_call done event"); + assert_eq!(done["item"]["status"], "failed"); + assert!(done["item"]["error"].as_str().unwrap().contains("parse")); + + // The continuation feeds the error back for self-correction. + let result = handle.result.await.expect("result"); + assert_eq!(result.call_id, "call_bad"); + let ResponsesInputItem::FunctionCallOutput(out) = &result.continuation_items[0] else { + panic!("expected a FunctionCallOutput continuation item"); }; + assert!(out.output.contains("Invalid arguments")); + } + + #[tokio::test] + async fn execute_emits_call_and_output_items_and_stashes_defs() { + let exec = executor_with_catalog(); + let call = DetectedToolCall::new( + TOOL_SEARCH_FUNCTION_NAME, + "call_1", + serde_json::json!({"query": "search jira issues"}), + ); let handle = exec.execute(call, &ctx()).await.expect("execute"); let events: Vec = handle.events.collect().await; let joined: String = events @@ -684,11 +809,11 @@ mod tests { #[tokio::test] async fn apply_to_continuation_injects_discovered_tools_and_output() { let exec = executor_with_catalog(); - let call = DetectedToolCall { - tool_name: TOOL_SEARCH_FUNCTION_NAME, - call_id: "call_1".to_string(), - arguments: serde_json::json!({"query": "search jira"}), - }; + let call = DetectedToolCall::new( + TOOL_SEARCH_FUNCTION_NAME, + "call_1", + serde_json::json!({"query": "search jira"}), + ); let handle = exec.execute(call, &ctx()).await.unwrap(); let _ = handle.events.collect::>().await; let result = handle.result.await.unwrap(); diff --git a/src/services/server_tools/mod.rs b/src/services/server_tools/mod.rs index 055216d0..7149663e 100644 --- a/src/services/server_tools/mod.rs +++ b/src/services/server_tools/mod.rs @@ -195,6 +195,50 @@ pub struct DetectedToolCall { /// Tool-specific arguments payload — JSON value or any other structure /// the tool needs to execute. Each tool's `execute()` interprets this. pub arguments: Value, + /// `Some(error)` when the tool recognized this call by name but could + /// not parse its arguments. A `detect()` MUST mark such a call here + /// rather than dropping it (a dropped call leaves the loop reporting a + /// false `completed` with no feedback to the model). The tool's + /// `execute()` MUST, when this is set, render its spec-shaped failure + /// item (e.g. `shell_call_output` with a non-zero exit, `web_search_call` + /// / `file_search_call` with status `failed`, `mcp_call` with `error`) + /// plus a `function_call_output` carrying the error, and return `Ok` + /// (never `Err`) so the loop continues and the model can self-correct. + pub invalid: Option, +} + +impl DetectedToolCall { + /// A well-formed detected call routed to the tool for execution. + pub fn new(tool_name: &'static str, call_id: impl Into, arguments: Value) -> Self { + Self { + tool_name, + call_id: call_id.into(), + arguments, + invalid: None, + } + } + + /// A call recognized by name but whose arguments could not be parsed. + /// See [`DetectedToolCall::invalid`] for the contract each tool's + /// `execute()` must honor. + pub fn invalid( + tool_name: &'static str, + call_id: impl Into, + error: impl Into, + ) -> Self { + Self { + tool_name, + call_id: call_id.into(), + arguments: Value::Null, + invalid: Some(error.into()), + } + } +} + +/// Standard human-readable message for an unparseable tool call, fed back +/// to the model in the `function_call_output` so it can correct the call. +pub fn invalid_arguments_text(tool_name: &str, error: &str) -> String { + format!("Invalid arguments for tool `{tool_name}`: {error}") } /// Result of executing one tool call. @@ -256,6 +300,12 @@ pub trait ServerExecutedTool: Send + Sync { /// tool's type that it contains. /// /// Called for every event of every iteration. Must be cheap. + /// + /// Contract: when an event carries a call this tool recognizes by name + /// but whose arguments cannot be parsed, return it via + /// [`DetectedToolCall::invalid`] — never silently drop it. Dropping a + /// recognized call makes the loop end as a false `completed` with no + /// feedback, stranding the model. fn detect(&self, event: &[u8], ctx: &ToolContext) -> Vec; /// Execute one detected tool call. @@ -263,6 +313,11 @@ pub trait ServerExecutedTool: Send + Sync { /// Returns a handle exposing progress events plus the final result. /// The orchestrator forwards the events to the client and awaits the /// result to build the continuation payload. + /// + /// Contract: when `call.invalid` is `Some`, emit this tool's spec-shaped + /// failure item and return `Ok` with a `function_call_output` carrying + /// the error in `continuation_items` — do not return `Err` (that aborts + /// the whole turn) and do not run the underlying tool. async fn execute( &self, call: DetectedToolCall, diff --git a/src/services/shell_tool.rs b/src/services/shell_tool.rs index c1ee825c..62d45bd1 100644 --- a/src/services/shell_tool.rs +++ b/src/services/shell_tool.rs @@ -643,8 +643,8 @@ impl ResolvedShellArgs { impl ShellToolArguments { pub const FUNCTION_NAME: &'static str = "shell"; - pub fn parse(arguments_json: &str) -> Option { - serde_json::from_str(arguments_json).ok() + pub fn parse(arguments_json: &str) -> Result { + serde_json::from_str(arguments_json) } /// Resolve the parsed arguments into a flat call shape. Returns @@ -937,7 +937,34 @@ struct ShellToolCall { args: ResolvedShellArgs, } -fn parse_shell_tool_call(value: &Value) -> Option { +/// Outcome of inspecting a `function_call` item named `shell`. +/// +/// `Invalid` carries the call id and a human-readable reason so the +/// executor can synthesize a spec-shaped error `shell_call_output` rather +/// than dropping the call (which would strand the agentic loop reporting a +/// false `completed`). A `None` from [`parse_shell_tool_call`] means the +/// item is genuinely not a shell call and should pass through untouched. +#[derive(Debug, Clone)] +enum ShellCallDetection { + Valid(ShellToolCall), + Invalid { id: String, error: String }, +} + +/// Truncate a raw arguments blob for inclusion in an error message so a +/// pathological payload can't bloat the fed-back error text. +fn truncate_for_error(s: &str) -> String { + const MAX: usize = 256; + if s.len() <= MAX { + return s.to_string(); + } + let mut end = MAX; + while !s.is_char_boundary(end) { + end -= 1; + } + format!("{}…", &s[..end]) +} + +fn parse_shell_tool_call(value: &Value) -> Option { let obj = value.as_object()?; if obj.get("type").and_then(|t| t.as_str())? != "function_call" { return None; @@ -952,11 +979,30 @@ fn parse_shell_tool_call(value: &Value) -> Option { .unwrap_or("unknown") .to_string(); let arguments_str = obj.get("arguments")?.as_str()?; - let args = ShellToolArguments::parse(arguments_str)?.resolve()?; - Some(ShellToolCall { id, args }) + let parsed = match ShellToolArguments::parse(arguments_str) { + Ok(parsed) => parsed, + Err(e) => { + return Some(ShellCallDetection::Invalid { + id, + error: format!( + "could not parse `arguments` into the expected \ + {{\"action\": {{\"commands\": [...]}}}} shape: {e}. Received: {}", + truncate_for_error(arguments_str) + ), + }); + } + }; + match parsed.resolve() { + Some(args) => Some(ShellCallDetection::Valid(ShellToolCall { id, args })), + None => Some(ShellCallDetection::Invalid { + id, + error: "`action.commands` was empty; provide at least one non-empty command line" + .to_string(), + }), + } } -fn detect_shell_in_chunk(chunk: &[u8]) -> Vec { +fn detect_shell_in_chunk(chunk: &[u8]) -> Vec { let Ok(chunk_str) = std::str::from_utf8(chunk) else { return Vec::new(); }; @@ -1176,6 +1222,99 @@ fn format_shell_call_output_item( })) } +/// Build a self-contained [`ToolExecutionHandle`] for a `shell` call whose +/// arguments couldn't be parsed (or had no commands). No container is +/// touched: it synthesizes the spec-shaped `shell_call` + `shell_call_output` +/// pair — modeled as a failed command (`outcome: exit{exit_code: 2}`, the +/// conventional shell "invalid usage" code) with the reason in `stderr`, +/// per OpenAI's "preserve non-zero exit outputs so the model can reason +/// about recovery steps" guidance — and a matching `function_call_output` +/// continuation item so the model sees the error and can retry. +fn synthesize_invalid_args_handle(call_id: &str, error: &str) -> ToolExecutionHandle { + // exit_code 2 == conventional shell "incorrect usage". + const INVALID_ARGS_EXIT_CODE: i32 = 2; + let id = call_id.to_string(); + + // `added` then `done`, output-before-call (same ordering invariant as the + // success path: the call resolving must mean its output is already on the + // wire). `killed: false` + a non-zero exit yields status `completed` with + // an `exit` outcome. + let events = vec![ + format_shell_call_item( + ItemLifecycle::Added, + &id, + &id, + 0, + &[], + None, + None, + None, + None, + "in_progress", + None, + Some("model"), + ), + format_shell_call_output_item( + ItemLifecycle::Added, + &id, + &id, + 0, + 0, + "", + "", + &[], + false, + None, + Some("gateway"), + ), + format_shell_call_output_item( + ItemLifecycle::Done, + &id, + &id, + 0, + INVALID_ARGS_EXIT_CODE, + "", + error, + &[], + false, + None, + Some("gateway"), + ), + format_shell_call_item( + ItemLifecycle::Done, + &id, + &id, + 0, + &[], + None, + None, + None, + None, + "completed", + None, + Some("model"), + ), + ]; + + let combined = format!("exit_code: {INVALID_ARGS_EXIT_CODE}\nstdout:\n\nstderr:\n{error}"); + let cont_item = ResponsesInputItem::FunctionCallOutput(FunctionCallOutput { + type_: FunctionCallOutputType::FunctionCallOutput, + id: Some(id.clone()), + call_id: id.clone(), + output: combined, + status: None, + }); + let result = ToolCallResult { + call_id: id, + continuation_items: vec![cont_item], + }; + + ToolExecutionHandle { + events: Box::pin(futures_util::stream::iter(events)), + result: Box::pin(async move { Ok(result) }), + } +} + /// Emit the spec-canonical `output_item.done` events for both the /// `shell_call` and `shell_call_output` items when a shell call fails /// before producing real output (boot failure, passthrough misconfig, @@ -1699,18 +1838,23 @@ impl ServerExecutedTool for ShellExecutor { fn detect(&self, event: &[u8], _ctx: &ToolContext) -> Vec { detect_shell_in_chunk(event) .into_iter() - .map(|tc| DetectedToolCall { - tool_name: ShellToolArguments::FUNCTION_NAME, - call_id: tc.id.clone(), - arguments: serde_json::json!({ - "id": tc.id, - "commands": tc.args.commands, - "stdin": tc.args.stdin, - "timeout_ms": tc.args.timeout_ms, - "max_output_length": tc.args.max_output_length, - "env": tc.args.env, - "working_directory": tc.args.working_directory, - }), + .map(|detection| match detection { + ShellCallDetection::Valid(tc) => DetectedToolCall::new( + ShellToolArguments::FUNCTION_NAME, + tc.id.clone(), + serde_json::json!({ + "id": tc.id, + "commands": tc.args.commands, + "stdin": tc.args.stdin, + "timeout_ms": tc.args.timeout_ms, + "max_output_length": tc.args.max_output_length, + "env": tc.args.env, + "working_directory": tc.args.working_directory, + }), + ), + ShellCallDetection::Invalid { id, error } => { + DetectedToolCall::invalid(ShellToolArguments::FUNCTION_NAME, id, error) + } }) .collect() } @@ -1720,6 +1864,12 @@ impl ServerExecutedTool for ShellExecutor { call: DetectedToolCall, _ctx: &ToolContext, ) -> Result { + // The model emitted a `shell` call we recognized but couldn't parse. + // Surface a spec-shaped error `shell_call_output` and feed the error + // back so the loop continues and the model can retry — never drop it. + if let Some(error) = &call.invalid { + return Ok(synthesize_invalid_args_handle(&call.call_id, error)); + } let commands: Vec = call .arguments .get("commands") @@ -3002,7 +3152,9 @@ mod tests { "call_id": "call_abc", "arguments": "{\"action\": {\"commands\": [\"echo hi\"]}}" }); - let tc = parse_shell_tool_call(&v).unwrap(); + let ShellCallDetection::Valid(tc) = parse_shell_tool_call(&v).unwrap() else { + panic!("expected a valid shell call"); + }; assert_eq!(tc.id, "call_abc"); assert_eq!(tc.args.commands, vec!["echo hi".to_string()]); assert!(tc.args.stdin.is_none()); @@ -3017,7 +3169,9 @@ mod tests { "call_id": "call_xyz", "arguments": "{\"action\": {\"commands\": [\"cd /tmp\", \"ls /\"], \"timeout_ms\": 1500, \"max_output_length\": 2000, \"env\": {\"FOO\": \"bar\"}, \"working_directory\": \"/tmp\"}}" }); - let tc = parse_shell_tool_call(&v).unwrap(); + let ShellCallDetection::Valid(tc) = parse_shell_tool_call(&v).unwrap() else { + panic!("expected a valid shell call"); + }; assert_eq!( tc.args.commands, vec!["cd /tmp".to_string(), "ls /".to_string()] @@ -3097,6 +3251,97 @@ mod tests { assert!(parse_shell_tool_call(&v).is_none()); } + #[test] + fn double_encoded_action_detected_as_invalid() { + // `action` is a JSON *string* instead of an object (double-encoded). + let v = serde_json::json!({ + "type": "function_call", + "name": "shell", + "call_id": "call_bad", + "arguments": "{\"action\": \"{\\\"action\\\": {\\\"commands\\\": [\\\"echo hi\\\"]}}\"}" + }); + let ShellCallDetection::Invalid { id, error } = parse_shell_tool_call(&v).unwrap() else { + panic!("expected an invalid shell call"); + }; + assert_eq!(id, "call_bad"); + assert!(!error.is_empty()); + } + + #[test] + fn empty_commands_detected_as_invalid() { + let v = serde_json::json!({ + "type": "function_call", + "name": "shell", + "call_id": "call_empty", + "arguments": "{\"action\": {\"commands\": []}}" + }); + let ShellCallDetection::Invalid { id, error } = parse_shell_tool_call(&v).unwrap() else { + panic!("expected an invalid shell call"); + }; + assert_eq!(id, "call_empty"); + assert!(error.contains("commands")); + } + + #[test] + fn detect_shell_in_chunk_marks_malformed_invalid() { + let chunk = br#"data: {"type": "response.output_item.done", "item": {"type": "function_call", "name": "shell", "call_id": "call_z", "arguments": "{\"action\": \"oops\"}"}} + +"#; + let found = detect_shell_in_chunk(chunk); + assert_eq!(found.len(), 1); + assert!(matches!(found[0], ShellCallDetection::Invalid { .. })); + } + + #[tokio::test] + async fn synthesize_invalid_args_handle_emits_failed_command() { + let handle = synthesize_invalid_args_handle("call_99", "bad args here"); + + // Drain the client-facing event stream. + let events: Vec = handle.events.collect().await; + let parsed: Vec = events + .iter() + .map(|b| { + let text = std::str::from_utf8(b).unwrap(); + let data = text + .lines() + .find_map(|l| l.strip_prefix("data:").map(str::trim)) + .unwrap(); + serde_json::from_str(data).unwrap() + }) + .collect(); + + // The `shell_call_output` `done` reports a failed command: status + // completed, an `exit` outcome with the conventional non-zero code, + // and the reason in stderr. + let output_done = parsed + .iter() + .find(|e| { + e["type"] == "response.output_item.done" && e["item"]["type"] == "shell_call_output" + }) + .expect("a shell_call_output done event"); + assert_eq!(output_done["item"]["status"], "completed"); + let content = &output_done["item"]["output"][0]; + assert_eq!(content["outcome"]["type"], "exit"); + assert_eq!(content["outcome"]["exit_code"], 2); + assert!( + content["stderr"] + .as_str() + .unwrap() + .contains("bad args here"), + "stderr should carry the error: {content:?}" + ); + + // The continuation feeds the error back, paired to the same call_id. + let result = handle.result.await.unwrap(); + assert_eq!(result.call_id, "call_99"); + let ResponsesInputItem::FunctionCallOutput(out) = &result.continuation_items[0] else { + panic!("expected a FunctionCallOutput continuation item"); + }; + assert_eq!(out.call_id, "call_99"); + assert!(out.output.contains("exit_code: 2")); + assert!(out.output.contains("bad args here")); + } + #[test] fn preprocess_rewrites_shell_tool_to_function() { let payload_json = serde_json::json!({ diff --git a/src/services/web_search_tool.rs b/src/services/web_search_tool.rs index 5cb4bbc6..ad1d35ef 100644 --- a/src/services/web_search_tool.rs +++ b/src/services/web_search_tool.rs @@ -35,8 +35,8 @@ pub struct WebSearchToolArguments { impl WebSearchToolArguments { pub const FUNCTION_NAME: &'static str = "web_search"; - pub fn parse(arguments_json: &str) -> Option { - serde_json::from_str(arguments_json).ok() + pub fn parse(arguments_json: &str) -> Result { + serde_json::from_str(arguments_json) } pub fn function_description() -> &'static str { @@ -145,12 +145,24 @@ struct WebSearchToolCall { query: String, } +/// Outcome of inspecting a `function_call` item named `web_search`. +/// +/// `Invalid` carries the call id and reason so the executor can synthesize +/// a `web_search_call` with status `failed` rather than dropping the call. +/// `None` from [`parse_web_search_tool_call`] means the item is not a +/// web_search call and should pass through untouched. +#[derive(Debug, Clone)] +enum WebSearchCallDetection { + Valid(WebSearchToolCall), + Invalid { id: String, error: String }, +} + // ───────────────────────────────────────────────────────────────────────────── // Detection // ───────────────────────────────────────────────────────────────────────────── /// Parse a web_search tool call from a JSON value. -fn parse_web_search_tool_call(value: &Value) -> Option { +fn parse_web_search_tool_call(value: &Value) -> Option { let obj = value.as_object()?; let type_val = obj.get("type")?.as_str()?; @@ -171,16 +183,20 @@ fn parse_web_search_tool_call(value: &Value) -> Option { .to_string(); let arguments_str = obj.get("arguments")?.as_str()?; - let args = WebSearchToolArguments::parse(arguments_str)?; - - Some(WebSearchToolCall { - id, - query: args.query, - }) + match WebSearchToolArguments::parse(arguments_str) { + Ok(args) => Some(WebSearchCallDetection::Valid(WebSearchToolCall { + id, + query: args.query, + })), + Err(e) => Some(WebSearchCallDetection::Invalid { + id, + error: format!("could not parse `arguments` (expected {{\"query\": \"...\"}}): {e}"), + }), + } } /// Detect web_search tool calls in an SSE chunk. -fn detect_web_search_in_chunk(chunk: &[u8]) -> Vec { +fn detect_web_search_in_chunk(chunk: &[u8]) -> Vec { let Some(chunk_str) = std::str::from_utf8(chunk).ok() else { return Vec::new(); }; @@ -315,6 +331,55 @@ fn format_web_search_call_output_event(item_id: &str) -> Option { Some(Bytes::from(format!("data: {}\n\n", json_str))) } +/// Build a self-contained handle for a `web_search` call whose arguments +/// couldn't be parsed. Emits a `web_search_call` item with status `failed` +/// (the spec's failure status) and feeds the error back as a +/// `function_call_output` so the loop continues and the model can retry. +#[cfg(feature = "server")] +fn synthesize_web_search_invalid_handle( + call_id: &str, + error: &str, +) -> crate::services::server_tools::ToolExecutionHandle { + let id = call_id.to_string(); + let failed_item = WebSearchCallOutput { + type_: WebSearchCallOutputType::WebSearchCall, + id: id.clone(), + status: WebSearchStatus::Failed, + }; + let done_event = serde_json::json!({ + "type": "response.output_item.done", + "output_index": 0, + "item": failed_item, + }); + let events = vec![ + format_web_search_in_progress_event(&id, 0), + Bytes::from(format!( + "data: {}\n\n", + serde_json::to_string(&done_event).unwrap_or_default() + )), + ]; + + let continuation_item = ResponsesInputItem::FunctionCallOutput(FunctionCallOutput { + type_: FunctionCallOutputType::FunctionCallOutput, + id: Some(id.clone()), + call_id: id.clone(), + output: crate::services::server_tools::invalid_arguments_text( + WebSearchToolArguments::FUNCTION_NAME, + error, + ), + status: None, + }); + let result = crate::services::server_tools::ToolCallResult { + call_id: id, + continuation_items: vec![continuation_item], + }; + + crate::services::server_tools::ToolExecutionHandle { + events: Box::pin(futures_util::stream::iter(events)), + result: Box::pin(async move { Ok(result) }), + } +} + // ───────────────────────────────────────────────────────────────────────────── // Streaming wrapper // ───────────────────────────────────────────────────────────────────────────── @@ -382,13 +447,24 @@ impl crate::services::server_tools::ServerExecutedTool for WebSearchExecutor { ) -> Vec { detect_web_search_in_chunk(event) .into_iter() - .map(|tc| crate::services::server_tools::DetectedToolCall { - tool_name: WebSearchToolArguments::FUNCTION_NAME, - call_id: tc.id.clone(), - arguments: serde_json::json!({ - "id": tc.id, - "query": tc.query, - }), + .map(|detection| match detection { + WebSearchCallDetection::Valid(tc) => { + crate::services::server_tools::DetectedToolCall::new( + WebSearchToolArguments::FUNCTION_NAME, + tc.id.clone(), + serde_json::json!({ + "id": tc.id, + "query": tc.query, + }), + ) + } + WebSearchCallDetection::Invalid { id, error } => { + crate::services::server_tools::DetectedToolCall::invalid( + WebSearchToolArguments::FUNCTION_NAME, + id, + error, + ) + } }) .collect() } @@ -401,6 +477,12 @@ impl crate::services::server_tools::ServerExecutedTool for WebSearchExecutor { crate::services::server_tools::ToolExecutionHandle, crate::services::server_tools::ToolError, > { + // The model emitted a `web_search` call we recognized but couldn't + // parse. Surface a `web_search_call` with status `failed` and feed + // the error back so the loop continues — never drop it. + if let Some(error) = &call.invalid { + return Ok(synthesize_web_search_invalid_handle(&call.call_id, error)); + } let query = call .arguments .get("query") @@ -545,11 +627,31 @@ mod tests { "call_id": "call_123", "arguments": "{\"query\": \"rust async programming\"}" }); - let tc = parse_web_search_tool_call(&value).unwrap(); + let WebSearchCallDetection::Valid(tc) = parse_web_search_tool_call(&value).unwrap() else { + panic!("expected a valid web_search call"); + }; assert_eq!(tc.id, "call_123"); assert_eq!(tc.query, "rust async programming"); } + #[test] + fn test_parse_web_search_tool_call_invalid_arguments() { + // `query` is a number, not a string → fails to deserialize. + let value = serde_json::json!({ + "type": "function_call", + "name": "web_search", + "call_id": "call_bad", + "arguments": "{\"query\": 123}" + }); + let WebSearchCallDetection::Invalid { id, error } = + parse_web_search_tool_call(&value).unwrap() + else { + panic!("expected an invalid web_search call"); + }; + assert_eq!(id, "call_bad"); + assert!(!error.is_empty()); + } + #[test] fn test_parse_web_search_tool_call_not_web_search() { let value = serde_json::json!({ @@ -577,7 +679,10 @@ mod tests { "#; let calls = detect_web_search_in_chunk(chunk); assert_eq!(calls.len(), 1); - assert_eq!(calls[0].query, "latest news"); + let WebSearchCallDetection::Valid(tc) = &calls[0] else { + panic!("expected a valid web_search call"); + }; + assert_eq!(tc.query, "latest news"); } #[test] @@ -598,8 +703,11 @@ mod tests { let chunk = b"data: {\"type\": \"response.function_call_arguments.done\", \"name\": \"web_search\", \"item_id\": \"item_789\", \"arguments\": \"{\\\"query\\\": \\\"weather today\\\"}\"}\n\ndata: {\"type\": \"response.output_item.done\", \"item\": {\"type\": \"function_call\", \"name\": \"web_search\", \"call_id\": \"call_789\", \"arguments\": \"{\\\"query\\\": \\\"weather today\\\"}\"}}\n\n"; let calls = detect_web_search_in_chunk(chunk); assert_eq!(calls.len(), 1); - assert_eq!(calls[0].id, "call_789"); - assert_eq!(calls[0].query, "weather today"); + let WebSearchCallDetection::Valid(tc) = &calls[0] else { + panic!("expected a valid web_search call"); + }; + assert_eq!(tc.id, "call_789"); + assert_eq!(tc.query, "weather today"); } #[test]