Skip to content

Commit 4d36df3

Browse files
committed
Accept generic ExceptionGroups for raises
Closes #13115
1 parent 517b006 commit 4d36df3

File tree

4 files changed

+73
-5
lines changed

4 files changed

+73
-5
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ Tim Hoffmann
434434
Tim Strazny
435435
TJ Bruno
436436
Tobias Diez
437+
Tobias Petersen
437438
Tom Dalton
438439
Tom Viner
439440
Tomáš Gavenčiak

changelog/13115.improvement.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allows supplying ``ExceptionGroup[Exception]`` and ``BaseExceptionGroup[BaseException]`` to ``pytest.raises`` to keep full typing on ExcInfo.

src/_pytest/python_api.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
from numbers import Complex
1313
import pprint
1414
import re
15+
import sys
1516
from types import TracebackType
1617
from typing import Any
1718
from typing import cast
1819
from typing import final
20+
from typing import get_args
21+
from typing import get_origin
1922
from typing import overload
2023
from typing import TYPE_CHECKING
2124
from typing import TypeVar
@@ -24,6 +27,10 @@
2427
from _pytest.outcomes import fail
2528

2629

30+
if sys.version_info[:2] < (3, 11):
31+
from exceptiongroup import BaseExceptionGroup
32+
from exceptiongroup import ExceptionGroup
33+
2734
if TYPE_CHECKING:
2835
from numpy import ndarray
2936

@@ -786,6 +793,10 @@ def _as_numpy_array(obj: object) -> ndarray | None:
786793

787794
E = TypeVar("E", bound=BaseException)
788795

796+
# This will be `typing_GenericAlias` in the backport as opposed to
797+
# `types.GenericAlias` for native ExceptionGroup. The cast is to not confuse mypy
798+
_generic_exc_group_type = cast(type, type(ExceptionGroup[Exception]))
799+
789800

790801
@overload
791802
def raises(
@@ -954,15 +965,43 @@ def raises(
954965
f"Raising exceptions is already understood as failing the test, so you don't need "
955966
f"any special code to say 'this should never raise an exception'."
956967
)
968+
969+
expected_exceptions: tuple[type[E], ...]
970+
origin_exc: type[E] | None = get_origin(expected_exception)
957971
if isinstance(expected_exception, type):
958-
expected_exceptions: tuple[type[E], ...] = (expected_exception,)
972+
expected_exceptions = (expected_exception,)
973+
elif origin_exc and issubclass(origin_exc, BaseExceptionGroup):
974+
expected_exceptions = (cast(type[E], expected_exception),)
959975
else:
960976
expected_exceptions = expected_exception
961-
for exc in expected_exceptions:
962-
if not isinstance(exc, type) or not issubclass(exc, BaseException):
977+
978+
def validate_exc(exc: type[E]) -> type[E]:
979+
origin_exc: type[E] | None = get_origin(exc)
980+
if origin_exc and issubclass(origin_exc, BaseExceptionGroup):
981+
exc_type = get_args(exc)[0]
982+
if issubclass(origin_exc, ExceptionGroup) and exc_type is Exception:
983+
return cast(type[E], origin_exc)
984+
elif (
985+
issubclass(origin_exc, BaseExceptionGroup) and exc_type is BaseException
986+
):
987+
return cast(type[E], origin_exc)
988+
else:
989+
raise ValueError(
990+
f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` "
991+
f"are accepted as generic types but got `{exc}`. "
992+
f"`raises` will catch all instances of the base-type regardless so the "
993+
f"returned type will be wider regardless and has to be checked "
994+
f"with `ExceptionInfo.group_contains()`"
995+
)
996+
997+
elif not isinstance(exc, type) or not issubclass(exc, BaseException):
963998
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
964999
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
9651000
raise TypeError(msg.format(not_a))
1001+
else:
1002+
return exc
1003+
1004+
expected_exceptions = tuple(validate_exc(exc) for exc in expected_exceptions)
9661005

9671006
message = f"DID NOT RAISE {expected_exception}"
9681007

@@ -973,14 +1012,14 @@ def raises(
9731012
msg += ", ".join(sorted(kwargs))
9741013
msg += "\nUse context-manager form instead?"
9751014
raise TypeError(msg)
976-
return RaisesContext(expected_exception, message, match)
1015+
return RaisesContext(expected_exceptions, message, match)
9771016
else:
9781017
func = args[0]
9791018
if not callable(func):
9801019
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
9811020
try:
9821021
func(*args[1:], **kwargs)
983-
except expected_exception as e:
1022+
except expected_exceptions as e:
9841023
return _pytest._code.ExceptionInfo.from_exception(e)
9851024
fail(message)
9861025

testing/code/test_excinfo.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from _pytest._code.code import TracebackStyle
3232

3333
if sys.version_info < (3, 11):
34+
from exceptiongroup import BaseExceptionGroup
3435
from exceptiongroup import ExceptionGroup
3536

3637

@@ -453,6 +454,32 @@ def test_division_zero():
453454
result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match])
454455

455456

457+
def test_raises_accepts_generic_group() -> None:
458+
exc_group = ExceptionGroup("", [RuntimeError()])
459+
with pytest.raises(ExceptionGroup[Exception]) as exc_info:
460+
raise exc_group
461+
assert exc_info.group_contains(RuntimeError)
462+
463+
464+
def test_raises_accepts_generic_base_group() -> None:
465+
exc_group = ExceptionGroup("", [RuntimeError()])
466+
with pytest.raises(BaseExceptionGroup[BaseException]) as exc_info:
467+
raise exc_group
468+
assert exc_info.group_contains(RuntimeError)
469+
470+
471+
def test_raises_rejects_specific_generic_group() -> None:
472+
with pytest.raises(ValueError):
473+
pytest.raises(ExceptionGroup[RuntimeError])
474+
475+
476+
def test_raises_accepts_generic_group_in_tuple() -> None:
477+
exc_group = ExceptionGroup("", [RuntimeError()])
478+
with pytest.raises((ValueError, ExceptionGroup[Exception])) as exc_info:
479+
raise exc_group
480+
assert exc_info.group_contains(RuntimeError)
481+
482+
456483
class TestGroupContains:
457484
def test_contains_exception_type(self) -> None:
458485
exc_group = ExceptionGroup("", [RuntimeError()])

0 commit comments

Comments
 (0)