Skip to content

Commit ba93053

Browse files
rajeshvelichetigemini-code-assist[bot]holtskinner
authored
feat: Add agent card as a route in rest adapter (#386)
# Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [ ] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [ ] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Holt Skinner <[email protected]> Co-authored-by: Holt Skinner <[email protected]>
1 parent 2db7f81 commit ba93053

File tree

3 files changed

+102
-24
lines changed

3 files changed

+102
-24
lines changed

src/a2a/server/apps/rest/fastapi_app.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import logging
22

3+
from collections.abc import Callable
34
from typing import TYPE_CHECKING, Any
45

56

67
if TYPE_CHECKING:
78
from fastapi import APIRouter, FastAPI, Request, Response
9+
from fastapi.responses import JSONResponse
810

911
_package_fastapi_installed = True
1012
else:
1113
try:
1214
from fastapi import APIRouter, FastAPI, Request, Response
15+
from fastapi.responses import JSONResponse
1316

1417
_package_fastapi_installed = True
1518
except ImportError:
@@ -23,6 +26,7 @@
2326

2427
from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder
2528
from a2a.server.apps.rest.rest_adapter import RESTAdapter
29+
from a2a.server.context import ServerCallContext
2630
from a2a.server.request_handlers.request_handler import RequestHandler
2731
from a2a.types import AgentCard
2832
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
@@ -39,11 +43,17 @@ class A2ARESTFastAPIApplication:
3943
(SSE).
4044
"""
4145

42-
def __init__(
46+
def __init__( # noqa: PLR0913
4347
self,
4448
agent_card: AgentCard,
4549
http_handler: RequestHandler,
50+
extended_agent_card: AgentCard | None = None,
4651
context_builder: CallContextBuilder | None = None,
52+
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
53+
extended_card_modifier: Callable[
54+
[AgentCard, ServerCallContext], AgentCard
55+
]
56+
| None = None,
4757
):
4858
"""Initializes the A2ARESTFastAPIApplication.
4959
@@ -56,6 +66,11 @@ def __init__(
5666
context_builder: The CallContextBuilder used to construct the
5767
ServerCallContext passed to the http_handler. If None, no
5868
ServerCallContext is passed.
69+
card_modifier: An optional callback to dynamically modify the public
70+
agent card before it is served.
71+
extended_card_modifier: An optional callback to dynamically modify
72+
the extended agent card before it is served. It receives the
73+
call context.
5974
"""
6075
if not _package_fastapi_installed:
6176
raise ImportError(
@@ -66,7 +81,10 @@ def __init__(
6681
self._adapter = RESTAdapter(
6782
agent_card=agent_card,
6883
http_handler=http_handler,
84+
extended_agent_card=extended_agent_card,
6985
context_builder=context_builder,
86+
card_modifier=card_modifier,
87+
extended_card_modifier=extended_card_modifier,
7088
)
7189

7290
def build(
@@ -95,7 +113,8 @@ def build(
95113

96114
@router.get(f'{rpc_url}{agent_card_url}')
97115
async def get_agent_card(request: Request) -> Response:
98-
return await self._adapter.handle_get_agent_card(request)
116+
card = await self._adapter.handle_get_agent_card(request)
117+
return JSONResponse(card)
99118

100119
app.include_router(router)
101120
return app

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,34 @@ class RESTAdapter:
5252
manages response generation including Server-Sent Events (SSE).
5353
"""
5454

55-
def __init__(
55+
def __init__( # noqa: PLR0913
5656
self,
5757
agent_card: AgentCard,
5858
http_handler: RequestHandler,
59+
extended_agent_card: AgentCard | None = None,
5960
context_builder: CallContextBuilder | None = None,
61+
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
62+
extended_card_modifier: Callable[
63+
[AgentCard, ServerCallContext], AgentCard
64+
]
65+
| None = None,
6066
):
6167
"""Initializes the RESTApplication.
6268
6369
Args:
6470
agent_card: The AgentCard describing the agent's capabilities.
6571
http_handler: The handler instance responsible for processing A2A
6672
requests via http.
73+
extended_agent_card: An optional, distinct AgentCard to be served
74+
at the authenticated extended card endpoint.
6775
context_builder: The CallContextBuilder used to construct the
6876
ServerCallContext passed to the http_handler. If None, no
6977
ServerCallContext is passed.
78+
card_modifier: An optional callback to dynamically modify the public
79+
agent card before it is served.
80+
extended_card_modifier: An optional callback to dynamically modify
81+
the extended agent card before it is served. It receives the
82+
call context.
7083
"""
7184
if not _package_starlette_installed:
7285
raise ImportError(
@@ -75,9 +88,20 @@ def __init__(
7588
' optional dependencies, `a2a-sdk[http-server]`.'
7689
)
7790
self.agent_card = agent_card
91+
self.extended_agent_card = extended_agent_card
92+
self.card_modifier = card_modifier
93+
self.extended_card_modifier = extended_card_modifier
7894
self.handler = RESTHandler(
7995
agent_card=agent_card, request_handler=http_handler
8096
)
97+
if (
98+
self.agent_card.supports_authenticated_extended_card
99+
and self.extended_agent_card is None
100+
and self.extended_card_modifier is None
101+
):
102+
logger.error(
103+
'AgentCard.supports_authenticated_extended_card is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
104+
)
81105
self._context_builder = context_builder or DefaultCallContextBuilder()
82106

83107
@rest_error_handler
@@ -108,33 +132,35 @@ async def event_generator(
108132
event_generator(method(request, call_context))
109133
)
110134

111-
@rest_error_handler
112-
async def handle_get_agent_card(self, request: Request) -> JSONResponse:
135+
async def handle_get_agent_card(
136+
self, request: Request, call_context: ServerCallContext | None = None
137+
) -> dict[str, Any]:
113138
"""Handles GET requests for the agent card endpoint.
114139
115140
Args:
116141
request: The incoming Starlette Request object.
142+
call_context: ServerCallContext
117143
118144
Returns:
119145
A JSONResponse containing the agent card data.
120146
"""
121-
# The public agent card is a direct serialization of the agent_card
122-
# provided at initialization.
123-
return JSONResponse(
124-
self.agent_card.model_dump(mode='json', exclude_none=True)
125-
)
147+
card_to_serve = self.agent_card
148+
if self.card_modifier:
149+
card_to_serve = self.card_modifier(card_to_serve)
150+
151+
return card_to_serve.model_dump(mode='json', exclude_none=True)
126152

127-
@rest_error_handler
128153
async def handle_authenticated_agent_card(
129-
self, request: Request
130-
) -> JSONResponse:
154+
self, request: Request, call_context: ServerCallContext | None = None
155+
) -> dict[str, Any]:
131156
"""Hook for per credential agent card response.
132157
133158
If a dynamic card is needed based on the credentials provided in the request
134159
override this method and return the customized content.
135160
136161
Args:
137162
request: The incoming Starlette Request object.
163+
call_context: ServerCallContext
138164
139165
Returns:
140166
A JSONResponse containing the authenticated card.
@@ -145,9 +171,18 @@ async def handle_authenticated_agent_card(
145171
message='Authenticated card not supported'
146172
)
147173
)
148-
return JSONResponse(
149-
self.agent_card.model_dump(mode='json', exclude_none=True)
150-
)
174+
card_to_serve = self.extended_agent_card
175+
176+
if not card_to_serve:
177+
card_to_serve = self.agent_card
178+
179+
if self.extended_card_modifier:
180+
context = self._context_builder.build(request)
181+
# If no base extended card is provided, pass the public card to the modifier
182+
base_card = card_to_serve if card_to_serve else self.agent_card
183+
card_to_serve = self.extended_card_modifier(base_card, context)
184+
185+
return card_to_serve.model_dump(mode='json', exclude_none=True)
151186

152187
def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
153188
"""Constructs a dictionary of API routes and their corresponding handlers.
@@ -201,6 +236,8 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
201236
),
202237
}
203238
if self.agent_card.supports_authenticated_extended_card:
204-
routes[('/v1/card', 'GET')] = self.handle_authenticated_agent_card
239+
routes[('/v1/card', 'GET')] = functools.partial(
240+
self._handle_request, self.handle_authenticated_agent_card
241+
)
205242

206243
return routes

tests/integration/test_client_server_integration.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
32
from collections.abc import AsyncGenerator
43
from typing import NamedTuple
54
from unittest.mock import ANY, AsyncMock
@@ -8,7 +7,6 @@
87
import httpx
98
import pytest
109
import pytest_asyncio
11-
1210
from grpc.aio import Channel
1311

1412
from a2a.client.transports import JsonRpcTransport, RestTransport
@@ -38,7 +36,6 @@
3836
TransportProtocol,
3937
)
4038

