Skip to content

Commit 5aa9e9c

Browse files
vokrackonicoddemus
andauthored
Add spy_return_iter attribute to spy (#524)
--------- Co-authored-by: Bruno Oliveira <[email protected]>
1 parent dc6df75 commit 5aa9e9c

File tree

4 files changed

+94
-0
lines changed

4 files changed

+94
-0
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Releases
22
========
33

4+
UNRELEASED
5+
----------
6+
7+
* `#524 <https://github.com/pytest-dev/pytest-mock/pull/524>`_: Added ``spy_return_iter`` to ``mocker.spy``, which contains a duplicate of the return value of the spied method if it is an ``Iterator``.
8+
49
3.14.1 (2025-05-26)
510
-------------------
611

docs/usage.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ are available (like ``assert_called_once_with`` or ``call_count`` in the example
8181
In addition, spy objects contain two extra attributes:
8282

8383
* ``spy_return``: contains the last returned value of the spied function.
84+
* ``spy_return_iter``: contains a duplicate of the last returned value of the spied function if the value was an iterator. Uses `tee <https://docs.python.org/3/library/itertools.html#itertools.tee>`__) to duplicate the iterator.
8485
* ``spy_return_list``: contains a list of all returned values of the spied function (new in ``3.13``).
8586
* ``spy_exception``: contain the last exception value raised by the spied function/method when
8687
it was last called, or ``None`` if no exception was raised.

src/pytest_mock/plugin.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import builtins
22
import functools
33
import inspect
4+
import itertools
45
import unittest.mock
56
import warnings
67
from dataclasses import dataclass
@@ -137,6 +138,8 @@ def resetall(
137138
# NOTE: The mock may be a dictionary
138139
if hasattr(mock_item.mock, "spy_return_list"):
139140
mock_item.mock.spy_return_list = []
141+
if hasattr(mock_item.mock, "spy_return_iter"):
142+
mock_item.mock.spy_return_iter = None
140143
if isinstance(mock_item.mock, supports_reset_mock_with_args):
141144
mock_item.mock.reset_mock(
142145
return_value=return_value, side_effect=side_effect
@@ -178,6 +181,12 @@ def wrapper(*args, **kwargs):
178181
spy_obj.spy_exception = e
179182
raise
180183
else:
184+
if isinstance(r, Iterator):
185+
r, duplicated_iterator = itertools.tee(r, 2)
186+
spy_obj.spy_return_iter = duplicated_iterator
187+
else:
188+
spy_obj.spy_return_iter = None
189+
181190
spy_obj.spy_return = r
182191
spy_obj.spy_return_list.append(r)
183192
return r
@@ -204,6 +213,7 @@ async def async_wrapper(*args, **kwargs):
204213

205214
spy_obj = self.patch.object(obj, name, side_effect=wrapped, autospec=autospec)
206215
spy_obj.spy_return = None
216+
spy_obj.spy_return_iter = None
207217
spy_obj.spy_return_list = []
208218
spy_obj.spy_exception = None
209219
return spy_obj

tests/test_pytest_mock.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Any
88
from typing import Callable
99
from typing import Generator
10+
from typing import Iterable
11+
from typing import Iterator
1012
from typing import Tuple
1113
from typing import Type
1214
from unittest.mock import AsyncMock
@@ -265,12 +267,14 @@ def bar(self, arg):
265267
assert other.bar(arg=10) == 20
266268
foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
267269
assert foo.bar.spy_return == 20 # type:ignore[attr-defined]
270+
assert foo.bar.spy_return_iter is None # type:ignore[attr-defined]
268271
assert foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
269272
spy.assert_called_once_with(arg=10)
270273
assert spy.spy_return == 20
271274
assert foo.bar(arg=11) == 22
272275
assert foo.bar(arg=12) == 24
273276
assert spy.spy_return == 24
277+
assert spy.spy_return_iter is None
274278
assert spy.spy_return_list == [20, 22, 24]
275279

276280

@@ -349,11 +353,13 @@ def bar(self, x):
349353

350354
spy = mocker.spy(Foo, "bar")
351355
assert spy.spy_return is None
356+
assert spy.spy_return_iter is None
352357
assert spy.spy_return_list == []
353358
assert spy.spy_exception is None
354359

355360
Foo().bar(10)
356361
assert spy.spy_return == 30
362+
assert spy.spy_return_iter is None
357363
assert spy.spy_return_list == [30]
358364
assert spy.spy_exception is None
359365

@@ -363,11 +369,13 @@ def bar(self, x):
363369
with pytest.raises(ValueError):
364370
Foo().bar(0)
365371
assert spy.spy_return is None
372+
assert spy.spy_return_iter is None
366373
assert spy.spy_return_list == []
367374
assert str(spy.spy_exception) == "invalid x"
368375

369376
Foo().bar(15)
370377
assert spy.spy_return == 45
378+
assert spy.spy_return_iter is None
371379
assert spy.spy_return_list == [45]
372380
assert spy.spy_exception is None
373381

@@ -404,6 +412,7 @@ class Foo(Base):
404412
calls = [mocker.call(foo, arg=10), mocker.call(other, arg=10)]
405413
assert spy.call_args_list == calls
406414
assert spy.spy_return == 20
415+
assert spy.spy_return_iter is None
407416
assert spy.spy_return_list == [20, 20]
408417

409418

@@ -418,9 +427,11 @@ def bar(cls, arg):
418427
assert Foo.bar(arg=10) == 20
419428
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
420429
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
430+
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
421431
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
422432
spy.assert_called_once_with(arg=10)
423433
assert spy.spy_return == 20
434+
assert spy.spy_return_iter is None
424435
assert spy.spy_return_list == [20]
425436

426437

@@ -438,9 +449,11 @@ class Foo(Base):
438449
assert Foo.bar(arg=10) == 20
439450
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
440451
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
452+
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
441453
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
442454
spy.assert_called_once_with(arg=10)
443455
assert spy.spy_return == 20
456+
assert spy.spy_return_iter is None
444457
assert spy.spy_return_list == [20]
445458

446459

@@ -460,9 +473,11 @@ def bar(cls, arg):
460473
assert Foo.bar(arg=10) == 20
461474
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
462475
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
476+
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
463477
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
464478
spy.assert_called_once_with(arg=10)
465479
assert spy.spy_return == 20
480+
assert spy.spy_return_iter is None
466481
assert spy.spy_return_list == [20]
467482

468483

@@ -477,9 +492,11 @@ def bar(arg):
477492
assert Foo.bar(arg=10) == 20
478493
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
479494
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
495+
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
480496
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
481497
spy.assert_called_once_with(arg=10)
482498
assert spy.spy_return == 20
499+
assert spy.spy_return_iter is None
483500
assert spy.spy_return_list == [20]
484501

485502

@@ -497,9 +514,11 @@ class Foo(Base):
497514
assert Foo.bar(arg=10) == 20
498515
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
499516
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
517+
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
500518
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
501519
spy.assert_called_once_with(arg=10)
502520
assert spy.spy_return == 20
521+
assert spy.spy_return_iter is None
503522
assert spy.spy_return_list == [20]
504523

505524

@@ -521,9 +540,68 @@ def __call__(self, x):
521540
uut.call_like(10)
522541
spy.assert_called_once_with(10)
523542
assert spy.spy_return == 20
543+
assert spy.spy_return_iter is None
524544
assert spy.spy_return_list == [20]
525545

526546

547+
@pytest.mark.parametrize("iterator", [(i for i in range(3)), iter([0, 1, 2])])
548+
def test_spy_return_iter(mocker: MockerFixture, iterator: Iterator[int]) -> None:
549+
class Foo:
550+
def bar(self) -> Iterator[int]:
551+
return iterator
552+
553+
foo = Foo()
554+
spy = mocker.spy(foo, "bar")
555+
result = list(foo.bar())
556+
557+
assert result == [0, 1, 2]
558+
assert spy.spy_return is not None
559+
assert spy.spy_return_iter is not None
560+
assert list(spy.spy_return_iter) == result
561+
562+
[return_value] = spy.spy_return_list
563+
assert isinstance(return_value, Iterator)
564+
565+
566+
@pytest.mark.parametrize("iterable", [(0, 1, 2), [0, 1, 2], range(3)])
567+
def test_spy_return_iter_ignore_plain_iterable(
568+
mocker: MockerFixture, iterable: Iterable[int]
569+
) -> None:
570+
class Foo:
571+
def bar(self) -> Iterable[int]:
572+
return iterable
573+
574+
foo = Foo()
575+
spy = mocker.spy(foo, "bar")
576+
result = foo.bar()
577+
578+
assert result == iterable
579+
assert spy.spy_return == result
580+
assert spy.spy_return_iter is None
581+
assert spy.spy_return_list == [result]
582+
583+
584+
def test_spy_return_iter_resets(mocker: MockerFixture) -> None:
585+
class Foo:
586+
iterables: Any = [
587+
(i for i in range(3)),
588+
99,
589+
]
590+
591+
def bar(self) -> Any:
592+
return self.iterables.pop(0)
593+
594+
foo = Foo()
595+
spy = mocker.spy(foo, "bar")
596+
result_iterator = list(foo.bar())
597+
598+
assert result_iterator == [0, 1, 2]
599+
assert list(spy.spy_return_iter) == result_iterator
600+
601+
assert foo.bar() == 99
602+
assert spy.spy_return_iter is None
603+
604+
527605
@pytest.mark.asyncio
528606
async def test_instance_async_method_spy(mocker: MockerFixture) -> None:
529607
class Foo:

0 commit comments

Comments
 (0)