Skip to content

Commit a4319b5

Browse files
committed
feat(run): add lifecycle interrupt + inject, cancel-aware tools, safer tracing
1 parent 50a909a commit a4319b5

File tree

6 files changed

+1194
-73
lines changed

6 files changed

+1194
-73
lines changed

src/agents/_run_impl.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
import dataclasses
56
import inspect
67
from collections.abc import Awaitable
@@ -226,6 +227,29 @@ def get_model_tracing_impl(
226227
return ModelTracing.ENABLED_WITHOUT_DATA
227228

228229

230+
# Helpers for cancellable tool execution
231+
232+
233+
async def _await_cancellable(awaitable):
234+
"""Await an awaitable in its own task so CancelledError interrupts promptly."""
235+
task = asyncio.create_task(awaitable)
236+
try:
237+
return await task
238+
except asyncio.CancelledError:
239+
# propagate so run.py can handle terminal cancel
240+
raise
241+
242+
243+
def _maybe_call_cancel_hook(tool_obj) -> None:
244+
"""Best-effort: call a cancel/terminate hook on the tool if present."""
245+
for name in ("cancel", "terminate", "stop"):
246+
cb = getattr(tool_obj, name, None)
247+
if callable(cb):
248+
with contextlib.suppress(Exception):
249+
cb()
250+
break
251+
252+
229253
class RunImpl:
230254
@classmethod
231255
async def execute_tools_and_side_effects(
@@ -572,16 +596,24 @@ async def run_single_tool(
572596
if config.trace_include_sensitive_data:
573597
span_fn.span_data.input = tool_call.arguments
574598
try:
575-
_, _, result = await asyncio.gather(
599+
# run start hooks first (don’t tie them to the cancellable task)
600+
await asyncio.gather(
576601
hooks.on_tool_start(tool_context, agent, func_tool),
577602
(
578603
agent.hooks.on_tool_start(tool_context, agent, func_tool)
579604
if agent.hooks
580605
else _coro.noop_coroutine()
581606
),
582-
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
583607
)
584608

609+
try:
610+
result = await _await_cancellable(
611+
func_tool.on_invoke_tool(tool_context, tool_call.arguments)
612+
)
613+
except asyncio.CancelledError:
614+
_maybe_call_cancel_hook(func_tool)
615+
raise
616+
585617
await asyncio.gather(
586618
hooks.on_tool_end(tool_context, agent, func_tool, result),
587619
(
@@ -590,6 +622,7 @@ async def run_single_tool(
590622
else _coro.noop_coroutine()
591623
),
592624
)
625+
593626
except Exception as e:
594627
_error_tracing.attach_error_to_current_span(
595628
SpanError(
@@ -660,7 +693,6 @@ async def execute_computer_actions(
660693
config: RunConfig,
661694
) -> list[RunItem]:
662695
results: list[RunItem] = []
663-
# Need to run these serially, because each action can affect the computer state
664696
for action in actions:
665697
acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
666698
if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
@@ -677,24 +709,28 @@ async def execute_computer_actions(
677709
if ack:
678710
acknowledged.append(
679711
ComputerCallOutputAcknowledgedSafetyCheck(
680-
id=check.id,
681-
code=check.code,
682-
message=check.message,
712+
id=check.id, code=check.code, message=check.message
683713
)
684714
)
685715
else:
686716
raise UserError("Computer tool safety check was not acknowledged")
687717

688-
results.append(
689-
await ComputerAction.execute(
690-
agent=agent,
691-
action=action,
692-
hooks=hooks,
693-
context_wrapper=context_wrapper,
694-
config=config,
695-
acknowledged_safety_checks=acknowledged,
718+
try:
719+
item = await _await_cancellable(
720+
ComputerAction.execute(
721+
agent=agent,
722+
action=action,
723+
hooks=hooks,
724+
context_wrapper=context_wrapper,
725+
config=config,
726+
acknowledged_safety_checks=acknowledged,
727+
)
696728
)
697-
)
729+
except asyncio.CancelledError:
730+
_maybe_call_cancel_hook(action.computer_tool)
731+
raise
732+
733+
results.append(item)
698734

699735
return results
700736

@@ -1068,16 +1104,23 @@ async def execute(
10681104
else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call)
10691105
)
10701106

1071-
_, _, output = await asyncio.gather(
1107+
# start hooks first
1108+
await asyncio.gather(
10721109
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
10731110
(
10741111
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
10751112
if agent.hooks
10761113
else _coro.noop_coroutine()
10771114
),
1078-
output_func,
10791115
)
1080-
1116+
# run the action (screenshot/etc) in a cancellable task
1117+
try:
1118+
output = await _await_cancellable(output_func)
1119+
except asyncio.CancelledError:
1120+
_maybe_call_cancel_hook(action.computer_tool)
1121+
raise
1122+
1123+
# end hooks
10811124
await asyncio.gather(
10821125
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
10831126
(
@@ -1185,10 +1228,20 @@ async def execute(
11851228
data=call.tool_call,
11861229
)
11871230
output = call.local_shell_tool.executor(request)
1188-
if inspect.isawaitable(output):
1189-
result = await output
1190-
else:
1191-
result = output
1231+
try:
1232+
if inspect.isawaitable(output):
1233+
result = await _await_cancellable(output)
1234+
else:
1235+
# If executor returns a sync result, just use it (can’t cancel mid-call)
1236+
result = output
1237+
except asyncio.CancelledError:
1238+
# Best-effort: if the executor or tool exposes a cancel/terminate, call it
1239+
_maybe_call_cancel_hook(call.local_shell_tool)
1240+
# If your executor returns a proc handle (common pattern), adddress it here if needed:
1241+
# with contextlib.suppress(Exception):
1242+
# proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0)
1243+
# proc.kill()
1244+
raise
11921245

11931246
await asyncio.gather(
11941247
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
@@ -1201,7 +1254,7 @@ async def execute(
12011254

12021255
return ToolCallOutputItem(
12031256
agent=agent,
1204-
output=output,
1257+
output=result,
12051258
raw_item={
12061259
"type": "local_shell_call_output",
12071260
"id": call.tool_call.call_id,

src/agents/models/openai_responses.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
from collections.abc import AsyncIterator
56
from dataclasses import dataclass
@@ -175,15 +176,30 @@ async def stream_response(
175176

176177
final_response: Response | None = None
177178

178-
async for chunk in stream:
179-
if isinstance(chunk, ResponseCompletedEvent):
180-
final_response = chunk.response
181-
yield chunk
179+
try:
180+
async for chunk in stream: # ensure type checkers relax here
181+
if isinstance(chunk, ResponseCompletedEvent):
182+
final_response = chunk.response
183+
yield chunk
184+
except asyncio.CancelledError:
185+
# Cooperative cancel: ensure the HTTP stream is closed, then propagate
186+
try:
187+
await stream.close()
188+
except Exception:
189+
pass
190+
raise
191+
finally:
192+
# Always close the stream if the async iterator exits (normal or error)
193+
try:
194+
await stream.close()
195+
except Exception:
196+
pass
182197

183198
if final_response and tracing.include_data():
184199
span_response.span_data.response = final_response
185200
span_response.span_data.input = input
186201

202+
187203
except Exception as e:
188204
span_response.set_error(
189205
SpanError(

src/agents/result.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
import asyncio
5+
import contextlib
56
from collections.abc import AsyncIterator
67
from dataclasses import dataclass, field
78
from typing import TYPE_CHECKING, Any, cast
@@ -143,6 +144,12 @@ class RunResultStreaming(RunResultBase):
143144
is_complete: bool = False
144145
"""Whether the agent has finished running."""
145146

147+
_emit_status_events: bool = False
148+
"""Whether to emit RunUpdatedStreamEvent status updates.
149+
150+
Defaults to False for backward compatibility.
151+
"""
152+
146153
# Queues that the background run_loop writes to
147154
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field(
148155
default_factory=asyncio.Queue, repr=False
@@ -164,17 +171,45 @@ def last_agent(self) -> Agent[Any]:
164171
"""
165172
return self.current_agent
166173

167-
def cancel(self) -> None:
168-
"""Cancels the streaming run, stopping all background tasks and marking the run as
169-
complete."""
170-
self._cleanup_tasks() # Cancel all running tasks
171-
self.is_complete = True # Mark the run as complete to stop event streaming
174+
def cancel(self, reason: str | None = None) -> None:
175+
# 1) Signal cooperative cancel to the runner
176+
active = getattr(self, "_active_run", None)
177+
if active:
178+
with contextlib.suppress(Exception):
179+
active.cancel(reason)
180+
# 2) Do NOT cancel the background task; let the loop unwind cooperatively
181+
# task = getattr(self, "_run_impl_task", None)
182+
# if task and not task.done():
183+
# with contextlib.suppress(Exception):
184+
# task.cancel()
185+
186+
# 4) Mark complete; flushing only when status events are disabled
187+
self.is_complete = True
188+
if not getattr(self, "_emit_status_events", False):
189+
with contextlib.suppress(Exception):
190+
while not self._event_queue.empty():
191+
self._event_queue.get_nowait()
192+
self._event_queue.task_done()
193+
with contextlib.suppress(Exception):
194+
while not self._input_guardrail_queue.empty():
195+
self._input_guardrail_queue.get_nowait()
196+
self._input_guardrail_queue.task_done()
197+
198+
def inject(self, items: list[TResponseInputItem]) -> None:
199+
"""
200+
Inject new input items mid-run. They will be consumed at the start of the next step.
201+
"""
202+
active = getattr(self, "_active_run", None)
203+
if active is not None:
204+
try:
205+
active.inject(items)
206+
except Exception:
207+
pass
172208

173-
# Optionally, clear the event queue to prevent processing stale events
174-
while not self._event_queue.empty():
175-
self._event_queue.get_nowait()
176-
while not self._input_guardrail_queue.empty():
177-
self._input_guardrail_queue.get_nowait()
209+
@property
210+
def active_run(self):
211+
"""Access the underlying ActiveRun handle (may be None early in startup)."""
212+
return getattr(self, "_active_run", None)
178213

179214
async def stream_events(self) -> AsyncIterator[StreamEvent]:
180215
"""Stream deltas for new items as they are generated. We're using the types from the
@@ -243,21 +278,33 @@ def _check_errors(self):
243278
# Check the tasks for any exceptions
244279
if self._run_impl_task and self._run_impl_task.done():
245280
run_impl_exc = self._run_impl_task.exception()
246-
if run_impl_exc and isinstance(run_impl_exc, Exception):
281+
if (
282+
run_impl_exc
283+
and isinstance(run_impl_exc, Exception)
284+
and not isinstance(run_impl_exc, asyncio.CancelledError)
285+
):
247286
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
248287
run_impl_exc.run_data = self._create_error_details()
249288
self._stored_exception = run_impl_exc
250289

251290
if self._input_guardrails_task and self._input_guardrails_task.done():
252291
in_guard_exc = self._input_guardrails_task.exception()
253-
if in_guard_exc and isinstance(in_guard_exc, Exception):
292+
if (
293+
in_guard_exc
294+
and isinstance(in_guard_exc, Exception)
295+
and not isinstance(in_guard_exc, asyncio.CancelledError)
296+
):
254297
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
255298
in_guard_exc.run_data = self._create_error_details()
256299
self._stored_exception = in_guard_exc
257300

258301
if self._output_guardrails_task and self._output_guardrails_task.done():
259302
out_guard_exc = self._output_guardrails_task.exception()
260-
if out_guard_exc and isinstance(out_guard_exc, Exception):
303+
if (
304+
out_guard_exc
305+
and isinstance(out_guard_exc, Exception)
306+
and not isinstance(out_guard_exc, asyncio.CancelledError)
307+
):
261308
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
262309
out_guard_exc.run_data = self._create_error_details()
263310
self._stored_exception = out_guard_exc

0 commit comments

Comments
 (0)