Skip to content

Commit 5a636b4

Browse files
committed
refactor: Replace messages generator with iterator class that implements len()
1 parent db8418c commit 5a636b4

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

aiomqtt/client.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from types import TracebackType
1515
from typing import (
1616
Any,
17-
AsyncGenerator,
17+
AsyncIterator,
1818
Awaitable,
1919
Callable,
2020
Coroutine,
@@ -125,7 +125,7 @@ class Will:
125125

126126

127127
class Client:
128-
"""The async context manager that manages the connection to the broker.
128+
"""Asynchronous context manager for the connection to the MQTT broker.
129129
130130
Args:
131131
hostname: The hostname or IP address of the remote broker.
@@ -320,10 +320,6 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
320320
timeout = 10
321321
self.timeout = timeout
322322

323-
@property
324-
def messages(self) -> AsyncGenerator[Message, None]:
325-
return self._messages()
326-
327323
@property
328324
def identifier(self) -> str:
329325
"""Return the client identifier.
@@ -333,6 +329,42 @@ def identifier(self) -> str:
333329
"""
334330
return self._client._client_id.decode() # noqa: SLF001
335331

332+
class MessagesIterator:
333+
"""Dynamic view of the message queue."""
334+
335+
def __init__(self, client: Client) -> None:
336+
self._client = client
337+
338+
def __aiter__(self) -> AsyncIterator[Message]:
339+
return self
340+
341+
async def __anext__(self) -> Message:
342+
# Wait until we either (1) receive a message or (2) disconnect
343+
task = self._client._loop.create_task(self._client._queue.get()) # noqa: SLF001
344+
try:
345+
done, _ = await asyncio.wait(
346+
(task, self._client._disconnected), # noqa: SLF001
347+
return_when=asyncio.FIRST_COMPLETED,
348+
)
349+
# If the asyncio.wait is cancelled, we must also cancel the queue task
350+
except asyncio.CancelledError:
351+
task.cancel()
352+
raise
353+
# When we receive a message, return it
354+
if task in done:
355+
return task.result()
356+
# If we disconnect from the broker, stop the generator with an exception
357+
task.cancel()
358+
msg = "Disconnected during message iteration"
359+
raise MqttError(msg)
360+
361+
def __len__(self) -> int:
362+
return self._client._queue.qsize() # noqa: SLF001
363+
364+
@property
365+
def messages(self) -> MessagesIterator:
366+
return self.MessagesIterator(self)
367+
336368
@property
337369
def _pending_calls(self) -> Generator[int, None, None]:
338370
"""Yield all message IDs with pending calls."""
@@ -456,32 +488,6 @@ async def publish( # noqa: PLR0913
456488
# Wait for confirmation
457489
await self._wait_for(confirmation.wait(), timeout=timeout)
458490

459-
async def _messages(self) -> AsyncGenerator[Message, None]:
460-
"""Async generator that yields messages from the underlying message queue."""
461-
while True:
462-
# Wait until we either:
463-
# 1. Receive a message
464-
# 2. Disconnect from the broker
465-
task = self._loop.create_task(self._queue.get())
466-
try:
467-
done, _ = await asyncio.wait(
468-
(task, self._disconnected), return_when=asyncio.FIRST_COMPLETED
469-
)
470-
except asyncio.CancelledError:
471-
# If the asyncio.wait is cancelled, we must make sure
472-
# to also cancel the underlying tasks.
473-
task.cancel()
474-
raise
475-
if task in done:
476-
# We received a message. Return the result.
477-
yield task.result()
478-
else:
479-
# We were disconnected from the broker
480-
task.cancel()
481-
# Stop the generator with an exception
482-
msg = "Disconnected during message iteration"
483-
raise MqttError(msg)
484-
485491
async def _wait_for(
486492
self, fut: Awaitable[T], timeout: float | None, **kwargs: Any
487493
) -> T:

tests/test_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import logging
45
import pathlib
56
import ssl
67
import sys
8+
from typing import Any
79

810
import anyio
911
import anyio.abc
@@ -413,7 +415,7 @@ async def test_messages_view_is_reusable() -> None:
413415
@pytest.mark.network
414416
async def test_messages_view_multiple_tasks_concurrently() -> None:
415417
"""Test that ``.messages`` can be used concurrently by multiple tasks."""
416-
topic = TOPIC_PREFIX + "test_messages_generator_is_reentrant"
418+
topic = TOPIC_PREFIX + "test_messages_view_multiple_tasks_concurrently"
417419
async with Client(HOSTNAME) as client, anyio.create_task_group() as tg:
418420

419421
async def handle() -> None:

0 commit comments

Comments
 (0)