|
53 | 53 | ToolCallItemTypes,
|
54 | 54 | TResponseInputItem,
|
55 | 55 | )
|
56 |
| -from .lifecycle import RunHooks |
| 56 | +from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase |
57 | 57 | from .logger import logger
|
58 | 58 | from .memory import Session, SessionInputCallback
|
59 | 59 | from .model_settings import ModelSettings
|
@@ -461,13 +461,11 @@ async def run(
|
461 | 461 | ) -> RunResult:
|
462 | 462 | context = kwargs.get("context")
|
463 | 463 | max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
464 |
| - hooks = kwargs.get("hooks") |
| 464 | + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) |
465 | 465 | run_config = kwargs.get("run_config")
|
466 | 466 | previous_response_id = kwargs.get("previous_response_id")
|
467 | 467 | conversation_id = kwargs.get("conversation_id")
|
468 | 468 | session = kwargs.get("session")
|
469 |
| - if hooks is None: |
470 |
| - hooks = RunHooks[Any]() |
471 | 469 | if run_config is None:
|
472 | 470 | run_config = RunConfig()
|
473 | 471 |
|
@@ -668,14 +666,12 @@ def run_streamed(
|
668 | 666 | ) -> RunResultStreaming:
|
669 | 667 | context = kwargs.get("context")
|
670 | 668 | max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
671 |
| - hooks = kwargs.get("hooks") |
| 669 | + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) |
672 | 670 | run_config = kwargs.get("run_config")
|
673 | 671 | previous_response_id = kwargs.get("previous_response_id")
|
674 | 672 | conversation_id = kwargs.get("conversation_id")
|
675 | 673 | session = kwargs.get("session")
|
676 | 674 |
|
677 |
| - if hooks is None: |
678 |
| - hooks = RunHooks[Any]() |
679 | 675 | if run_config is None:
|
680 | 676 | run_config = RunConfig()
|
681 | 677 |
|
@@ -732,6 +728,23 @@ def run_streamed(
|
732 | 728 | )
|
733 | 729 | return streamed_result
|
734 | 730 |
|
| 731 | + @staticmethod |
| 732 | + def _validate_run_hooks( |
| 733 | + hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None, |
| 734 | + ) -> RunHooks[Any]: |
| 735 | + if hooks is None: |
| 736 | + return RunHooks[Any]() |
| 737 | + input_hook_type = type(hooks).__name__ |
| 738 | + if isinstance(hooks, AgentHooksBase): |
| 739 | + raise TypeError( |
| 740 | + "Run hooks must be instances of RunHooks. " |
| 741 | + f"Received agent-scoped hooks ({input_hook_type}). " |
| 742 | + "Attach AgentHooks to an Agent via Agent(..., hooks=...)." |
| 743 | + ) |
| 744 | + if not isinstance(hooks, RunHooksBase): |
| 745 | + raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.") |
| 746 | + return hooks |
| 747 | + |
735 | 748 | @classmethod
|
736 | 749 | async def _maybe_filter_model_input(
|
737 | 750 | cls,
|
|
0 commit comments