Skip to content

Commit f4c9c18

Browse files
authored
fix: handle concurrent task completion during cancellation (#449)
### Description We [check](https://github.com/a2aproject/a2a-python/blob/d2e869f/src/a2a/server/request_handlers/default_request_handler.py#L149) that a Task is in a cancellable state before calling `agent_executor.cancel`. This doesn't guarantee there's no task completion event in the queue which will be applied before our task cancellation request gets handled. This PR adds an extra check to ensure that we don't return a Task in a non-cancelled state as a successful cancellation call response. Instead we raise `TaskNotCancelableError`.
1 parent 9da9ecc commit f4c9c18

File tree

3 files changed

+65
-6
lines changed

3 files changed

+65
-6
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,21 @@ async def on_cancel_task(
180180

181181
consumer = EventConsumer(queue)
182182
result = await result_aggregator.consume_all(consumer)
183-
if isinstance(result, Task):
184-
return result
183+
if not isinstance(result, Task):
184+
raise ServerError(
185+
error=InternalError(
186+
message='Agent did not return valid response for cancel'
187+
)
188+
)
185189

186-
raise ServerError(
187-
error=InternalError(
188-
message='Agent did not return valid response for cancel'
190+
if result.status.state != TaskState.canceled:
191+
raise ServerError(
192+
error=TaskNotCancelableError(
193+
message=f'Task cannot be canceled - current state: {result.status.state}'
194+
)
189195
)
190-
)
196+
197+
return result
191198

192199
async def _run_event_stream(
193200
self, request: RequestContext, queue: EventQueue

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,56 @@ async def test_on_cancel_task_cancels_running_agent():
263263
mock_agent_executor.cancel.assert_awaited_once()
264264

265265

266+
@pytest.mark.asyncio
267+
async def test_on_cancel_task_completes_during_cancellation():
268+
"""Test on_cancel_task fails to cancel a task due to concurrent task completion."""
269+
task_id = 'running_agent_task_to_cancel'
270+
sample_task = create_sample_task(task_id=task_id)
271+
mock_task_store = AsyncMock(spec=TaskStore)
272+
mock_task_store.get.return_value = sample_task
273+
274+
mock_queue_manager = AsyncMock(spec=QueueManager)
275+
mock_event_queue = AsyncMock(spec=EventQueue)
276+
mock_queue_manager.tap.return_value = mock_event_queue
277+
278+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
279+
280+
# Mock ResultAggregator
281+
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
282+
mock_result_aggregator_instance.consume_all.return_value = (
283+
create_sample_task(task_id=task_id, status_state=TaskState.completed)
284+
)
285+
286+
request_handler = DefaultRequestHandler(
287+
agent_executor=mock_agent_executor,
288+
task_store=mock_task_store,
289+
queue_manager=mock_queue_manager,
290+
)
291+
292+
# Simulate a running agent task
293+
mock_producer_task = AsyncMock(spec=asyncio.Task)
294+
request_handler._running_agents[task_id] = mock_producer_task
295+
296+
from a2a.utils.errors import (
297+
ServerError, # Local import
298+
TaskNotCancelableError, # Local import
299+
)
300+
301+
with patch(
302+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
303+
return_value=mock_result_aggregator_instance,
304+
):
305+
params = TaskIdParams(id=task_id)
306+
with pytest.raises(ServerError) as exc_info:
307+
await request_handler.on_cancel_task(
308+
params, create_server_call_context()
309+
)
310+
311+
mock_producer_task.cancel.assert_called_once()
312+
mock_agent_executor.cancel.assert_awaited_once()
313+
assert isinstance(exc_info.value.error, TaskNotCancelableError)
314+
315+
266316
@pytest.mark.asyncio
267317
async def test_on_cancel_task_invalid_result_type():
268318
"""Test on_cancel_task when result_aggregator returns a Message instead of a Task."""

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ async def test_on_cancel_task_success(self) -> None:
149149
call_context = ServerCallContext(state={'foo': 'bar'})
150150

151151
async def streaming_coro():
152+
mock_task.status.state = TaskState.canceled
152153
yield mock_task
153154

154155
with patch(
@@ -160,6 +161,7 @@ async def streaming_coro():
160161
assert mock_agent_executor.cancel.call_count == 1
161162
self.assertIsInstance(response.root, CancelTaskSuccessResponse)
162163
assert response.root.result == mock_task # type: ignore
164+
assert response.root.result.status.state == TaskState.canceled
163165
mock_agent_executor.cancel.assert_called_once()
164166

165167
async def test_on_cancel_task_not_supported(self) -> None:

0 commit comments

Comments
 (0)