@@ -405,6 +405,134 @@ async def get_current_result():
405
405
mock_agent_executor .execute .assert_awaited_once ()
406
406
407
407
408
+ @pytest .mark .asyncio
409
+ async def test_on_message_send_with_push_notification_in_non_blocking_request ():
410
+ """Test that push notification callback is called during background event processing for non-blocking requests."""
411
+ mock_task_store = AsyncMock (spec = TaskStore )
412
+ mock_push_notification_store = AsyncMock (spec = PushNotificationConfigStore )
413
+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
414
+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
415
+ mock_push_sender = AsyncMock ()
416
+
417
+ task_id = 'non_blocking_task_1'
418
+ context_id = 'non_blocking_ctx_1'
419
+
420
+ # Create a task that will be returned after the first event
421
+ initial_task = create_sample_task (
422
+ task_id = task_id , context_id = context_id , status_state = TaskState .working
423
+ )
424
+
425
+ # Create a final task that will be available during background processing
426
+ final_task = create_sample_task (
427
+ task_id = task_id , context_id = context_id , status_state = TaskState .completed
428
+ )
429
+
430
+ mock_task_store .get .return_value = None
431
+
432
+ # Mock request context
433
+ mock_request_context = MagicMock (spec = RequestContext )
434
+ mock_request_context .task_id = task_id
435
+ mock_request_context .context_id = context_id
436
+ mock_request_context_builder .build .return_value = mock_request_context
437
+
438
+ request_handler = DefaultRequestHandler (
439
+ agent_executor = mock_agent_executor ,
440
+ task_store = mock_task_store ,
441
+ push_config_store = mock_push_notification_store ,
442
+ request_context_builder = mock_request_context_builder ,
443
+ push_sender = mock_push_sender ,
444
+ )
445
+
446
+ # Configure push notification
447
+ push_config = PushNotificationConfig (url = 'http://callback.com/push' )
448
+ message_config = MessageSendConfiguration (
449
+ push_notification_config = push_config ,
450
+ accepted_output_modes = ['text/plain' ],
451
+ blocking = False , # Non-blocking request
452
+ )
453
+ params = MessageSendParams (
454
+ message = Message (
455
+ role = Role .user ,
456
+ message_id = 'msg_non_blocking' ,
457
+ parts = [],
458
+ task_id = task_id ,
459
+ context_id = context_id ,
460
+ ),
461
+ configuration = message_config ,
462
+ )
463
+
464
+ # Mock ResultAggregator with custom behavior
465
+ mock_result_aggregator_instance = AsyncMock (spec = ResultAggregator )
466
+
467
+ # First call returns the initial task and indicates interruption (non-blocking)
468
+ mock_result_aggregator_instance .consume_and_break_on_interrupt .return_value = (
469
+ initial_task ,
470
+ True , # interrupted = True for non-blocking
471
+ )
472
+
473
+ # Mock the current_result property to return the final task
474
+ async def get_current_result ():
475
+ return final_task
476
+
477
+ type(mock_result_aggregator_instance ).current_result = PropertyMock (
478
+ return_value = get_current_result ()
479
+ )
480
+
481
+ # Track if the event_callback was passed to consume_and_break_on_interrupt
482
+ event_callback_passed = False
483
+ event_callback_received = None
484
+
485
+ async def mock_consume_and_break_on_interrupt (
486
+ consumer , blocking = True , event_callback = None
487
+ ):
488
+ nonlocal event_callback_passed , event_callback_received
489
+ event_callback_passed = event_callback is not None
490
+ event_callback_received = event_callback
491
+ return initial_task , True # interrupted = True for non-blocking
492
+
493
+ mock_result_aggregator_instance .consume_and_break_on_interrupt = (
494
+ mock_consume_and_break_on_interrupt
495
+ )
496
+
497
+ with (
498
+ patch (
499
+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
500
+ return_value = mock_result_aggregator_instance ,
501
+ ),
502
+ patch (
503
+ 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task' ,
504
+ return_value = initial_task ,
505
+ ),
506
+ patch (
507
+ 'a2a.server.request_handlers.default_request_handler.TaskManager.update_with_message' ,
508
+ return_value = initial_task ,
509
+ ),
510
+ ):
511
+ # Execute the non-blocking request
512
+ result = await request_handler .on_message_send (
513
+ params , create_server_call_context ()
514
+ )
515
+
516
+ # Verify the result is the initial task (non-blocking behavior)
517
+ assert result == initial_task
518
+
519
+ # Verify that the event_callback was passed to consume_and_break_on_interrupt
520
+ assert event_callback_passed , (
521
+ 'event_callback should have been passed to consume_and_break_on_interrupt'
522
+ )
523
+ assert event_callback_received is not None , (
524
+ 'event_callback should not be None'
525
+ )
526
+
527
+ # Verify that the push notification was sent with the final task
528
+ mock_push_sender .send_notification .assert_called_with (final_task )
529
+
530
+ # Verify that the push notification config was stored
531
+ mock_push_notification_store .set_info .assert_awaited_once_with (
532
+ task_id , push_config
533
+ )
534
+
535
+
408
536
@pytest .mark .asyncio
409
537
async def test_on_message_send_with_push_notification_no_existing_Task ():
410
538
"""Test on_message_send for new task sets push notification info if provided."""
0 commit comments