Skip to content

Commit b2e3a29

Browse files
ognis1205holtskinnergemini-code-assist[bot]
authored
test: improve test coverage for grpc_client.py (#306)
# Description Refactored `tests/client/test_grpc_client.py` and added test cases to improve coverage for `grpc_client.py`. - Improved test coverage for `src/a2a/client/grpc_client.py`: 41% → 98% - Fixed incorrect field checks on the response object in `grpc_client::send_message` - Fixed missing `await` call at the correct location in `grpc_client::send_message_streaming` - Defined valid `taskId` format (`[a-zA-Z0-9_.-]+`) Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `nox -s format` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes N/A 🦕 --------- Signed-off-by: Shingo OKAWA <[email protected]> Co-authored-by: Holt Skinner <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Holt Skinner <[email protected]>
1 parent 1cf8185 commit b2e3a29

File tree

2 files changed

+253
-7
lines changed

2 files changed

+253
-7
lines changed

src/a2a/utils/proto_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919

2020
# Regexp patterns for matching
21-
_TASK_NAME_MATCH = re.compile(r'tasks/([\w-]+)')
21+
_TASK_NAME_MATCH = re.compile(r'tasks/([^/]+)')
2222
_TASK_PUSH_CONFIG_NAME_MATCH = re.compile(
23-
r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)'
23+
r'tasks/([^/]+)/pushNotificationConfigs/([^/]+)'
2424
)
2525

2626

tests/client/test_grpc_client.py

Lines changed: 251 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,45 @@
1-
from unittest.mock import AsyncMock
1+
from unittest.mock import AsyncMock, MagicMock
22

3+
import grpc
34
import pytest
45

56
from a2a.client.transports.grpc import GrpcTransport
67
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
78
from a2a.types import (
89
AgentCapabilities,
910
AgentCard,
11+
Artifact,
12+
GetTaskPushNotificationConfigParams,
1013
Message,
1114
MessageSendParams,
1215
Part,
16+
PushNotificationAuthenticationInfo,
17+
PushNotificationConfig,
1318
Role,
1419
Task,
20+
TaskArtifactUpdateEvent,
1521
TaskIdParams,
22+
TaskPushNotificationConfig,
1623
TaskQueryParams,
1724
TaskState,
1825
TaskStatus,
26+
TaskStatusUpdateEvent,
1927
TextPart,
2028
)
2129
from a2a.utils import get_text_parts, proto_utils
30+
from a2a.utils.errors import ServerError
2231

2332

24-
# Fixtures
2533
@pytest.fixture
2634
def mock_grpc_stub() -> AsyncMock:
2735
"""Provides a mock gRPC stub with methods mocked."""
2836
stub = AsyncMock(spec=a2a_pb2_grpc.A2AServiceStub)
2937
stub.SendMessage = AsyncMock()
30-
stub.SendStreamingMessage = AsyncMock()
38+
stub.SendStreamingMessage = MagicMock()
3139
stub.GetTask = AsyncMock()
3240
stub.CancelTask = AsyncMock()
33-
stub.CreateTaskPushNotification = AsyncMock()
34-
stub.GetTaskPushNotification = AsyncMock()
41+
stub.CreateTaskPushNotificationConfig = AsyncMock()
42+
stub.GetTaskPushNotificationConfig = AsyncMock()
3543
return stub
3644

3745

@@ -93,6 +101,78 @@ def sample_message() -> Message:
93101
)
94102

95103