41-
4239
# --- Test Constants ---
4340

4441
TASK_FROM_STREAM = Task(
@@ -130,7 +127,7 @@ def agent_card() -> AgentCard:
130127
default_input_modes=['text/plain'],
131128
default_output_modes=['text/plain'],
132129
preferred_transport=TransportProtocol.jsonrpc,
133-
supports_authenticated_extended_card=True,
130+
supports_authenticated_extended_card=False,
134131
additional_interfaces=[
135132
AgentInterface(
136133
transport=TransportProtocol.http_json, url='http://testserver'
@@ -709,9 +706,7 @@ async def test_http_transport_get_card(
709706
transport_setup_fixture
710707
)
711708
transport = transport_setup.transport
712-
713-
# The transport starts with a minimal card, get_card() fetches the full one
714-
transport.agent_card.supports_authenticated_extended_card = True
709+
# Get the base card.
715710
result = await transport.get_card()
716711

717712
assert result.name == agent_card.name
@@ -722,6 +717,33 @@ async def test_http_transport_get_card(
722717
await transport.close()
723718

724719

720+
@pytest.mark.asyncio
721+
async def test_http_transport_get_authenticated_card(
722+
agent_card: AgentCard,
723+
mock_request_handler: AsyncMock,
724+
) -> None:
725+
agent_card.supports_authenticated_extended_card = True
726+
extended_agent_card = agent_card.model_copy(deep=True)
727+
extended_agent_card.name = 'Extended Agent Card'
728+
729+
app_builder = A2ARESTFastAPIApplication(
730+
agent_card,
731+
mock_request_handler,
732+
extended_agent_card=extended_agent_card,
733+
)
734+
app = app_builder.build()
735+
httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app))
736+
737+
transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card)
738+
result = await transport.get_card()
739+
assert result.name == extended_agent_card.name
740+
assert transport.agent_card.name == extended_agent_card.name
741+
assert transport._needs_extended_card is False
742+
743+
if hasattr(transport, 'close'):
744+
await transport.close()
745+
746+
725747
@pytest.mark.asyncio
726748
async def test_grpc_transport_get_card(
727749
grpc_server_and_handler: tuple[str, AsyncMock],

0 commit comments

Comments
 (0)