12
12
from numbers import Complex
13
13
import pprint
14
14
import re
15
+ import sys
15
16
from types import TracebackType
16
17
from typing import Any
17
18
from typing import cast
18
19
from typing import final
20
+ from typing import get_args
21
+ from typing import get_origin
19
22
from typing import overload
20
23
from typing import TYPE_CHECKING
21
24
from typing import TypeVar
24
27
from _pytest .outcomes import fail
25
28
26
29
30
+ if sys .version_info [:2 ] < (3 , 11 ):
31
+ from exceptiongroup import BaseExceptionGroup
32
+ from exceptiongroup import ExceptionGroup
33
+
27
34
if TYPE_CHECKING :
28
35
from numpy import ndarray
29
36
@@ -786,6 +793,10 @@ def _as_numpy_array(obj: object) -> ndarray | None:
786
793
787
794
E = TypeVar ("E" , bound = BaseException )
788
795
796
+ # This will be `typing_GenericAlias` in the backport as opposed to
797
+ # `types.GenericAlias` for native ExceptionGroup. The cast is to not confuse mypy
798
+ _generic_exc_group_type = cast (type , type (ExceptionGroup [Exception ]))
799
+
789
800
790
801
@overload
791
802
def raises (
@@ -954,15 +965,43 @@ def raises(
954
965
f"Raising exceptions is already understood as failing the test, so you don't need "
955
966
f"any special code to say 'this should never raise an exception'."
956
967
)
968
+
969
+ expected_exceptions : tuple [type [E ], ...]
970
+ origin_exc : type [E ] | None = get_origin (expected_exception )
957
971
if isinstance (expected_exception , type ):
958
- expected_exceptions : tuple [type [E ], ...] = (expected_exception ,)
972
+ expected_exceptions = (expected_exception ,)
973
+ elif origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
974
+ expected_exceptions = (cast (type [E ], expected_exception ),)
959
975
else :
960
976
expected_exceptions = expected_exception
961
- for exc in expected_exceptions :
962
- if not isinstance (exc , type ) or not issubclass (exc , BaseException ):
977
+
978
+ def validate_exc (exc : type [E ]) -> type [E ]:
979
+ origin_exc : type [E ] | None = get_origin (exc )
980
+ if origin_exc and issubclass (origin_exc , BaseExceptionGroup ):
981
+ exc_type = get_args (exc )[0 ]
982
+ if issubclass (origin_exc , ExceptionGroup ) and exc_type is Exception :
983
+ return cast (type [E ], origin_exc )
984
+ elif (
985
+ issubclass (origin_exc , BaseExceptionGroup ) and exc_type is BaseException
986
+ ):
987
+ return cast (type [E ], origin_exc )
988
+ else :
989
+ raise ValueError (
990
+ f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` "
991
+ f"are accepted as generic types but got `{ exc } `. "
992
+ f"`raises` will catch all instances of the base-type regardless so the "
993
+ f"returned type will be wider regardless and has to be checked "
994
+ f"with `ExceptionInfo.group_contains()`"
995
+ )
996
+
997
+ elif not isinstance (exc , type ) or not issubclass (exc , BaseException ):
963
998
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
964
999
not_a = exc .__name__ if isinstance (exc , type ) else type (exc ).__name__
965
1000
raise TypeError (msg .format (not_a ))
1001
+ else :
1002
+ return exc
1003
+
1004
+ expected_exceptions = tuple (validate_exc (exc ) for exc in expected_exceptions )
966
1005
967
1006
message = f"DID NOT RAISE { expected_exception } "
968
1007
@@ -973,14 +1012,14 @@ def raises(
973
1012
msg += ", " .join (sorted (kwargs ))
974
1013
msg += "\n Use context-manager form instead?"
975
1014
raise TypeError (msg )
976
- return RaisesContext (expected_exception , message , match )
1015
+ return RaisesContext (expected_exceptions , message , match )
977
1016
else :
978
1017
func = args [0 ]
979
1018
if not callable (func ):
980
1019
raise TypeError (f"{ func !r} object (type: { type (func )} ) must be callable" )
981
1020
try :
982
1021
func (* args [1 :], ** kwargs )
983
- except expected_exception as e :
1022
+ except expected_exceptions as e :
984
1023
return _pytest ._code .ExceptionInfo .from_exception (e )
985
1024
fail (message )
986
1025
0 commit comments