Skip to content

Commit 679ecc8

Browse files
committed
remove duplication
1 parent a17ad34 commit 679ecc8

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,7 @@ async def _setup_message_execution(
231231
)
232232

233233
queue = await self._queue_manager.create_or_tap(task_id)
234-
result_aggregator = ResultAggregator(
235-
task_manager,
236-
push_sender=self._push_sender,
237-
)
234+
result_aggregator = ResultAggregator(task_manager)
238235
# TODO: to manage the non-blocking flows.
239236
producer_task = asyncio.create_task(
240237
self._run_event_stream(request_context, queue)
@@ -289,11 +286,15 @@ async def on_message_send(
289286

290287
interrupted_or_non_blocking = False
291288
try:
289+
# Create async callback for push notifications
290+
async def push_notification_callback() -> None:
291+
await self._send_push_notification_if_needed(task_id, result_aggregator)
292+
292293
(
293294
result,
294295
interrupted_or_non_blocking,
295296
) = await result_aggregator.consume_and_break_on_interrupt(
296-
consumer, blocking=blocking
297+
consumer, blocking=blocking, event_callback=push_notification_callback
297298
)
298299
if not result:
299300
raise ServerError(error=InternalError())

src/a2a/server/tasks/result_aggregator.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import asyncio
22
import logging
33

4-
from collections.abc import AsyncGenerator, AsyncIterator
4+
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Awaitable
55

66
from a2a.server.events import Event, EventConsumer
7-
from a2a.server.tasks.push_notification_sender import PushNotificationSender
87
from a2a.server.tasks.task_manager import TaskManager
98
from a2a.types import Message, Task, TaskState, TaskStatusUpdateEvent
109

@@ -28,17 +27,14 @@ class ResultAggregator:
2827
def __init__(
2928
self,
3029
task_manager: TaskManager,
31-
push_sender: PushNotificationSender | None = None,
3230
) -> None:
3331
"""Initializes the ResultAggregator.
3432
3533
Args:
3634
task_manager: The `TaskManager` instance to use for processing events
3735
and managing the task state.
38-
push_sender: The `PushNotificationSender` instance to use for sending push notifications.
3936
"""
4037
self.task_manager = task_manager
41-
self.push_sender = push_sender
4238
self._message: Message | None = None
4339

4440
@property
@@ -99,7 +95,10 @@ async def consume_all(
9995
return await self.task_manager.get_task()
10096

10197
async def consume_and_break_on_interrupt(
102-
self, consumer: EventConsumer, blocking: bool = True
98+
self,
99+
consumer: EventConsumer,
100+
blocking: bool = True,
101+
event_callback: Callable[[], Awaitable[None]] | None = None
103102
) -> tuple[Task | Message | None, bool]:
104103
"""Processes the event stream until completion or an interruptable state is encountered.
105104
@@ -112,6 +111,9 @@ async def consume_and_break_on_interrupt(
112111
consumer: The `EventConsumer` to read events from.
113112
blocking: If `False`, the method returns as soon as a task/message
114113
is available. If `True`, it waits for a terminal state.
114+
event_callback: Optional async callback function to be called after each event
115+
is processed in the background continuation.
116+
Mainly used for push notifications currently.
115117
116118
Returns:
117119
A tuple containing:
@@ -157,13 +159,15 @@ async def consume_and_break_on_interrupt(
157159
if should_interrupt:
158160
# Continue consuming the rest of the events in the background.
159161
# TODO: We should track all outstanding tasks to ensure they eventually complete.
160-
asyncio.create_task(self._continue_consuming(event_stream)) # noqa: RUF006
162+
asyncio.create_task(self._continue_consuming(event_stream, event_callback)) # noqa: RUF006
161163
interrupted = True
162164
break
163165
return await self.task_manager.get_task(), interrupted
164166

165167
async def _continue_consuming(
166-
self, event_stream: AsyncIterator[Event]
168+
self,
169+
event_stream: AsyncIterator[Event],
170+
event_callback: Callable[[], Awaitable[None]] | None = None
167171
) -> None:
168172
"""Continues processing an event stream in a background task.
169173
@@ -172,14 +176,9 @@ async def _continue_consuming(
172176
173177
Args:
174178
event_stream: The remaining `AsyncIterator` of events from the consumer.
179+
event_callback: Optional async callback function to be called after each event is processed.
175180
"""
176181
async for event in event_stream:
177182
await self.task_manager.process(event)
178-
await self._send_push_notification_if_needed()
179-
180-
async def _send_push_notification_if_needed(self) -> None:
181-
"""Sends push notification if configured and task is available."""
182-
if self.push_sender:
183-
latest_task = await self.current_result
184-
if isinstance(latest_task, Task):
185-
await self.push_sender.send_notification(latest_task)
183+
if event_callback:
184+
await event_callback()

0 commit comments

Comments
 (0)