Skip to content

Commit dec4b48

Browse files
racinmatgemini-code-assist[bot]holtskinnerkthota-g
authored
fix: openapi working in sub-app (#324)
The `A2AFastAPIApplication` is very useful, but currently can not be used as a [Sub Application](https://fastapi.tiangolo.com/advanced/sub-applications/) without breaking the openapi specification. [Lifespan does not work for Sub Applications](https://fastapi.tiangolo.com/advanced/events/#sub-applications), which makes this otherwise useful integration work in existing larger FastAPI apps. This fix gets rid of the lifespan and instead enriches the openapi on the first call of the [openapi method](https://github.com/fastapi/fastapi/blob/0.116.1/fastapi/applications.py#L966). I tested it locally and everything works with sub-application. The following works after the fix with openapi components populated: ```python from fastapi import FastAPI from a2a.server.apps import A2AFastAPIApplication app = FastAPI() agent_app = A2AFastAPIApplication(...).build() app.mount("/a2a", agent_app) ``` I am not adding `@override` although this overrides the `openapi`, because the pyright fails on it, see [the failed lint run](https://github.com/a2aproject/a2a-python/actions/runs/16370448989/job/46257297467?pr=324). --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Holt Skinner <[email protected]> Co-authored-by: kthota-g <[email protected]>
1 parent 3032aa6 commit dec4b48

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import logging
22

3-
from collections.abc import AsyncIterator
4-
from contextlib import asynccontextmanager
53
from typing import Any
64

75
from fastapi import FastAPI
@@ -21,6 +19,28 @@
2119
logger = logging.getLogger(__name__)
2220

2321

22+
class A2AFastAPI(FastAPI):
23+
"""A FastAPI application that adds A2A-specific OpenAPI components."""
24+
25+
_a2a_components_added: bool = False
26+
27+
def openapi(self) -> dict[str, Any]:
28+
"""Generates the OpenAPI schema for the application."""
29+
openapi_schema = super().openapi()
30+
if not self._a2a_components_added:
31+
a2a_request_schema = A2ARequest.model_json_schema(
32+
ref_template='#/components/schemas/{model}'
33+
)
34+
defs = a2a_request_schema.pop('$defs', {})
35+
component_schemas = openapi_schema.setdefault(
36+
'components', {}
37+
).setdefault('schemas', {})
38+
component_schemas.update(defs)
39+
component_schemas['A2ARequest'] = a2a_request_schema
40+
self._a2a_components_added = True
41+
return openapi_schema
42+
43+
2444
class A2AFastAPIApplication(JSONRPCApplication):
2545
"""A FastAPI application implementing the A2A protocol server endpoints.
2646
@@ -92,23 +112,7 @@ def build(
92112
Returns:
93113
A configured FastAPI application instance.
94114
"""
95-
96-
@asynccontextmanager
97-
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
98-
a2a_request_schema = A2ARequest.model_json_schema(
99-
ref_template='#/components/schemas/{model}'
100-
)
101-
defs = a2a_request_schema.pop('$defs', {})
102-
openapi_schema = app.openapi()
103-
component_schemas = openapi_schema.setdefault(
104-
'components', {}
105-
).setdefault('schemas', {})
106-
component_schemas.update(defs)
107-
component_schemas['A2ARequest'] = a2a_request_schema
108-
109-
yield
110-
111-
app = FastAPI(lifespan=lifespan, **kwargs)
115+
app = A2AFastAPI(**kwargs)
112116

113117
self.add_routes_to_app(
114118
app, agent_card_url, rpc_url, extended_agent_card_url

tests/server/apps/jsonrpc/test_serialization.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest import mock
22

33
import pytest
4+
from fastapi import FastAPI
45

56
from pydantic import ValidationError
67
from starlette.testclient import TestClient
@@ -183,3 +184,21 @@ def test_handle_unicode_characters(agent_card_with_api_key: AgentCard):
183184
data = response.json()
184185
assert 'error' not in data or data['error'] is None
185186
assert data['result']['parts'][0]['text'] == f'Received: {unicode_text}'
187+
188+
189+
def test_fastapi_sub_application(agent_card_with_api_key: AgentCard):
190+
"""
191+
Tests that the A2AFastAPIApplication endpoint correctly passes the url in sub-application.
192+
"""
193+
handler = mock.AsyncMock()
194+
sub_app_instance = A2AFastAPIApplication(agent_card_with_api_key, handler)
195+
app_instance = FastAPI()
196+
app_instance.mount('/a2a', sub_app_instance.build())
197+
client = TestClient(app_instance)
198+
199+
response = client.get('/a2a/openapi.json')
200+
assert response.status_code == 200
201+
response_data = response.json()
202+
203+
assert 'servers' in response_data
204+
assert response_data['servers'] == [{'url': '/a2a'}]

0 commit comments

Comments
 (0)