104+
@pytest.fixture
105+
def sample_artifact() -> Artifact:
106+
"""Provides a sample Artifact object."""
107+
return Artifact(
108+
artifact_id='artifact-1',
109+
name='example.txt',
110+
description='An example artifact',
111+
parts=[Part(root=TextPart(text='Hi there'))],
112+
metadata={},
113+
extensions=[],
114+
)
115+
116+
117+
@pytest.fixture
118+
def sample_task_status_update_event() -> TaskStatusUpdateEvent:
119+
"""Provides a sample TaskStatusUpdateEvent."""
120+
return TaskStatusUpdateEvent(
121+
task_id='task-1',
122+
context_id='ctx-1',
123+
status=TaskStatus(state=TaskState.working),
124+
final=False,
125+
metadata={},
126+
)
127+
128+
129+
@pytest.fixture
130+
def sample_task_artifact_update_event(
131+
sample_artifact,
132+
) -> TaskArtifactUpdateEvent:
133+
"""Provides a sample TaskArtifactUpdateEvent."""
134+
return TaskArtifactUpdateEvent(
135+
task_id='task-1',
136+
context_id='ctx-1',
137+
artifact=sample_artifact,
138+
append=True,
139+
last_chunk=True,
140+
metadata={},
141+
)
142+
143+
144+
@pytest.fixture
145+
def sample_authentication_info() -> PushNotificationAuthenticationInfo:
146+
"""Provides a sample AuthenticationInfo object."""
147+
return PushNotificationAuthenticationInfo(
148+
schemes=['apikey', 'oauth2'], credentials='secret-token'
149+
)
150+
151+
152+
@pytest.fixture
153+
def sample_push_notification_config(
154+
sample_authentication_info: PushNotificationAuthenticationInfo,
155+
) -> PushNotificationConfig:
156+
"""Provides a sample PushNotificationConfig object."""
157+
return PushNotificationConfig(
158+
id='config-1',
159+
url='https://example.com/notify',
160+
token='example-token',
161+
authentication=sample_authentication_info,
162+
)
163+
164+
165+
@pytest.fixture
166+
def sample_task_push_notification_config(
167+
sample_push_notification_config: PushNotificationConfig,
168+
) -> TaskPushNotificationConfig:
169+
"""Provides a sample TaskPushNotificationConfig object."""
170+
return TaskPushNotificationConfig(
171+
task_id='task-1',
172+
push_notification_config=sample_push_notification_config,
173+
)
174+
175+
96176
@pytest.mark.asyncio
97177
async def test_send_message_task_response(
98178
grpc_transport: GrpcTransport,
@@ -134,6 +214,57 @@ async def test_send_message_message_response(
134214
)
135215

136216

217+
@pytest.mark.asyncio
218+
async def test_send_message_streaming( # noqa: PLR0913
219+
grpc_transport: GrpcTransport,
220+
mock_grpc_stub: AsyncMock,
221+
sample_message_send_params: MessageSendParams,
222+
sample_message: Message,
223+
sample_task: Task,
224+
sample_task_status_update_event: TaskStatusUpdateEvent,
225+
sample_task_artifact_update_event: TaskArtifactUpdateEvent,
226+
):
227+
"""Test send_message_streaming that yields responses."""
228+
stream = MagicMock()
229+
stream.read = AsyncMock(
230+
side_effect=[
231+
a2a_pb2.StreamResponse(
232+
msg=proto_utils.ToProto.message(sample_message)
233+
),
234+
a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(sample_task)),
235+
a2a_pb2.StreamResponse(
236+
status_update=proto_utils.ToProto.task_status_update_event(
237+
sample_task_status_update_event
238+
)
239+
),
240+
a2a_pb2.StreamResponse(
241+
artifact_update=proto_utils.ToProto.task_artifact_update_event(
242+
sample_task_artifact_update_event
243+
)
244+
),
245+
grpc.aio.EOF,
246+
]
247+
)
248+
mock_grpc_stub.SendStreamingMessage.return_value = stream
249+
250+
responses = [
251+
response
252+
async for response in grpc_transport.send_message_streaming(
253+
sample_message_send_params
254+
)
255+
]
256+
257+
mock_grpc_stub.SendStreamingMessage.assert_called_once()
258+
assert isinstance(responses[0], Message)
259+
assert responses[0].message_id == sample_message.message_id
260+
assert isinstance(responses[1], Task)
261+
assert responses[1].id == sample_task.id
262+
assert isinstance(responses[2], TaskStatusUpdateEvent)
263+
assert responses[2].task_id == sample_task_status_update_event.task_id
264+
assert isinstance(responses[3], TaskArtifactUpdateEvent)
265+
assert responses[3].task_id == sample_task_artifact_update_event.task_id
266+
267+
137268
@pytest.mark.asyncio
138269
async def test_get_task(
139270
grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task
@@ -188,3 +319,118 @@ async def test_cancel_task(
188319
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}')
189320
)
190321
assert response.status.state == TaskState.canceled
322+
323+
324+
@pytest.mark.asyncio
325+
async def test_set_task_callback_with_valid_task(
326+
grpc_transport: GrpcTransport,
327+
mock_grpc_stub: AsyncMock,
328+
sample_task_push_notification_config: TaskPushNotificationConfig,
329+
):
330+
"""Test setting a task push notification config with a valid task id."""
331+
mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = (
332+
proto_utils.ToProto.task_push_notification_config(
333+
sample_task_push_notification_config
334+
)
335+
)
336+
337+
response = await grpc_transport.set_task_callback(
338+
sample_task_push_notification_config
339+
)
340+
341+
mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once_with(
342+
a2a_pb2.CreateTaskPushNotificationConfigRequest(
343+
parent=f'tasks/{sample_task_push_notification_config.task_id}',
344+
config_id=sample_task_push_notification_config.push_notification_config.id,
345+
config=proto_utils.ToProto.task_push_notification_config(
346+
sample_task_push_notification_config
347+
),
348+
)
349+
)
350+
assert response.task_id == sample_task_push_notification_config.task_id
351+
352+
353+
@pytest.mark.asyncio
354+
async def test_set_task_callback_with_invalid_task(
355+
grpc_transport: GrpcTransport,
356+
mock_grpc_stub: AsyncMock,
357+
sample_task_push_notification_config: TaskPushNotificationConfig,
358+
):
359+
"""Test setting a task push notification config with an invalid task id."""
360+
mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig(
361+
name=(
362+
f'invalid-path-to-tasks/{sample_task_push_notification_config.task_id}/'
363+
f'pushNotificationConfigs/{sample_task_push_notification_config.push_notification_config.id}'
364+
),
365+
push_notification_config=proto_utils.ToProto.push_notification_config(
366+
sample_task_push_notification_config.push_notification_config
367+
),
368+
)
369+
370+
with pytest.raises(ServerError) as exc_info:
371+
await grpc_transport.set_task_callback(
372+
sample_task_push_notification_config
373+
)
374+
assert (
375+
'Bad TaskPushNotificationConfig resource name'
376+
in exc_info.value.error.message
377+
)
378+
379+
380+
@pytest.mark.asyncio
381+
async def test_get_task_callback_with_valid_task(
382+
grpc_transport: GrpcTransport,
383+
mock_grpc_stub: AsyncMock,
384+
sample_task_push_notification_config: TaskPushNotificationConfig,
385+
):
386+
"""Test retrieving a task push notification config with a valid task id."""
387+
mock_grpc_stub.GetTaskPushNotificationConfig.return_value = (
388+
proto_utils.ToProto.task_push_notification_config(
389+
sample_task_push_notification_config
390+
)
391+
)
392+
params = GetTaskPushNotificationConfigParams(
393+
id=sample_task_push_notification_config.task_id,
394+
push_notification_config_id=sample_task_push_notification_config.push_notification_config.id,
395+
)
396+
397+
response = await grpc_transport.get_task_callback(params)
398+
399+
mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once_with(
400+
a2a_pb2.GetTaskPushNotificationConfigRequest(
401+
name=(
402+
f'tasks/{params.id}/'
403+
f'pushNotificationConfigs/{params.push_notification_config_id}'
404+
),
405+
)
406+
)
407+
assert response.task_id == sample_task_push_notification_config.task_id
408+
409+
410+
@pytest.mark.asyncio
411+
async def test_get_task_callback_with_invalid_task(
412+
grpc_transport: GrpcTransport,
413+
mock_grpc_stub: AsyncMock,
414+
sample_task_push_notification_config: TaskPushNotificationConfig,
415+
):
416+
"""Test retrieving a task push notification config with an invalid task id."""
417+
mock_grpc_stub.GetTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig(
418+
name=(
419+
f'invalid-path-to-tasks/{sample_task_push_notification_config.task_id}/'
420+
f'pushNotificationConfigs/{sample_task_push_notification_config.push_notification_config.id}'
421+
),
422+
push_notification_config=proto_utils.ToProto.push_notification_config(
423+
sample_task_push_notification_config.push_notification_config
424+
),
425+
)
426+
params = GetTaskPushNotificationConfigParams(
427+
id=sample_task_push_notification_config.task_id,
428+
push_notification_config_id=sample_task_push_notification_config.push_notification_config.id,
429+
)
430+
431+
with pytest.raises(ServerError) as exc_info:
432+
await grpc_transport.get_task_callback(params)
433+
assert (
434+
'Bad TaskPushNotificationConfig resource name'
435+
in exc_info.value.error.message
436+
)

0 commit comments

Comments
 (0)