1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
+ import contextlib
4
5
import dataclasses
5
6
import inspect
6
7
from collections .abc import Awaitable
@@ -226,6 +227,29 @@ def get_model_tracing_impl(
226
227
return ModelTracing .ENABLED_WITHOUT_DATA
227
228
228
229
230
+ # Helpers for cancellable tool execution
231
+
232
+
233
+ async def _await_cancellable (awaitable ):
234
+ """Await an awaitable in its own task so CancelledError interrupts promptly."""
235
+ task = asyncio .create_task (awaitable )
236
+ try :
237
+ return await task
238
+ except asyncio .CancelledError :
239
+ # propagate so run.py can handle terminal cancel
240
+ raise
241
+
242
+
243
+ def _maybe_call_cancel_hook (tool_obj ) -> None :
244
+ """Best-effort: call a cancel/terminate hook on the tool if present."""
245
+ for name in ("cancel" , "terminate" , "stop" ):
246
+ cb = getattr (tool_obj , name , None )
247
+ if callable (cb ):
248
+ with contextlib .suppress (Exception ):
249
+ cb ()
250
+ break
251
+
252
+
229
253
class RunImpl :
230
254
@classmethod
231
255
async def execute_tools_and_side_effects (
@@ -572,16 +596,24 @@ async def run_single_tool(
572
596
if config .trace_include_sensitive_data :
573
597
span_fn .span_data .input = tool_call .arguments
574
598
try :
575
- _ , _ , result = await asyncio .gather (
599
+ # run start hooks first (don’t tie them to the cancellable task)
600
+ await asyncio .gather (
576
601
hooks .on_tool_start (tool_context , agent , func_tool ),
577
602
(
578
603
agent .hooks .on_tool_start (tool_context , agent , func_tool )
579
604
if agent .hooks
580
605
else _coro .noop_coroutine ()
581
606
),
582
- func_tool .on_invoke_tool (tool_context , tool_call .arguments ),
583
607
)
584
608
609
+ try :
610
+ result = await _await_cancellable (
611
+ func_tool .on_invoke_tool (tool_context , tool_call .arguments )
612
+ )
613
+ except asyncio .CancelledError :
614
+ _maybe_call_cancel_hook (func_tool )
615
+ raise
616
+
585
617
await asyncio .gather (
586
618
hooks .on_tool_end (tool_context , agent , func_tool , result ),
587
619
(
@@ -590,6 +622,7 @@ async def run_single_tool(
590
622
else _coro .noop_coroutine ()
591
623
),
592
624
)
625
+
593
626
except Exception as e :
594
627
_error_tracing .attach_error_to_current_span (
595
628
SpanError (
@@ -660,7 +693,6 @@ async def execute_computer_actions(
660
693
config : RunConfig ,
661
694
) -> list [RunItem ]:
662
695
results : list [RunItem ] = []
663
- # Need to run these serially, because each action can affect the computer state
664
696
for action in actions :
665
697
acknowledged : list [ComputerCallOutputAcknowledgedSafetyCheck ] | None = None
666
698
if action .tool_call .pending_safety_checks and action .computer_tool .on_safety_check :
@@ -677,24 +709,28 @@ async def execute_computer_actions(
677
709
if ack :
678
710
acknowledged .append (
679
711
ComputerCallOutputAcknowledgedSafetyCheck (
680
- id = check .id ,
681
- code = check .code ,
682
- message = check .message ,
712
+ id = check .id , code = check .code , message = check .message
683
713
)
684
714
)
685
715
else :
686
716
raise UserError ("Computer tool safety check was not acknowledged" )
687
717
688
- results .append (
689
- await ComputerAction .execute (
690
- agent = agent ,
691
- action = action ,
692
- hooks = hooks ,
693
- context_wrapper = context_wrapper ,
694
- config = config ,
695
- acknowledged_safety_checks = acknowledged ,
718
+ try :
719
+ item = await _await_cancellable (
720
+ ComputerAction .execute (
721
+ agent = agent ,
722
+ action = action ,
723
+ hooks = hooks ,
724
+ context_wrapper = context_wrapper ,
725
+ config = config ,
726
+ acknowledged_safety_checks = acknowledged ,
727
+ )
696
728
)
697
- )
729
+ except asyncio .CancelledError :
730
+ _maybe_call_cancel_hook (action .computer_tool )
731
+ raise
732
+
733
+ results .append (item )
698
734
699
735
return results
700
736
@@ -1068,16 +1104,23 @@ async def execute(
1068
1104
else cls ._get_screenshot_sync (action .computer_tool .computer , action .tool_call )
1069
1105
)
1070
1106
1071
- _ , _ , output = await asyncio .gather (
1107
+ # start hooks first
1108
+ await asyncio .gather (
1072
1109
hooks .on_tool_start (context_wrapper , agent , action .computer_tool ),
1073
1110
(
1074
1111
agent .hooks .on_tool_start (context_wrapper , agent , action .computer_tool )
1075
1112
if agent .hooks
1076
1113
else _coro .noop_coroutine ()
1077
1114
),
1078
- output_func ,
1079
1115
)
1080
-
1116
+ # run the action (screenshot/etc) in a cancellable task
1117
+ try :
1118
+ output = await _await_cancellable (output_func )
1119
+ except asyncio .CancelledError :
1120
+ _maybe_call_cancel_hook (action .computer_tool )
1121
+ raise
1122
+
1123
+ # end hooks
1081
1124
await asyncio .gather (
1082
1125
hooks .on_tool_end (context_wrapper , agent , action .computer_tool , output ),
1083
1126
(
@@ -1185,10 +1228,20 @@ async def execute(
1185
1228
data = call .tool_call ,
1186
1229
)
1187
1230
output = call .local_shell_tool .executor (request )
1188
- if inspect .isawaitable (output ):
1189
- result = await output
1190
- else :
1191
- result = output
1231
+ try :
1232
+ if inspect .isawaitable (output ):
1233
+ result = await _await_cancellable (output )
1234
+ else :
1235
+ # If executor returns a sync result, just use it (can’t cancel mid-call)
1236
+ result = output
1237
+ except asyncio .CancelledError :
1238
+ # Best-effort: if the executor or tool exposes a cancel/terminate, call it
1239
+ _maybe_call_cancel_hook (call .local_shell_tool )
1240
+ # If your executor returns a proc handle (common pattern), adddress it here if needed:
1241
+ # with contextlib.suppress(Exception):
1242
+ # proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0)
1243
+ # proc.kill()
1244
+ raise
1192
1245
1193
1246
await asyncio .gather (
1194
1247
hooks .on_tool_end (context_wrapper , agent , call .local_shell_tool , result ),
@@ -1201,7 +1254,7 @@ async def execute(
1201
1254
1202
1255
return ToolCallOutputItem (
1203
1256
agent = agent ,
1204
- output = output ,
1257
+ output = result ,
1205
1258
raw_item = {
1206
1259
"type" : "local_shell_call_output" ,
1207
1260
"id" : call .tool_call .call_id ,
0 commit comments