Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Releases
========

UNRELEASED
----------

* `#524 <https://github.com/pytest-dev/pytest-mock/pull/524>`_: Added ``spy_return_iter`` to ``mocker.spy``, which contains a duplicate of the return value of the spied method if it is an ``Iterator``.

3.14.1 (2025-05-26)
-------------------

Expand Down
1 change: 1 addition & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ are available (like ``assert_called_once_with`` or ``call_count`` in the example
In addition, spy objects contain two extra attributes:

* ``spy_return``: contains the last returned value of the spied function.
* ``spy_return_iter``: contains a duplicate of the last returned value of the spied function if the value was an iterator. Uses `tee <https://docs.python.org/3/library/itertools.html#itertools.tee>`__) to duplicate the iterator.
* ``spy_return_list``: contains a list of all returned values of the spied function (new in ``3.13``).
* ``spy_exception``: contain the last exception value raised by the spied function/method when
it was last called, or ``None`` if no exception was raised.
Expand Down
10 changes: 10 additions & 0 deletions src/pytest_mock/plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import builtins
import functools
import inspect
import itertools
import unittest.mock
import warnings
from dataclasses import dataclass
Expand Down Expand Up @@ -137,6 +138,8 @@ def resetall(
# NOTE: The mock may be a dictionary
if hasattr(mock_item.mock, "spy_return_list"):
mock_item.mock.spy_return_list = []
if hasattr(mock_item.mock, "spy_return_iter"):
mock_item.mock.spy_return_iter = None
if isinstance(mock_item.mock, supports_reset_mock_with_args):
mock_item.mock.reset_mock(
return_value=return_value, side_effect=side_effect
Expand Down Expand Up @@ -178,6 +181,12 @@ def wrapper(*args, **kwargs):
spy_obj.spy_exception = e
raise
else:
if isinstance(r, Iterator):
r, duplicated_iterator = itertools.tee(r, 2)
spy_obj.spy_return_iter = duplicated_iterator
else:
spy_obj.spy_return_iter = None

spy_obj.spy_return = r
spy_obj.spy_return_list.append(r)
return r
Expand All @@ -204,6 +213,7 @@ async def async_wrapper(*args, **kwargs):

spy_obj = self.patch.object(obj, name, side_effect=wrapped, autospec=autospec)
spy_obj.spy_return = None
spy_obj.spy_return_iter = None
spy_obj.spy_return_list = []
spy_obj.spy_exception = None
return spy_obj
Expand Down
78 changes: 78 additions & 0 deletions tests/test_pytest_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Any
from typing import Callable
from typing import Generator
from typing import Iterable
from typing import Iterator
from typing import Tuple
from typing import Type
from unittest.mock import AsyncMock
Expand Down Expand Up @@ -265,12 +267,14 @@ def bar(self, arg):
assert other.bar(arg=10) == 20
foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert foo.bar.spy_return_iter is None # type:ignore[attr-defined]
assert foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert foo.bar(arg=11) == 22
assert foo.bar(arg=12) == 24
assert spy.spy_return == 24
assert spy.spy_return_iter is None
assert spy.spy_return_list == [20, 22, 24]


Expand Down Expand Up @@ -349,11 +353,13 @@ def bar(self, x):

spy = mocker.spy(Foo, "bar")
assert spy.spy_return is None
assert spy.spy_return_iter is None
assert spy.spy_return_list == []
assert spy.spy_exception is None

Foo().bar(10)
assert spy.spy_return == 30
assert spy.spy_return_iter is None
assert spy.spy_return_list == [30]
assert spy.spy_exception is None

Expand All @@ -363,11 +369,13 @@ def bar(self, x):
with pytest.raises(ValueError):
Foo().bar(0)
assert spy.spy_return is None
assert spy.spy_return_iter is None
assert spy.spy_return_list == []
assert str(spy.spy_exception) == "invalid x"

Foo().bar(15)
assert spy.spy_return == 45
assert spy.spy_return_iter is None
assert spy.spy_return_list == [45]
assert spy.spy_exception is None

Expand Down Expand Up @@ -404,6 +412,7 @@ class Foo(Base):
calls = [mocker.call(foo, arg=10), mocker.call(other, arg=10)]
assert spy.call_args_list == calls
assert spy.spy_return == 20
assert spy.spy_return_iter is None
assert spy.spy_return_list == [20, 20]


Expand All @@ -418,9 +427,11 @@ def bar(cls, arg):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_iter is None
assert spy.spy_return_list == [20]


Expand All @@ -438,9 +449,11 @@ class Foo(Base):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_iter is None
assert spy.spy_return_list == [20]


Expand All @@ -460,9 +473,11 @@ def bar(cls, arg):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_iter is None
assert spy.spy_return_list == [20]


Expand All @@ -477,9 +492,11 @@ def bar(arg):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_iter is None
assert spy.spy_return_list == [20]


Expand All @@ -497,9 +514,11 @@ class Foo(Base):
assert Foo.bar(arg=10) == 20
Foo.bar.assert_called_once_with(arg=10) # type:ignore[attr-defined]
assert Foo.bar.spy_return == 20 # type:ignore[attr-defined]
assert Foo.bar.spy_return_iter is None # type:ignore[attr-defined]
assert Foo.bar.spy_return_list == [20] # type:ignore[attr-defined]
spy.assert_called_once_with(arg=10)
assert spy.spy_return == 20
assert spy.spy_return_iter is None
assert spy.spy_return_list == [20]


Expand All @@ -521,9 +540,68 @@ def __call__(self, x):
uut.call_like(10)
spy.assert_called_once_with(10)
assert spy.spy_return == 20
assert spy.spy_return_iter is None
assert spy.spy_return_list == [20]


@pytest.mark.parametrize("iterator", [(i for i in range(3)), iter([0, 1, 2])])
def test_spy_return_iter(mocker: MockerFixture, iterator: Iterator[int]) -> None:
class Foo:
def bar(self) -> Iterator[int]:
return iterator

foo = Foo()
spy = mocker.spy(foo, "bar")
result = list(foo.bar())

assert result == [0, 1, 2]
assert spy.spy_return is not None
assert spy.spy_return_iter is not None
assert list(spy.spy_return_iter) == result

[return_value] = spy.spy_return_list
assert isinstance(return_value, Iterator)


@pytest.mark.parametrize("iterable", [(0, 1, 2), [0, 1, 2], range(3)])
def test_spy_return_iter_ignore_plain_iterable(
mocker: MockerFixture, iterable: Iterable[int]
) -> None:
class Foo:
def bar(self) -> Iterable[int]:
return iterable

foo = Foo()
spy = mocker.spy(foo, "bar")
result = foo.bar()

assert result == iterable
assert spy.spy_return == result
assert spy.spy_return_iter is None
assert spy.spy_return_list == [result]


def test_spy_return_iter_resets(mocker: MockerFixture) -> None:
class Foo:
iterables: Any = [
(i for i in range(3)),
99,
]

def bar(self) -> Any:
return self.iterables.pop(0)

foo = Foo()
spy = mocker.spy(foo, "bar")
result_iterator = list(foo.bar())

assert result_iterator == [0, 1, 2]
assert list(spy.spy_return_iter) == result_iterator

assert foo.bar() == 99
assert spy.spy_return_iter is None


@pytest.mark.asyncio
async def test_instance_async_method_spy(mocker: MockerFixture) -> None:
class Foo:
Expand Down
Loading