Skip to content

Commit db82a65

Browse files
weimchholtskinner
andauthored
fix: non-blocking send_message server handler not invoke push notification (#394)
- Problem: When client use `send_message` with `MessageSendConfiguration.blocking=False`, the `result_aggregator` will enter the logic of `_continue_consuming`. But it's not push notification to client. The client can't get notification for long-running task(non-blocking invoke) at this situation. - Solution: Simply add push notification logic to result_aggregator is okay. Fixes #239 🦕 Co-authored-by: Holt Skinner <[email protected]>
1 parent a371461 commit db82a65

File tree

3 files changed

+158
-6
lines changed

3 files changed

+158
-6
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,19 @@ async def on_message_send(
286286

287287
interrupted_or_non_blocking = False
288288
try:
289+
# Create async callback for push notifications
290+
async def push_notification_callback() -> None:
291+
await self._send_push_notification_if_needed(
292+
task_id, result_aggregator
293+
)
294+
289295
(
290296
result,
291297
interrupted_or_non_blocking,
292298
) = await result_aggregator.consume_and_break_on_interrupt(
293-
consumer, blocking=blocking
299+
consumer,
300+
blocking=blocking,
301+
event_callback=push_notification_callback,
294302
)
295303
if not result:
296304
raise ServerError(error=InternalError()) # noqa: TRY301

src/a2a/server/tasks/result_aggregator.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33

4-
from collections.abc import AsyncGenerator, AsyncIterator
4+
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
55

66
from a2a.server.events import Event, EventConsumer
77
from a2a.server.tasks.task_manager import TaskManager
@@ -24,7 +24,10 @@ class ResultAggregator:
2424
Task object and emit that Task object.
2525
"""
2626

27-
def __init__(self, task_manager: TaskManager):
27+
def __init__(
28+
self,
29+
task_manager: TaskManager,
30+
) -> None:
2831
"""Initializes the ResultAggregator.
2932
3033
Args:
@@ -92,7 +95,10 @@ async def consume_all(
9295
return await self.task_manager.get_task()
9396

9497
async def consume_and_break_on_interrupt(
95-
self, consumer: EventConsumer, blocking: bool = True
98+
self,
99+
consumer: EventConsumer,
100+
blocking: bool = True,
101+
event_callback: Callable[[], Awaitable[None]] | None = None,
96102
) -> tuple[Task | Message | None, bool]:
97103
"""Processes the event stream until completion or an interruptable state is encountered.
98104
@@ -105,6 +111,9 @@ async def consume_and_break_on_interrupt(
105111
consumer: The `EventConsumer` to read events from.
106112
blocking: If `False`, the method returns as soon as a task/message
107113
is available. If `True`, it waits for a terminal state.
114+
event_callback: Optional async callback function to be called after each event
115+
is processed in the background continuation.
116+
Mainly used for push notifications currently.
108117
109118
Returns:
110119
A tuple containing:
@@ -150,13 +159,17 @@ async def consume_and_break_on_interrupt(
150159
if should_interrupt:
151160
# Continue consuming the rest of the events in the background.
152161
# TODO: We should track all outstanding tasks to ensure they eventually complete.
153-
asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006
162+
asyncio.create_task( # noqa: RUF006
163+
self._continue_consuming(event_stream, event_callback)
164+
)
154165
interrupted = True
155166
break
156167
return await self.task_manager.get_task(), interrupted
157168

158169
async def _continue_consuming(
159-
self, event_stream: AsyncIterator[Event]
170+
self,
171+
event_stream: AsyncIterator[Event],
172+
event_callback: Callable[[], Awaitable[None]] | None = None,
160173
) -> None:
161174
"""Continues processing an event stream in a background task.
162175
@@ -165,6 +178,9 @@ async def _continue_consuming(
165178
166179
Args:
167180
event_stream: The remaining `AsyncIterator` of events from the consumer.
181+
event_callback: Optional async callback function to be called after each event is processed.
168182
"""
169183
async for event in event_stream:
170184
await self.task_manager.process(event)
185+
if event_callback:
186+
await event_callback()

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,134 @@ async def get_current_result():
405405
mock_agent_executor.execute.assert_awaited_once()
406406

407407

408+
@pytest.mark.asyncio
409+
async def test_on_message_send_with_push_notification_in_non_blocking_request():
410+
"""Test that push notification callback is called during background event processing for non-blocking requests."""
411+
mock_task_store = AsyncMock(spec=TaskStore)
412+
mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore)
413+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
414+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
415+
mock_push_sender = AsyncMock()
416+
417+
task_id = 'non_blocking_task_1'
418+
context_id = 'non_blocking_ctx_1'
419+
420+
# Create a task that will be returned after the first event
421+
initial_task = create_sample_task(
422+
task_id=task_id, context_id=context_id, status_state=TaskState.working
423+
)
424+
425+
# Create a final task that will be available during background processing
426+
final_task = create_sample_task(
427+
task_id=task_id, context_id=context_id, status_state=TaskState.completed
428+
)
429+
430+
mock_task_store.get.return_value = None
431+
432+
# Mock request context
433+
mock_request_context = MagicMock(spec=RequestContext)
434+
mock_request_context.task_id = task_id
435+
mock_request_context.context_id = context_id
436+
mock_request_context_builder.build.return_value = mock_request_context
437+
438+
request_handler = DefaultRequestHandler(
439+
agent_executor=mock_agent_executor,
440+
task_store=mock_task_store,
441+
push_config_store=mock_push_notification_store,
442+
request_context_builder=mock_request_context_builder,
443+
push_sender=mock_push_sender,
444+
)
445+
446+
# Configure push notification
447+
push_config = PushNotificationConfig(url='http://callback.com/push')
448+
message_config = MessageSendConfiguration(
449+
push_notification_config=push_config,
450+
accepted_output_modes=['text/plain'],
451+
blocking=False, # Non-blocking request
452+
)
453+
params = MessageSendParams(
454+
message=Message(
455+
role=Role.user,
456+
message_id='msg_non_blocking',
457+
parts=[],
458+
task_id=task_id,
459+
context_id=context_id,
460+
),
461+
configuration=message_config,
462+
)
463+
464+
# Mock ResultAggregator with custom behavior
465+
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
466+
467+
# First call returns the initial task and indicates interruption (non-blocking)
468+
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
469+
initial_task,
470+
True, # interrupted = True for non-blocking
471+
)
472+
473+
# Mock the current_result property to return the final task
474+
async def get_current_result():
475+
return final_task
476+
477+
type(mock_result_aggregator_instance).current_result = PropertyMock(
478+
return_value=get_current_result()
479+
)
480+
481+
# Track if the event_callback was passed to consume_and_break_on_interrupt
482+
event_callback_passed = False
483+
event_callback_received = None
484+
485+
async def mock_consume_and_break_on_interrupt(
486+
consumer, blocking=True, event_callback=None
487+
):
488+
nonlocal event_callback_passed, event_callback_received
489+
event_callback_passed = event_callback is not None
490+
event_callback_received = event_callback
491+
return initial_task, True # interrupted = True for non-blocking
492+
493+
mock_result_aggregator_instance.consume_and_break_on_interrupt = (
494+
mock_consume_and_break_on_interrupt
495+
)
496+
497+
with (
498+
patch(
499+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
500+
return_value=mock_result_aggregator_instance,
501+
),
502+
patch(
503+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
504+
return_value=initial_task,
505+
),
506+
patch(
507+
'a2a.server.request_handlers.default_request_handler.TaskManager.update_with_message',
508+
return_value=initial_task,
509+
),
510+
):
511+
# Execute the non-blocking request
512+
result = await request_handler.on_message_send(
513+
params, create_server_call_context()
514+
)
515+
516+
# Verify the result is the initial task (non-blocking behavior)
517+
assert result == initial_task
518+
519+
# Verify that the event_callback was passed to consume_and_break_on_interrupt
520+
assert event_callback_passed, (
521+
'event_callback should have been passed to consume_and_break_on_interrupt'
522+
)
523+
assert event_callback_received is not None, (
524+
'event_callback should not be None'
525+
)
526+
527+
# Verify that the push notification was sent with the final task
528+
mock_push_sender.send_notification.assert_called_with(final_task)
529+
530+
# Verify that the push notification config was stored
531+
mock_push_notification_store.set_info.assert_awaited_once_with(
532+
task_id, push_config
533+
)
534+
535+
408536
@pytest.mark.asyncio
409537
async def test_on_message_send_with_push_notification_no_existing_Task():
410538
"""Test on_message_send for new task sets push notification info if provided."""

0 commit comments

Comments
 (0)