|
1 |
| -from unittest.mock import AsyncMock |
| 1 | +from unittest.mock import AsyncMock, MagicMock |
2 | 2 |
|
| 3 | +import grpc |
3 | 4 | import pytest
|
4 | 5 |
|
5 | 6 | from a2a.client.transports.grpc import GrpcTransport
|
6 | 7 | from a2a.grpc import a2a_pb2, a2a_pb2_grpc
|
7 | 8 | from a2a.types import (
|
8 | 9 | AgentCapabilities,
|
9 | 10 | AgentCard,
|
| 11 | + Artifact, |
| 12 | + GetTaskPushNotificationConfigParams, |
10 | 13 | Message,
|
11 | 14 | MessageSendParams,
|
12 | 15 | Part,
|
| 16 | + PushNotificationAuthenticationInfo, |
| 17 | + PushNotificationConfig, |
13 | 18 | Role,
|
14 | 19 | Task,
|
| 20 | + TaskArtifactUpdateEvent, |
15 | 21 | TaskIdParams,
|
| 22 | + TaskPushNotificationConfig, |
16 | 23 | TaskQueryParams,
|
17 | 24 | TaskState,
|
18 | 25 | TaskStatus,
|
| 26 | + TaskStatusUpdateEvent, |
19 | 27 | TextPart,
|
20 | 28 | )
|
21 | 29 | from a2a.utils import get_text_parts, proto_utils
|
| 30 | +from a2a.utils.errors import ServerError |
22 | 31 |
|
23 | 32 |
|
24 |
| -# Fixtures |
25 | 33 | @pytest.fixture
|
26 | 34 | def mock_grpc_stub() -> AsyncMock:
|
27 | 35 | """Provides a mock gRPC stub with methods mocked."""
|
28 | 36 | stub = AsyncMock(spec=a2a_pb2_grpc.A2AServiceStub)
|
29 | 37 | stub.SendMessage = AsyncMock()
|
30 |
| - stub.SendStreamingMessage = AsyncMock() |
| 38 | + stub.SendStreamingMessage = MagicMock() |
31 | 39 | stub.GetTask = AsyncMock()
|
32 | 40 | stub.CancelTask = AsyncMock()
|
33 |
| - stub.CreateTaskPushNotification = AsyncMock() |
34 |
| - stub.GetTaskPushNotification = AsyncMock() |
| 41 | + stub.CreateTaskPushNotificationConfig = AsyncMock() |
| 42 | + stub.GetTaskPushNotificationConfig = AsyncMock() |
35 | 43 | return stub
|
36 | 44 |
|
37 | 45 |
|
@@ -93,6 +101,78 @@ def sample_message() -> Message:
|
93 | 101 | )
|
94 | 102 |
|
95 | 103 |
|
| 104 | +@pytest.fixture |
| 105 | +def sample_artifact() -> Artifact: |
| 106 | + """Provides a sample Artifact object.""" |
| 107 | + return Artifact( |
| 108 | + artifact_id='artifact-1', |
| 109 | + name='example.txt', |
| 110 | + description='An example artifact', |
| 111 | + parts=[Part(root=TextPart(text='Hi there'))], |
| 112 | + metadata={}, |
| 113 | + extensions=[], |
| 114 | + ) |
| 115 | + |
| 116 | + |
| 117 | +@pytest.fixture |
| 118 | +def sample_task_status_update_event() -> TaskStatusUpdateEvent: |
| 119 | + """Provides a sample TaskStatusUpdateEvent.""" |
| 120 | + return TaskStatusUpdateEvent( |
| 121 | + task_id='task-1', |
| 122 | + context_id='ctx-1', |
| 123 | + status=TaskStatus(state=TaskState.working), |
| 124 | + final=False, |
| 125 | + metadata={}, |
| 126 | + ) |
| 127 | + |
| 128 | + |
| 129 | +@pytest.fixture |
| 130 | +def sample_task_artifact_update_event( |
| 131 | + sample_artifact, |
| 132 | +) -> TaskArtifactUpdateEvent: |
| 133 | + """Provides a sample TaskArtifactUpdateEvent.""" |
| 134 | + return TaskArtifactUpdateEvent( |
| 135 | + task_id='task-1', |
| 136 | + context_id='ctx-1', |
| 137 | + artifact=sample_artifact, |
| 138 | + append=True, |
| 139 | + last_chunk=True, |
| 140 | + metadata={}, |
| 141 | + ) |
| 142 | + |
| 143 | + |
| 144 | +@pytest.fixture |
| 145 | +def sample_authentication_info() -> PushNotificationAuthenticationInfo: |
| 146 | + """Provides a sample AuthenticationInfo object.""" |
| 147 | + return PushNotificationAuthenticationInfo( |
| 148 | + schemes=['apikey', 'oauth2'], credentials='secret-token' |
| 149 | + ) |
| 150 | + |
| 151 | + |
| 152 | +@pytest.fixture |
| 153 | +def sample_push_notification_config( |
| 154 | + sample_authentication_info: PushNotificationAuthenticationInfo, |
| 155 | +) -> PushNotificationConfig: |
| 156 | + """Provides a sample PushNotificationConfig object.""" |
| 157 | + return PushNotificationConfig( |
| 158 | + id='config-1', |
| 159 | + url='https://example.com/notify', |
| 160 | + token='example-token', |
| 161 | + authentication=sample_authentication_info, |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +@pytest.fixture |
| 166 | +def sample_task_push_notification_config( |
| 167 | + sample_push_notification_config: PushNotificationConfig, |
| 168 | +) -> TaskPushNotificationConfig: |
| 169 | + """Provides a sample TaskPushNotificationConfig object.""" |
| 170 | + return TaskPushNotificationConfig( |
| 171 | + task_id='task-1', |
| 172 | + push_notification_config=sample_push_notification_config, |
| 173 | + ) |
| 174 | + |
| 175 | + |
96 | 176 | @pytest.mark.asyncio
|
97 | 177 | async def test_send_message_task_response(
|
98 | 178 | grpc_transport: GrpcTransport,
|
@@ -134,6 +214,57 @@ async def test_send_message_message_response(
|
134 | 214 | )
|
135 | 215 |
|
136 | 216 |
|
| 217 | +@pytest.mark.asyncio |
| 218 | +async def test_send_message_streaming( # noqa: PLR0913 |
| 219 | + grpc_transport: GrpcTransport, |
| 220 | + mock_grpc_stub: AsyncMock, |
| 221 | + sample_message_send_params: MessageSendParams, |
| 222 | + sample_message: Message, |
| 223 | + sample_task: Task, |
| 224 | + sample_task_status_update_event: TaskStatusUpdateEvent, |
| 225 | + sample_task_artifact_update_event: TaskArtifactUpdateEvent, |
| 226 | +): |
| 227 | + """Test send_message_streaming that yields responses.""" |
| 228 | + stream = MagicMock() |
| 229 | + stream.read = AsyncMock( |
| 230 | + side_effect=[ |
| 231 | + a2a_pb2.StreamResponse( |
| 232 | + msg=proto_utils.ToProto.message(sample_message) |
| 233 | + ), |
| 234 | + a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(sample_task)), |
| 235 | + a2a_pb2.StreamResponse( |
| 236 | + status_update=proto_utils.ToProto.task_status_update_event( |
| 237 | + sample_task_status_update_event |
| 238 | + ) |
| 239 | + ), |
| 240 | + a2a_pb2.StreamResponse( |
| 241 | + artifact_update=proto_utils.ToProto.task_artifact_update_event( |
| 242 | + sample_task_artifact_update_event |
| 243 | + ) |
| 244 | + ), |
| 245 | + grpc.aio.EOF, |
| 246 | + ] |
| 247 | + ) |
| 248 | + mock_grpc_stub.SendStreamingMessage.return_value = stream |
| 249 | + |
| 250 | + responses = [ |
| 251 | + response |
| 252 | + async for response in grpc_transport.send_message_streaming( |
| 253 | + sample_message_send_params |
| 254 | + ) |
| 255 | + ] |
| 256 | + |
| 257 | + mock_grpc_stub.SendStreamingMessage.assert_called_once() |
| 258 | + assert isinstance(responses[0], Message) |
| 259 | + assert responses[0].message_id == sample_message.message_id |
| 260 | + assert isinstance(responses[1], Task) |
| 261 | + assert responses[1].id == sample_task.id |
| 262 | + assert isinstance(responses[2], TaskStatusUpdateEvent) |
| 263 | + assert responses[2].task_id == sample_task_status_update_event.task_id |
| 264 | + assert isinstance(responses[3], TaskArtifactUpdateEvent) |
| 265 | + assert responses[3].task_id == sample_task_artifact_update_event.task_id |
| 266 | + |
| 267 | + |
137 | 268 | @pytest.mark.asyncio
|
138 | 269 | async def test_get_task(
|
139 | 270 | grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task
|
@@ -188,3 +319,118 @@ async def test_cancel_task(
|
188 | 319 | a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}')
|
189 | 320 | )
|
190 | 321 | assert response.status.state == TaskState.canceled
|
| 322 | + |
| 323 | + |
| 324 | +@pytest.mark.asyncio |
| 325 | +async def test_set_task_callback_with_valid_task( |
| 326 | + grpc_transport: GrpcTransport, |
| 327 | + mock_grpc_stub: AsyncMock, |
| 328 | + sample_task_push_notification_config: TaskPushNotificationConfig, |
| 329 | +): |
| 330 | + """Test setting a task push notification config with a valid task id.""" |
| 331 | + mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = ( |
| 332 | + proto_utils.ToProto.task_push_notification_config( |
| 333 | + sample_task_push_notification_config |
| 334 | + ) |
| 335 | + ) |
| 336 | + |
| 337 | + response = await grpc_transport.set_task_callback( |
| 338 | + sample_task_push_notification_config |
| 339 | + ) |
| 340 | + |
| 341 | + mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once_with( |
| 342 | + a2a_pb2.CreateTaskPushNotificationConfigRequest( |
| 343 | + parent=f'tasks/{sample_task_push_notification_config.task_id}', |
| 344 | + config_id=sample_task_push_notification_config.push_notification_config.id, |
| 345 | + config=proto_utils.ToProto.task_push_notification_config( |
| 346 | + sample_task_push_notification_config |
| 347 | + ), |
| 348 | + ) |
| 349 | + ) |
| 350 | + assert response.task_id == sample_task_push_notification_config.task_id |
| 351 | + |
| 352 | + |
| 353 | +@pytest.mark.asyncio |
| 354 | +async def test_set_task_callback_with_invalid_task( |
| 355 | + grpc_transport: GrpcTransport, |
| 356 | + mock_grpc_stub: AsyncMock, |
| 357 | + sample_task_push_notification_config: TaskPushNotificationConfig, |
| 358 | +): |
| 359 | + """Test setting a task push notification config with an invalid task id.""" |
| 360 | + mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig( |
| 361 | + name=( |
| 362 | + f'invalid-path-to-tasks/{sample_task_push_notification_config.task_id}/' |
| 363 | + f'pushNotificationConfigs/{sample_task_push_notification_config.push_notification_config.id}' |
| 364 | + ), |
| 365 | + push_notification_config=proto_utils.ToProto.push_notification_config( |
| 366 | + sample_task_push_notification_config.push_notification_config |
| 367 | + ), |
| 368 | + ) |
| 369 | + |
| 370 | + with pytest.raises(ServerError) as exc_info: |
| 371 | + await grpc_transport.set_task_callback( |
| 372 | + sample_task_push_notification_config |
| 373 | + ) |
| 374 | + assert ( |
| 375 | + 'Bad TaskPushNotificationConfig resource name' |
| 376 | + in exc_info.value.error.message |
| 377 | + ) |
| 378 | + |
| 379 | + |
| 380 | +@pytest.mark.asyncio |
| 381 | +async def test_get_task_callback_with_valid_task( |
| 382 | + grpc_transport: GrpcTransport, |
| 383 | + mock_grpc_stub: AsyncMock, |
| 384 | + sample_task_push_notification_config: TaskPushNotificationConfig, |
| 385 | +): |
| 386 | + """Test retrieving a task push notification config with a valid task id.""" |
| 387 | + mock_grpc_stub.GetTaskPushNotificationConfig.return_value = ( |
| 388 | + proto_utils.ToProto.task_push_notification_config( |
| 389 | + sample_task_push_notification_config |
| 390 | + ) |
| 391 | + ) |
| 392 | + params = GetTaskPushNotificationConfigParams( |
| 393 | + id=sample_task_push_notification_config.task_id, |
| 394 | + push_notification_config_id=sample_task_push_notification_config.push_notification_config.id, |
| 395 | + ) |
| 396 | + |
| 397 | + response = await grpc_transport.get_task_callback(params) |
| 398 | + |
| 399 | + mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once_with( |
| 400 | + a2a_pb2.GetTaskPushNotificationConfigRequest( |
| 401 | + name=( |
| 402 | + f'tasks/{params.id}/' |
| 403 | + f'pushNotificationConfigs/{params.push_notification_config_id}' |
| 404 | + ), |
| 405 | + ) |
| 406 | + ) |
| 407 | + assert response.task_id == sample_task_push_notification_config.task_id |
| 408 | + |
| 409 | + |
| 410 | +@pytest.mark.asyncio |
| 411 | +async def test_get_task_callback_with_invalid_task( |
| 412 | + grpc_transport: GrpcTransport, |
| 413 | + mock_grpc_stub: AsyncMock, |
| 414 | + sample_task_push_notification_config: TaskPushNotificationConfig, |
| 415 | +): |
| 416 | + """Test retrieving a task push notification config with an invalid task id.""" |
| 417 | + mock_grpc_stub.GetTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig( |
| 418 | + name=( |
| 419 | + f'invalid-path-to-tasks/{sample_task_push_notification_config.task_id}/' |
| 420 | + f'pushNotificationConfigs/{sample_task_push_notification_config.push_notification_config.id}' |
| 421 | + ), |
| 422 | + push_notification_config=proto_utils.ToProto.push_notification_config( |
| 423 | + sample_task_push_notification_config.push_notification_config |
| 424 | + ), |
| 425 | + ) |
| 426 | + params = GetTaskPushNotificationConfigParams( |
| 427 | + id=sample_task_push_notification_config.task_id, |
| 428 | + push_notification_config_id=sample_task_push_notification_config.push_notification_config.id, |
| 429 | + ) |
| 430 | + |
| 431 | + with pytest.raises(ServerError) as exc_info: |
| 432 | + await grpc_transport.get_task_callback(params) |
| 433 | + assert ( |
| 434 | + 'Bad TaskPushNotificationConfig resource name' |
| 435 | + in exc_info.value.error.message |
| 436 | + ) |
0 commit comments