Skip to content

Commit 0bb5563

Browse files
authored
feat: Convert fields in types.py to use snake_case (#199)
Makes types more pythonic Pydantic aliases can be used for initialization and are used automatically for serialization. Added a workaround to support `camelCase` when setting specific attributes. NOTE: In 0.3.0 and later, only snake_case will be supported for direct attribute access.
1 parent 97f1093 commit 0bb5563

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+785
-673
lines changed

.github/actions/spelling/allow.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ pyversions
6868
respx
6969
resub
7070
RUF
71+
SLF
7172
socio
7273
sse
7374
tagwords

.mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[mypy]
22
exclude = src/a2a/grpc/
33
disable_error_code = import-not-found,annotation-unchecked,import-untyped
4+
plugins = pydantic.mypy
45

56
[mypy-examples.*]
67
follow_imports = skip

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ url = "https://test.pypi.org/simple/"
9999
publish-url = "https://test.pypi.org/legacy/"
100100
explicit = true
101101

102+
[tool.mypy]
103+
plugins = ['pydantic.mypy']
104+
102105
[tool.pyright]
103106
include = ["src"]
104107
exclude = [

scripts/generate_types.sh

100644100755
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ uv run datamodel-codegen \
3333
--class-name A2A \
3434
--use-standard-collections \
3535
--use-subclass-enum \
36-
--base-class a2a._base.A2ABaseModel
36+
--base-class a2a._base.A2ABaseModel \
37+
--field-constraints \
38+
--snake-case-field \
39+
--no-alias
3740

3841
echo "Formatting generated file with ruff..."
3942
uv run ruff format "$GENERATED_FILE"

src/a2a/_base.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,77 @@
1+
from typing import Any, ClassVar
2+
13
from pydantic import BaseModel, ConfigDict
4+
from pydantic.alias_generators import to_camel
5+
6+
7+
def to_camel_custom(snake: str) -> str:
8+
"""Convert a snake_case string to camelCase.
9+
10+
Args:
11+
snake: The string to convert.
12+
13+
Returns:
14+
The converted camelCase string.
15+
"""
16+
# First, remove any trailing underscores. This is common for names that
17+
# conflict with Python keywords, like 'in_' or 'from_'.
18+
if snake.endswith('_'):
19+
snake = snake.rstrip('_')
20+
return to_camel(snake)
221

322

423
class A2ABaseModel(BaseModel):
524
"""Base class for shared behavior across A2A data models.
625
726
Provides a common configuration (e.g., alias-based population) and
827
serves as the foundation for future extensions or shared utilities.
28+
29+
This implementation provides backward compatibility for camelCase aliases
30+
by lazy-loading an alias map upon first use.
931
"""
1032

1133
model_config = ConfigDict(
1234
# SEE: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.populate_by_name
1335
validate_by_name=True,
1436
validate_by_alias=True,
37+
serialize_by_alias=True,
38+
alias_generator=to_camel_custom,
1539
)
40+
41+
# Cache for the alias -> field_name mapping.
42+
# It starts as None and is populated on first access.
43+
_alias_to_field_name_map: ClassVar[dict[str, str] | None] = None
44+
45+
@classmethod
46+
def _get_alias_map(cls) -> dict[str, str]:
47+
"""Lazily builds and returns the alias-to-field-name mapping for the class.
48+
49+
The map is cached on the class object to avoid re-computation.
50+
"""
51+
if cls._alias_to_field_name_map is None:
52+
cls._alias_to_field_name_map = {
53+
field.alias: field_name
54+
for field_name, field in cls.model_fields.items()
55+
if field.alias is not None
56+
}
57+
return cls._alias_to_field_name_map
58+
59+
def __setattr__(self, name: str, value: Any) -> None:
60+
"""Allow setting attributes via their camelCase alias."""
61+
# Get the map and find the corresponding snake_case field name.
62+
field_name = type(self)._get_alias_map().get(name) # noqa: SLF001
63+
# If an alias was used, field_name will be set; otherwise, use the original name.
64+
super().__setattr__(field_name or name, value)
65+
66+
def __getattr__(self, name: str) -> Any:
67+
"""Allow getting attributes via their camelCase alias."""
68+
# Get the map and find the corresponding snake_case field name.
69+
field_name = type(self)._get_alias_map().get(name) # noqa: SLF001
70+
if field_name:
71+
# If an alias was used, retrieve the actual snake_case attribute.
72+
return getattr(self, field_name)
73+
74+
# If it's not a known alias, it's a genuine missing attribute.
75+
raise AttributeError(
76+
f"'{type(self).__name__}' object has no attribute '{name}'"
77+
)

src/a2a/client/auth/interceptor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def intercept(
3636
if (
3737
agent_card is None
3838
or agent_card.security is None
39-
or agent_card.securitySchemes is None
39+
or agent_card.security_schemes is None
4040
):
4141
return request_payload, http_kwargs
4242

@@ -45,8 +45,8 @@ async def intercept(
4545
credential = await self._credential_service.get_credentials(
4646
scheme_name, context
4747
)
48-
if credential and scheme_name in agent_card.securitySchemes:
49-
scheme_def_union = agent_card.securitySchemes.get(
48+
if credential and scheme_name in agent_card.security_schemes:
49+
scheme_def_union = agent_card.security_schemes.get(
5050
scheme_name
5151
)
5252
if not scheme_def_union:

src/a2a/client/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def create_text_message_object(
1515
content: The text content of the message. Defaults to an empty string.
1616
1717
Returns:
18-
A `Message` object with a new UUID messageId.
18+
A `Message` object with a new UUID message_id.
1919
"""
2020
return Message(
21-
role=role, parts=[Part(TextPart(text=content))], messageId=str(uuid4())
21+
role=role, parts=[Part(TextPart(text=content))], message_id=str(uuid4())
2222
)

src/a2a/server/agent_execution/context.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def __init__( # noqa: PLR0913
5353
# match the request. Otherwise, create them
5454
if self._params:
5555
if task_id:
56-
self._params.message.taskId = task_id
56+
self._params.message.task_id = task_id
5757
if task and task.id != task_id:
5858
raise ServerError(InvalidParamsError(message='bad task id'))
5959
else:
6060
self._check_or_generate_task_id()
6161
if context_id:
62-
self._params.message.contextId = context_id
63-
if task and task.contextId != context_id:
62+
self._params.message.context_id = context_id
63+
if task and task.context_id != context_id:
6464
raise ServerError(
6565
InvalidParamsError(message='bad context id')
6666
)
@@ -148,17 +148,17 @@ def _check_or_generate_task_id(self) -> None:
148148
if not self._params:
149149
return
150150

151-
if not self._task_id and not self._params.message.taskId:
152-
self._params.message.taskId = str(uuid.uuid4())
153-
if self._params.message.taskId:
154-
self._task_id = self._params.message.taskId
151+
if not self._task_id and not self._params.message.task_id:
152+
self._params.message.task_id = str(uuid.uuid4())
153+
if self._params.message.task_id:
154+
self._task_id = self._params.message.task_id
155155

156156
def _check_or_generate_context_id(self) -> None:
157157
"""Ensures a context ID is present, generating one if necessary."""
158158
if not self._params:
159159
return
160160

161-
if not self._context_id and not self._params.message.contextId:
162-
self._params.message.contextId = str(uuid.uuid4())
163-
if self._params.message.contextId:
164-
self._context_id = self._params.message.contextId
161+
if not self._context_id and not self._params.message.context_id:
162+
self._params.message.context_id = str(uuid.uuid4())
163+
if self._params.message.context_id:
164+
self._context_id = self._params.message.context_id

src/a2a/server/agent_execution/simple_request_context_builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(
1818
1919
Args:
2020
should_populate_referred_tasks: If True, the builder will fetch tasks
21-
referenced in `params.message.referenceTaskIds` and populate the
21+
referenced in `params.message.reference_task_ids` and populate the
2222
`related_tasks` field in the RequestContext. Defaults to False.
2323
task_store: The TaskStore instance to use for fetching referred tasks.
2424
Required if `should_populate_referred_tasks` is True.
@@ -38,7 +38,7 @@ async def build(
3838
3939
This method assembles the RequestContext object. If the builder was
4040
initialized with `should_populate_referred_tasks=True`, it fetches all tasks
41-
referenced in `params.message.referenceTaskIds` from the `task_store`.
41+
referenced in `params.message.reference_task_ids` from the `task_store`.
4242
4343
Args:
4444
params: The parameters of the incoming message send request.
@@ -57,12 +57,12 @@ async def build(
5757
self._task_store
5858
and self._should_populate_referred_tasks
5959
and params
60-
and params.message.referenceTaskIds
60+
and params.message.reference_task_ids
6161
):
6262
tasks = await asyncio.gather(
6363
*[
6464
self._task_store.get(task_id)
65-
for task_id in params.message.referenceTaskIds
65+
for task_id in params.message.reference_task_ids
6666
]
6767
)
6868
related_tasks = [x for x in tasks if x is not None]

src/a2a/server/apps/jsonrpc/fastapi_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def add_routes_to_app(
8989
)(self._handle_requests)
9090
app.get(agent_card_url)(self._handle_get_agent_card)
9191

92-
if self.agent_card.supportsAuthenticatedExtendedCard:
92+
if self.agent_card.supports_authenticated_extended_card:
9393
app.get(extended_agent_card_url)(
9494
self._handle_get_authenticated_extended_agent_card
9595
)

0 commit comments

Comments
 (0)