Skip to content

Commit 7a4a22f

Browse files
authored
fix #1750 better error message when passing AgentHooks to Runner (#1752)
This pull request resolves #1750
1 parent a4c125e commit 7a4a22f

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

src/agents/run.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
ToolCallItemTypes,
5454
TResponseInputItem,
5555
)
56-
from .lifecycle import RunHooks
56+
from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase
5757
from .logger import logger
5858
from .memory import Session, SessionInputCallback
5959
from .model_settings import ModelSettings
@@ -461,13 +461,11 @@ async def run(
461461
) -> RunResult:
462462
context = kwargs.get("context")
463463
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
464-
hooks = kwargs.get("hooks")
464+
hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks")))
465465
run_config = kwargs.get("run_config")
466466
previous_response_id = kwargs.get("previous_response_id")
467467
conversation_id = kwargs.get("conversation_id")
468468
session = kwargs.get("session")
469-
if hooks is None:
470-
hooks = RunHooks[Any]()
471469
if run_config is None:
472470
run_config = RunConfig()
473471

@@ -668,14 +666,12 @@ def run_streamed(
668666
) -> RunResultStreaming:
669667
context = kwargs.get("context")
670668
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
671-
hooks = kwargs.get("hooks")
669+
hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks")))
672670
run_config = kwargs.get("run_config")
673671
previous_response_id = kwargs.get("previous_response_id")
674672
conversation_id = kwargs.get("conversation_id")
675673
session = kwargs.get("session")
676674

677-
if hooks is None:
678-
hooks = RunHooks[Any]()
679675
if run_config is None:
680676
run_config = RunConfig()
681677

@@ -732,6 +728,23 @@ def run_streamed(
732728
)
733729
return streamed_result
734730

731+
@staticmethod
732+
def _validate_run_hooks(
733+
hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None,
734+
) -> RunHooks[Any]:
735+
if hooks is None:
736+
return RunHooks[Any]()
737+
input_hook_type = type(hooks).__name__
738+
if isinstance(hooks, AgentHooksBase):
739+
raise TypeError(
740+
"Run hooks must be instances of RunHooks. "
741+
f"Received agent-scoped hooks ({input_hook_type}). "
742+
"Attach AgentHooks to an Agent via Agent(..., hooks=...)."
743+
)
744+
if not isinstance(hooks, RunHooksBase):
745+
raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.")
746+
return hooks
747+
735748
@classmethod
736749
async def _maybe_filter_model_input(
737750
cls,

tests/test_run_hooks.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from collections import defaultdict
2-
from typing import Any, Optional
2+
from typing import Any, Optional, cast
33

44
import pytest
55

66
from agents.agent import Agent
77
from agents.items import ItemHelpers, ModelResponse, TResponseInputItem
8-
from agents.lifecycle import RunHooks
8+
from agents.lifecycle import AgentHooks, RunHooks
99
from agents.models.interface import Model
1010
from agents.run import Runner
1111
from agents.run_context import RunContextWrapper, TContext
@@ -191,6 +191,29 @@ async def boom(*args, **kwargs):
191191
assert hooks.events["on_agent_end"] == 0
192192

193193

194+
class DummyAgentHooks(AgentHooks):
195+
"""Agent-scoped hooks used to verify runtime validation."""
196+
197+
198+
@pytest.mark.asyncio
199+
async def test_runner_run_rejects_agent_hooks():
200+
model = FakeModel()
201+
agent = Agent(name="A", model=model)
202+
hooks = cast(RunHooks, DummyAgentHooks())
203+
204+
with pytest.raises(TypeError, match="Run hooks must be instances of RunHooks"):
205+
await Runner.run(agent, input="hello", hooks=hooks)
206+
207+
208+
def test_runner_run_streamed_rejects_agent_hooks():
209+
model = FakeModel()
210+
agent = Agent(name="A", model=model)
211+
hooks = cast(RunHooks, DummyAgentHooks())
212+
213+
with pytest.raises(TypeError, match="Run hooks must be instances of RunHooks"):
214+
Runner.run_streamed(agent, input="hello", hooks=hooks)
215+
216+
194217
class BoomModel(Model):
195218
async def get_response(self, *a, **k):
196219
raise AssertionError("get_response should not be called in streaming test")

0 commit comments

Comments
 (0)