Skip to content

Commit e3e5c4b

Browse files
feat: Add ServerCallContext into task store operations (#443)
In production systems the persistence of the task should be protected based on the credentials of the creator of the task (e.g. a user id or email). Additionally, applications may have other criteria to use for task persistence (like application name, or region task runs in). --- Providing the `ServerCallContext` into the calls to the `get`, `save` and `delete` interface for the task store allows customization of the persisted task data based on the characteristics needed for a real solution. Agent implementors can construct the appropriate `ServerCallContext` based on the incoming request and use that information at task creation, retrieval and deletion time. Fixes #442 🦕 --------- Co-authored-by: Holt Skinner <[email protected]> Co-authored-by: Holt Skinner <[email protected]>
1 parent 813f5cd commit e3e5c4b

File tree

8 files changed

+108
-68
lines changed

8 files changed

+108
-68
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ async def on_get_task(
109109
context: ServerCallContext | None = None,
110110
) -> Task | None:
111111
"""Default handler for 'tasks/get'."""
112-
task: Task | None = await self.task_store.get(params.id)
112+
task: Task | None = await self.task_store.get(params.id, context)
113113
if not task:
114114
raise ServerError(error=TaskNotFoundError())
115115

@@ -141,7 +141,7 @@ async def on_cancel_task(
141141
142142
Attempts to cancel the task managed by the `AgentExecutor`.
143143
"""
144-
task: Task | None = await self.task_store.get(params.id)
144+
task: Task | None = await self.task_store.get(params.id, context)
145145
if not task:
146146
raise ServerError(error=TaskNotFoundError())
147147

@@ -158,6 +158,7 @@ async def on_cancel_task(
158158
context_id=task.context_id,
159159
task_store=self.task_store,
160160
initial_message=None,
161+
context=context,
161162
)
162163
result_aggregator = ResultAggregator(task_manager)
163164

@@ -224,6 +225,7 @@ async def _setup_message_execution(
224225
context_id=params.message.context_id,
225226
task_store=self.task_store,
226227
initial_message=params.message,
228+
context=context,
227229
)
228230
task: Task | None = await task_manager.get_task()
229231

@@ -424,7 +426,7 @@ async def on_set_task_push_notification_config(
424426
if not self._push_config_store:
425427
raise ServerError(error=UnsupportedOperationError())
426428

427-
task: Task | None = await self.task_store.get(params.task_id)
429+
task: Task | None = await self.task_store.get(params.task_id, context)
428430
if not task:
429431
raise ServerError(error=TaskNotFoundError())
430432

@@ -447,7 +449,7 @@ async def on_get_task_push_notification_config(
447449
if not self._push_config_store:
448450
raise ServerError(error=UnsupportedOperationError())
449451

450-
task: Task | None = await self.task_store.get(params.id)
452+
task: Task | None = await self.task_store.get(params.id, context)
451453
if not task:
452454
raise ServerError(error=TaskNotFoundError())
453455

@@ -476,7 +478,7 @@ async def on_resubscribe_to_task(
476478
Allows a client to re-attach to a running streaming task's event stream.
477479
Requires the task and its queue to still be active.
478480
"""
479-
task: Task | None = await self.task_store.get(params.id)
481+
task: Task | None = await self.task_store.get(params.id, context)
480482
if not task:
481483
raise ServerError(error=TaskNotFoundError())
482484

@@ -492,6 +494,7 @@ async def on_resubscribe_to_task(
492494
context_id=task.context_id,
493495
task_store=self.task_store,
494496
initial_message=None,
497+
context=context,
495498
)
496499

497500
result_aggregator = ResultAggregator(task_manager)
@@ -516,7 +519,7 @@ async def on_list_task_push_notification_config(
516519
if not self._push_config_store:
517520
raise ServerError(error=UnsupportedOperationError())
518521

519-
task: Task | None = await self.task_store.get(params.id)
522+
task: Task | None = await self.task_store.get(params.id, context)
520523
if not task:
521524
raise ServerError(error=TaskNotFoundError())
522525

@@ -543,7 +546,7 @@ async def on_delete_task_push_notification_config(
543546
if not self._push_config_store:
544547
raise ServerError(error=UnsupportedOperationError())
545548

546-
task: Task | None = await self.task_store.get(params.id)
549+
task: Task | None = await self.task_store.get(params.id, context)
547550
if not task:
548551
raise ServerError(error=TaskNotFoundError())
549552

src/a2a/server/tasks/database_task_store.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"or 'pip install a2a-sdk[sql]'"
2020
) from e
2121

22+
from a2a.server.context import ServerCallContext
2223
from a2a.server.models import Base, TaskModel, create_task_model
2324
from a2a.server.tasks.task_store import TaskStore
2425
from a2a.types import Task # Task is the Pydantic model
@@ -119,15 +120,19 @@ def _from_orm(self, task_model: TaskModel) -> Task:
119120
# Pydantic's model_validate will parse the nested dicts/lists from JSON
120121
return Task.model_validate(task_data_from_db)
121122

122-
async def save(self, task: Task) -> None:
123+
async def save(
124+
self, task: Task, context: ServerCallContext | None = None
125+
) -> None:
123126
"""Saves or updates a task in the database."""
124127
await self._ensure_initialized()
125128
db_task = self._to_orm(task)
126129
async with self.async_session_maker.begin() as session:
127130
await session.merge(db_task)
128131
logger.debug('Task %s saved/updated successfully.', task.id)
129132

130-
async def get(self, task_id: str) -> Task | None:
133+
async def get(
134+
self, task_id: str, context: ServerCallContext | None = None
135+
) -> Task | None:
131136
"""Retrieves a task from the database by ID."""
132137
await self._ensure_initialized()
133138
async with self.async_session_maker() as session:
@@ -142,7 +147,9 @@ async def get(self, task_id: str) -> Task | None:
142147
logger.debug('Task %s not found in store.', task_id)
143148
return None
144149

145-
async def delete(self, task_id: str) -> None:
150+
async def delete(
151+
self, task_id: str, context: ServerCallContext | None = None
152+
) -> None:
146153
"""Deletes a task from the database by ID."""
147154
await self._ensure_initialized()
148155

src/a2a/server/tasks/inmemory_task_store.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import logging
33

4+
from a2a.server.context import ServerCallContext
45
from a2a.server.tasks.task_store import TaskStore
56
from a2a.types import Task
67

@@ -21,13 +22,17 @@ def __init__(self) -> None:
2122
self.tasks: dict[str, Task] = {}
2223
self.lock = asyncio.Lock()
2324

24-
async def save(self, task: Task) -> None:
25+
async def save(
26+
self, task: Task, context: ServerCallContext | None = None
27+
) -> None:
2528
"""Saves or updates a task in the in-memory store."""
2629
async with self.lock:
2730
self.tasks[task.id] = task
2831
logger.debug('Task %s saved successfully.', task.id)
2932

30-
async def get(self, task_id: str) -> Task | None:
33+
async def get(
34+
self, task_id: str, context: ServerCallContext | None = None
35+
) -> Task | None:
3136
"""Retrieves a task from the in-memory store by ID."""
3237
async with self.lock:
3338
logger.debug('Attempting to get task with id: %s', task_id)
@@ -38,7 +43,9 @@ async def get(self, task_id: str) -> Task | None:
3843
logger.debug('Task %s not found in store.', task_id)
3944
return task
4045

41-
async def delete(self, task_id: str) -> None:
46+
async def delete(
47+
self, task_id: str, context: ServerCallContext | None = None
48+
) -> None:
4249
"""Deletes a task from the in-memory store by ID."""
4350
async with self.lock:
4451
logger.debug('Attempting to delete task with id: %s', task_id)

src/a2a/server/tasks/task_manager.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
from a2a.server.context import ServerCallContext
34
from a2a.server.events.event_queue import Event
45
from a2a.server.tasks.task_store import TaskStore
56
from a2a.types import (
@@ -31,6 +32,7 @@ def __init__(
3132
context_id: str | None,
3233
task_store: TaskStore,
3334
initial_message: Message | None,
35+
context: ServerCallContext | None = None,
3436
):
3537
"""Initializes the TaskManager.
3638
@@ -40,6 +42,7 @@ def __init__(
4042
task_store: The `TaskStore` instance for persistence.
4143
initial_message: The `Message` that initiated the task, if any.
4244
Used when creating a new task object.
45+
context: The `ServerCallContext` that this task is produced under.
4346
"""
4447
if task_id is not None and not (isinstance(task_id, str) and task_id):
4548
raise ValueError('Task ID must be a non-empty string')
@@ -49,6 +52,7 @@ def __init__(
4952
self.task_store = task_store
5053
self._initial_message = initial_message
5154
self._current_task: Task | None = None
55+
self._call_context: ServerCallContext | None = context
5256
logger.debug(
5357
'TaskManager initialized with task_id: %s, context_id: %s',
5458
task_id,
@@ -74,7 +78,9 @@ async def get_task(self) -> Task | None:
7478
logger.debug(
7579
'Attempting to get task from store with id: %s', self.task_id
7680
)
77-
self._current_task = await self.task_store.get(self.task_id)
81+
self._current_task = await self.task_store.get(
82+
self.task_id, self._call_context
83+
)
7884
if self._current_task:
7985
logger.debug('Task %s retrieved successfully.', self.task_id)
8086
else:
@@ -167,7 +173,7 @@ async def ensure_task(
167173
logger.debug(
168174
'Attempting to retrieve existing task with id: %s', self.task_id
169175
)
170-
task = await self.task_store.get(self.task_id)
176+
task = await self.task_store.get(self.task_id, self._call_context)
171177

172178
if not task:
173179
logger.info(
@@ -231,7 +237,7 @@ async def _save_task(self, task: Task) -> None:
231237
task: The `Task` object to save.
232238
"""
233239
logger.debug('Saving task with id: %s', task.id)
234-
await self.task_store.save(task)
240+
await self.task_store.save(task, self._call_context)
235241
self._current_task = task
236242
if not self.task_id:
237243
logger.info('New task created with id: %s', task.id)

src/a2a/server/tasks/task_store.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
22

3+
from a2a.server.context import ServerCallContext
34
from a2a.types import Task
45

56

@@ -10,13 +11,19 @@ class TaskStore(ABC):
1011
"""
1112

1213
@abstractmethod
13-
async def save(self, task: Task) -> None:
14+
async def save(
15+
self, task: Task, context: ServerCallContext | None = None
16+
) -> None:
1417
"""Saves or updates a task in the store."""
1518

1619
@abstractmethod
17-
async def get(self, task_id: str) -> Task | None:
20+
async def get(
21+
self, task_id: str, context: ServerCallContext | None = None
22+
) -> Task | None:
1823
"""Retrieves a task from the store by ID."""
1924

2025
@abstractmethod
21-
async def delete(self, task_id: str) -> None:
26+
async def delete(
27+
self, task_id: str, context: ServerCallContext | None = None
28+
) -> None:
2229
"""Deletes a task from the store by ID."""

0 commit comments

Comments
 (0)