14
14
from types import TracebackType
15
15
from typing import (
16
16
Any ,
17
- AsyncGenerator ,
17
+ AsyncIterator ,
18
18
Awaitable ,
19
19
Callable ,
20
20
Coroutine ,
@@ -125,7 +125,7 @@ class Will:
125
125
126
126
127
127
class Client :
128
- """The async context manager that manages the connection to the broker.
128
+ """Asynchronous context manager for the connection to the MQTT broker.
129
129
130
130
Args:
131
131
hostname: The hostname or IP address of the remote broker.
@@ -320,10 +320,6 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
320
320
timeout = 10
321
321
self .timeout = timeout
322
322
323
- @property
324
- def messages (self ) -> AsyncGenerator [Message , None ]:
325
- return self ._messages ()
326
-
327
323
@property
328
324
def identifier (self ) -> str :
329
325
"""Return the client identifier.
@@ -333,6 +329,42 @@ def identifier(self) -> str:
333
329
"""
334
330
return self ._client ._client_id .decode () # noqa: SLF001
335
331
332
+ class MessagesIterator :
333
+ """Dynamic view of the message queue."""
334
+
335
+ def __init__ (self , client : Client ) -> None :
336
+ self ._client = client
337
+
338
+ def __aiter__ (self ) -> AsyncIterator [Message ]:
339
+ return self
340
+
341
+ async def __anext__ (self ) -> Message :
342
+ # Wait until we either (1) receive a message or (2) disconnect
343
+ task = self ._client ._loop .create_task (self ._client ._queue .get ()) # noqa: SLF001
344
+ try :
345
+ done , _ = await asyncio .wait (
346
+ (task , self ._client ._disconnected ), # noqa: SLF001
347
+ return_when = asyncio .FIRST_COMPLETED ,
348
+ )
349
+ # If the asyncio.wait is cancelled, we must also cancel the queue task
350
+ except asyncio .CancelledError :
351
+ task .cancel ()
352
+ raise
353
+ # When we receive a message, return it
354
+ if task in done :
355
+ return task .result ()
356
+ # If we disconnect from the broker, stop the generator with an exception
357
+ task .cancel ()
358
+ msg = "Disconnected during message iteration"
359
+ raise MqttError (msg )
360
+
361
+ def __len__ (self ) -> int :
362
+ return self ._client ._queue .qsize () # noqa: SLF001
363
+
364
+ @property
365
+ def messages (self ) -> MessagesIterator :
366
+ return self .MessagesIterator (self )
367
+
336
368
@property
337
369
def _pending_calls (self ) -> Generator [int , None , None ]:
338
370
"""Yield all message IDs with pending calls."""
@@ -456,32 +488,6 @@ async def publish( # noqa: PLR0913
456
488
# Wait for confirmation
457
489
await self ._wait_for (confirmation .wait (), timeout = timeout )
458
490
459
- async def _messages (self ) -> AsyncGenerator [Message , None ]:
460
- """Async generator that yields messages from the underlying message queue."""
461
- while True :
462
- # Wait until we either:
463
- # 1. Receive a message
464
- # 2. Disconnect from the broker
465
- task = self ._loop .create_task (self ._queue .get ())
466
- try :
467
- done , _ = await asyncio .wait (
468
- (task , self ._disconnected ), return_when = asyncio .FIRST_COMPLETED
469
- )
470
- except asyncio .CancelledError :
471
- # If the asyncio.wait is cancelled, we must make sure
472
- # to also cancel the underlying tasks.
473
- task .cancel ()
474
- raise
475
- if task in done :
476
- # We received a message. Return the result.
477
- yield task .result ()
478
- else :
479
- # We were disconnected from the broker
480
- task .cancel ()
481
- # Stop the generator with an exception
482
- msg = "Disconnected during message iteration"
483
- raise MqttError (msg )
484
-
485
491
async def _wait_for (
486
492
self , fut : Awaitable [T ], timeout : float | None , ** kwargs : Any
487
493
) -> T :
0 commit comments