Skip to content

Commit 70b4999

Browse files
authored
feat: support non-blocking sendMessage (#349)
1 parent d9e463c commit 70b4999

File tree

3 files changed

+71
-8
lines changed

3 files changed

+71
-8
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,18 @@ async def on_message_send(
280280
consumer = EventConsumer(queue)
281281
producer_task.add_done_callback(consumer.agent_task_callback)
282282

283-
interrupted = False
283+
blocking = True # Default to blocking behavior
284+
if params.configuration and params.configuration.blocking is False:
285+
blocking = False
286+
287+
interrupted_or_non_blocking = False
284288
try:
285289
(
286290
result,
287-
interrupted,
288-
) = await result_aggregator.consume_and_break_on_interrupt(consumer)
291+
interrupted_or_non_blocking,
292+
) = await result_aggregator.consume_and_break_on_interrupt(
293+
consumer, blocking=blocking
294+
)
289295
if not result:
290296
raise ServerError(error=InternalError())
291297

@@ -300,7 +306,7 @@ async def on_message_send(
300306
logger.error(f'Agent execution failed. Error: {e}')
301307
raise
302308
finally:
303-
if interrupted:
309+
if interrupted_or_non_blocking:
304310
# TODO: Track this disconnected cleanup task.
305311
asyncio.create_task( # noqa: RUF006
306312
self._cleanup_producer(producer_task, task_id)

src/a2a/server/tasks/result_aggregator.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,19 @@ async def consume_all(
9292
return await self.task_manager.get_task()
9393

9494
async def consume_and_break_on_interrupt(
95-
self, consumer: EventConsumer
95+
self, consumer: EventConsumer, blocking: bool = True
9696
) -> tuple[Task | Message | None, bool]:
9797
"""Processes the event stream until completion or an interruptable state is encountered.
9898
99-
Interruptable states currently include `TaskState.auth_required`.
99+
If `blocking` is False, it returns after the first event that creates a Task or Message.
100+
If `blocking` is True, it waits for completion unless an `auth_required`
101+
state is encountered, which is always an interruption.
100102
If interrupted, consumption continues in a background task.
101103
102104
Args:
103105
consumer: The `EventConsumer` to read events from.
106+
blocking: If `False`, the method returns as soon as a task/message
107+
is available. If `True`, it waits for a terminal state.
104108
105109
Returns:
106110
A tuple containing:
@@ -117,10 +121,15 @@ async def consume_and_break_on_interrupt(
117121
self._message = event
118122
return event, False
119123
await self.task_manager.process(event)
120-
if (
124+
125+
should_interrupt = False
126+
is_auth_required = (
121127
isinstance(event, Task | TaskStatusUpdateEvent)
122128
and event.status.state == TaskState.auth_required
123-
):
129+
)
130+
131+
# Always interrupt on auth_required, as it needs external action.
132+
if is_auth_required:
124133
# auth-required is a special state: the message should be
125134
# escalated back to the caller, but the agent is expected to
126135
# continue producing events once the authorization is received
@@ -130,6 +139,16 @@ async def consume_and_break_on_interrupt(
130139
logger.debug(
131140
'Encountered an auth-required task: breaking synchronous message/send flow.'
132141
)
142+
should_interrupt = True
143+
# For non-blocking calls, interrupt as soon as a task is available.
144+
elif not blocking:
145+
logger.debug(
146+
'Non-blocking call: returning task after first event.'
147+
)
148+
should_interrupt = True
149+
150+
if should_interrupt:
151+
# Continue consuming the rest of the events in the background.
133152
# TODO: We should track all outstanding tasks to ensure they eventually complete.
134153
asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006
135154
interrupted = True

tests/server/tasks/test_result_aggregator.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,44 @@ async def raiser_gen_interrupt():
384384
)
385385
self.mock_task_manager.get_task.assert_not_called()
386386

387+
@patch('asyncio.create_task')
388+
async def test_consume_and_break_non_blocking(
389+
self, mock_create_task: MagicMock
390+
):
391+
"""Test that with blocking=False, the method returns after the first event."""
392+
first_event = create_sample_task('non_blocking_task')
393+
event_after = create_sample_message('should be consumed later')
394+
395+
async def mock_consume_generator():
396+
yield first_event
397+
yield event_after
398+
399+
self.mock_event_consumer.consume_all.return_value = (
400+
mock_consume_generator()
401+
)
402+
# After processing `first_event`, the current result will be that task.
403+
self.aggregator.task_manager.get_task.return_value = first_event
404+
405+
self.aggregator._continue_consuming = AsyncMock()
406+
mock_create_task.side_effect = lambda coro: asyncio.ensure_future(coro)
407+
408+
(
409+
result,
410+
interrupted,
411+
) = await self.aggregator.consume_and_break_on_interrupt(
412+
self.mock_event_consumer, blocking=False
413+
)
414+
415+
self.assertEqual(result, first_event)
416+
self.assertTrue(interrupted)
417+
self.mock_task_manager.process.assert_called_once_with(first_event)
418+
mock_create_task.assert_called_once()
419+
# The background task should be created with the remaining stream
420+
self.aggregator._continue_consuming.assert_called_once()
421+
self.assertIsInstance(
422+
self.aggregator._continue_consuming.call_args[0][0], AsyncIterator
423+
)
424+
387425
@patch('asyncio.create_task') # To verify _continue_consuming is called
388426
async def test_continue_consuming_processes_remaining_events(
389427
self, mock_create_task: MagicMock

0 commit comments

Comments
 (0)