Skip to content

Commit 3a68d8f

Browse files
authored
fix: handle SSE errors occurred after stream started (#894)
The spec doesn't defined this behavior: a2aproject/A2A#1262, but currently it'd close the connection.
1 parent 01b3b2c commit 3a68d8f

8 files changed

Lines changed: 242 additions & 82 deletions

File tree

src/a2a/client/transports/http_helpers.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
1313

1414

15+
def _default_sse_error_handler(sse_data: str) -> NoReturn:
16+
raise A2AClientError(f'SSE stream error event received: {sse_data}')
17+
18+
1519
@contextmanager
1620
def handle_http_exceptions(
1721
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn]
@@ -71,9 +75,22 @@ async def send_http_stream_request(
7175
url: str,
7276
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn]
7377
| None = None,
78+
sse_error_handler: Callable[[str], NoReturn] = _default_sse_error_handler,
7479
**kwargs: Any,
7580
) -> AsyncGenerator[str]:
76-
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
81+
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions.
82+
83+
Args:
84+
httpx_client: The async HTTP client.
85+
method: The HTTP method (e.g. 'POST', 'GET').
86+
url: The URL to send the request to.
87+
status_error_handler: Handler for HTTP status errors. Should raise an
88+
appropriate domain-specific exception.
89+
sse_error_handler: Handler for SSE error events. Called with the
90+
raw SSE data string when an ``event: error`` SSE event is received.
91+
Should raise an appropriate domain-specific exception.
92+
**kwargs: Additional keyword arguments forwarded to ``aconnect_sse``.
93+
"""
7794
with handle_http_exceptions(status_error_handler):
7895
async with _SSEEventSource(
7996
httpx_client, method, url, **kwargs
@@ -97,6 +114,8 @@ async def send_http_stream_request(
97114
async for sse in event_source.aiter_sse():
98115
if not sse.data:
99116
continue
117+
if sse.event == 'error':
118+
sse_error_handler(sse.data)
100119
yield sse.data
101120

102121

src/a2a/client/transports/jsonrpc.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from collections.abc import AsyncGenerator
4-
from typing import Any
4+
from typing import Any, NoReturn
55
from uuid import uuid4
66

77
import httpx
@@ -350,6 +350,7 @@ async def _send_stream_request(
350350
'POST',
351351
self.url,
352352
None,
353+
self._handle_sse_error,
353354
json=rpc_request_payload,
354355
**http_kwargs,
355356
):
@@ -360,3 +361,10 @@ async def _send_stream_request(
360361
json_rpc_response.result, StreamResponse()
361362
)
362363
yield response
364+
365+
def _handle_sse_error(self, sse_data: str) -> NoReturn:
366+
"""Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error."""
367+
json_rpc_response = JSONRPC20Response.from_json(sse_data)
368+
if json_rpc_response.error:
369+
raise self._create_jsonrpc_error(json_rpc_response.error)
370+
raise A2AClientError(f'SSE stream error: {sse_data}')

src/a2a/client/transports/rest.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,47 @@
4141
logger = logging.getLogger(__name__)
4242

4343

44+
def _parse_rest_error(
45+
error_payload: dict[str, Any],
46+
fallback_message: str,
47+
) -> Exception | None:
48+
"""Parses a REST error payload and returns the appropriate A2AError.
49+
50+
Args:
51+
error_payload: The parsed JSON error payload.
52+
fallback_message: Message to use if the payload has no ``message``.
53+
54+
Returns:
55+
The mapped A2AError if a known reason was found, otherwise ``None``.
56+
"""
57+
error_data = error_payload.get('error', {})
58+
message = error_data.get('message', fallback_message)
59+
details = error_data.get('details', [])
60+
if not isinstance(details, list):
61+
return None
62+
63+
# The `details` array can contain multiple different error objects.
64+
# We extract the first `ErrorInfo` object because it contains the
65+
# specific `reason` code needed to map this back to a Python A2AError.
66+
for d in details:
67+
if (
68+
isinstance(d, dict)
69+
and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo'
70+
):
71+
reason = d.get('reason')
72+
metadata = d.get('metadata') or {}
73+
if isinstance(reason, str):
74+
exception_cls = A2A_REASON_TO_ERROR.get(reason)
75+
if exception_cls:
76+
exc = exception_cls(message)
77+
if metadata:
78+
exc.data = metadata
79+
return exc
80+
break
81+
82+
return None
83+
84+
4485
@trace_class(kind=SpanKind.CLIENT)
4586
class RestTransport(ClientTransport):
4687
"""A REST transport for the A2A client."""
@@ -294,39 +335,12 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
294335
"""Handles HTTP status errors and raises the appropriate A2AError."""
295336
try:
296337
error_payload = e.response.json()
297-
error_data = error_payload.get('error', {})
298-
299-
message = error_data.get('message', str(e))
300-
details = error_data.get('details', [])
301-
if not isinstance(details, list):
302-
details = []
303-
304-
# The `details` array can contain multiple different error objects.
305-
# We extract the first `ErrorInfo` object because it contains the
306-
# specific `reason` code needed to map this back to a Python A2AError.
307-
error_info = {}
308-
for d in details:
309-
if (
310-
isinstance(d, dict)
311-
and d.get('@type')
312-
== 'type.googleapis.com/google.rpc.ErrorInfo'
313-
):
314-
error_info = d
315-
break
316-
reason = error_info.get('reason')
317-
metadata = error_info.get('metadata') or {}
318-
319-
if isinstance(reason, str):
320-
exception_cls = A2A_REASON_TO_ERROR.get(reason)
321-
if exception_cls:
322-
exc = exception_cls(message)
323-
if metadata:
324-
exc.data = metadata
325-
raise exc from e
338+
mapped = _parse_rest_error(error_payload, str(e))
339+
if mapped:
340+
raise mapped from e
326341
except (json.JSONDecodeError, ValueError):
327342
pass
328343

329-
# Fallback mappings for status codes if 'type' is missing or unknown
330344
status_code = e.response.status_code
331345
if status_code == httpx.codes.NOT_FOUND:
332346
raise MethodNotFoundError(
@@ -335,6 +349,14 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
335349

336350
raise A2AClientError(f'HTTP Error {status_code}: {e}') from e
337351

352+
def _handle_sse_error(self, sse_data: str) -> NoReturn:
353+
"""Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError."""
354+
error_payload = json.loads(sse_data)
355+
mapped = _parse_rest_error(error_payload, sse_data)
356+
if mapped:
357+
raise mapped
358+
raise A2AClientError(sse_data)
359+
338360
async def _send_stream_request(
339361
self,
340362
method: str,
@@ -352,6 +374,7 @@ async def _send_stream_request(
352374
method,
353375
f'{self.url}{path}',
354376
self._handle_http_error,
377+
self._handle_sse_error,
355378
json=json,
356379
**http_kwargs,
357380
):

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,30 @@ def _create_response(
565565
async def event_generator(
566566
stream: AsyncGenerator[dict[str, Any]],
567567
) -> AsyncGenerator[dict[str, str]]:
568-
async for item in stream:
569-
yield {'data': json.dumps(item)}
568+
try:
569+
async for item in stream:
570+
event: dict[str, str] = {
571+
'data': json.dumps(item),
572+
}
573+
if 'error' in item:
574+
event['event'] = 'error'
575+
yield event
576+
except Exception as e:
577+
logger.exception(
578+
'Unhandled error during JSON-RPC SSE stream'
579+
)
580+
rpc_error: A2AError | JSONRPCError = (
581+
e
582+
if isinstance(e, A2AError | JSONRPCError)
583+
else InternalError(message=str(e))
584+
)
585+
error_response = build_error_response(
586+
context.state.get('request_id'), rpc_error
587+
)
588+
yield {
589+
'event': 'error',
590+
'data': json.dumps(error_response),
591+
}
570592

571593
return EventSourceResponse(
572594
event_generator(handler_result), headers=headers

src/a2a/server/routes/rest_dispatcher.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from a2a.utils import constants, proto_utils
2222
from a2a.utils.error_handlers import (
23+
build_rest_error_payload,
2324
rest_error_handler,
2425
rest_stream_error_handler,
2526
)
@@ -32,20 +33,23 @@
3233

3334

3435
if TYPE_CHECKING:
36+
from sse_starlette.event import ServerSentEvent
3537
from sse_starlette.sse import EventSourceResponse
3638
from starlette.requests import Request
3739
from starlette.responses import JSONResponse, Response
3840

3941
_package_starlette_installed = True
4042
else:
4143
try:
44+
from sse_starlette.event import ServerSentEvent
4245
from sse_starlette.sse import EventSourceResponse
4346
from starlette.requests import Request
4447
from starlette.responses import JSONResponse, Response
4548

4649
_package_starlette_installed = True
4750
except ImportError:
4851
EventSourceResponse = Any
52+
ServerSentEvent = Any
4953
Request = Any
5054
JSONResponse = Any
5155
Response = Any
@@ -135,10 +139,17 @@ async def _handle_streaming(
135139
except StopAsyncIteration:
136140
return EventSourceResponse(iter([]))
137141

138-
async def event_generator() -> AsyncIterator[str]:
139-
yield json.dumps(first_item)
140-
async for item in stream:
141-
yield json.dumps(item)
142+
async def event_generator() -> AsyncIterator[ServerSentEvent]:
143+
yield ServerSentEvent(data=json.dumps(first_item))
144+
try:
145+
async for item in stream:
146+
yield ServerSentEvent(data=json.dumps(item))
147+
except Exception as e:
148+
logger.exception('Error during REST SSE stream')
149+
yield ServerSentEvent(
150+
data=json.dumps(build_rest_error_payload(e)),
151+
event='error',
152+
)
142153

143154
return EventSourceResponse(event_generator())
144155

src/a2a/utils/error_handlers.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,43 @@ def _build_error_payload(
5454
return {'error': payload}
5555

5656

57-
def _create_error_response(error: Exception) -> Response:
58-
"""Helper function to create a JSONResponse for an error."""
57+
def build_rest_error_payload(error: Exception) -> dict[str, Any]:
58+
"""Build a REST error payload dict from an exception.
59+
60+
Returns:
61+
A dict with the error payload in the standard REST error format.
62+
"""
5963
if isinstance(error, A2AError):
6064
mapping = A2A_REST_ERROR_MAPPING.get(
6165
type(error), RestErrorMap(500, 'INTERNAL', 'INTERNAL_ERROR')
6266
)
63-
http_code = mapping.http_code
64-
grpc_status = mapping.grpc_status
65-
reason = mapping.reason
67+
# SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response.
68+
metadata = getattr(error, 'data', None) or {}
69+
return _build_error_payload(
70+
code=mapping.http_code,
71+
status=mapping.grpc_status,
72+
message=getattr(error, 'message', str(error)),
73+
reason=mapping.reason,
74+
metadata=metadata,
75+
)
76+
if isinstance(error, ParseError):
77+
return _build_error_payload(
78+
code=400,
79+
status='INVALID_ARGUMENT',
80+
message=str(error),
81+
reason='INVALID_REQUEST',
82+
metadata={},
83+
)
84+
return _build_error_payload(
85+
code=500,
86+
status='INTERNAL',
87+
message='unknown exception',
88+
)
6689

90+
91+
def _create_error_response(error: Exception) -> Response:
92+
"""Helper function to create a JSONResponse for an error."""
93+
if isinstance(error, A2AError):
6794
log_level = (
6895
logging.ERROR
6996
if isinstance(error, InternalError)
@@ -76,42 +103,17 @@ def _create_error_response(error: Exception) -> Response:
76103
getattr(error, 'message', str(error)),
77104
f', Data={error.data}' if error.data else '',
78105
)
79-
80-
# SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response.
81-
metadata = getattr(error, 'data', None) or {}
82-
83-
return JSONResponse(
84-
content=_build_error_payload(
85-
code=http_code,
86-
status=grpc_status,
87-
message=getattr(error, 'message', str(error)),
88-
reason=reason,
89-
metadata=metadata,
90-
),
91-
status_code=http_code,
92-
media_type='application/json',
93-
)
94-
if isinstance(error, ParseError):
106+
elif isinstance(error, ParseError):
95107
logger.warning('Parse error: %s', str(error))
96-
return JSONResponse(
97-
content=_build_error_payload(
98-
code=400,
99-
status='INVALID_ARGUMENT',
100-
message=str(error),
101-
reason='INVALID_REQUEST',
102-
metadata={},
103-
),
104-
status_code=400,
105-
media_type='application/json',
106-
)
107-
logger.exception('Unknown error occurred')
108+
else:
109+
logger.exception('Unknown error occurred')
110+
111+
payload = build_rest_error_payload(error)
112+
# Extract HTTP status code from the payload
113+
http_code = payload.get('error', {}).get('code', 500)
108114
return JSONResponse(
109-
content=_build_error_payload(
110-
code=500,
111-
status='INTERNAL',
112-
message='unknown exception',
113-
),
114-
status_code=500,
115+
content=payload,
116+
status_code=http_code,
115117
media_type='application/json',
116118
)
117119

0 commit comments

Comments
 (0)