9
9
from prisma .types import AgentGraphWhereInput
10
10
from pydantic .fields import computed_field
11
11
12
+ from backend .blocks .agent import AgentExecutorBlock
12
13
from backend .blocks .basic import AgentInputBlock , AgentOutputBlock
13
14
from backend .data .block import BlockInput , BlockType , get_block , get_blocks
14
15
from backend .data .db import BaseDbModel , transaction
@@ -174,24 +175,35 @@ def starting_nodes(self) -> list[Node]:
174
175
if node .id not in outbound_nodes or node .id in input_nodes
175
176
]
176
177
177
- def reassign_ids (self , reassign_graph_id : bool = False ):
178
+ def reassign_ids (self , user_id : str , reassign_graph_id : bool = False ):
178
179
"""
179
180
Reassigns all IDs in the graph to new UUIDs.
180
181
This method can be used before storing a new graph to the database.
181
182
"""
182
- self .validate_graph ()
183
183
184
+ # Reassign Graph ID
184
185
id_map = {node .id : str (uuid .uuid4 ()) for node in self .nodes }
185
186
if reassign_graph_id :
186
187
self .id = str (uuid .uuid4 ())
187
188
189
+ # Reassign Node IDs
188
190
for node in self .nodes :
189
191
node .id = id_map [node .id ]
190
192
193
+ # Reassign Link IDs
191
194
for link in self .links :
192
195
link .source_id = id_map [link .source_id ]
193
196
link .sink_id = id_map [link .sink_id ]
194
197
198
+ # Reassign User IDs for agent blocks
199
+ for node in self .nodes :
200
+ if node .block_id != AgentExecutorBlock ().id :
201
+ continue
202
+ node .input_default ["user_id" ] = user_id
203
+ node .input_default .setdefault ("data" , {})
204
+
205
+ self .validate_graph ()
206
+
195
207
def validate_graph (self , for_run : bool = False ):
196
208
def sanitize (name ):
197
209
return name .split ("_#_" )[0 ].split ("_@_" )[0 ].split ("_$_" )[0 ]
@@ -215,6 +227,7 @@ def sanitize(name):
215
227
for_run # Skip input completion validation, unless when executing.
216
228
or block .block_type == BlockType .INPUT
217
229
or block .block_type == BlockType .OUTPUT
230
+ or block .block_type == BlockType .AGENT
218
231
):
219
232
raise ValueError (
220
233
f"Node { block .name } #{ node .id } required input missing: `{ name } `"
@@ -248,18 +261,26 @@ def is_static_output_block(nid: str) -> bool:
248
261
)
249
262
250
263
sanitized_name = sanitize (name )
264
+ vals = node .input_default
251
265
if i == 0 :
252
- fields = f"Valid output fields: { block .output_schema .get_fields ()} "
266
+ fields = (
267
+ block .output_schema .get_fields ()
268
+ if block .block_type != BlockType .AGENT
269
+ else vals .get ("output_schema" , {}).get ("properties" , {}).keys ()
270
+ )
253
271
else :
254
- fields = f"Valid input fields: { block .input_schema .get_fields ()} "
272
+ fields = (
273
+ block .input_schema .get_fields ()
274
+ if block .block_type != BlockType .AGENT
275
+ else vals .get ("input_schema" , {}).get ("properties" , {}).keys ()
276
+ )
255
277
if sanitized_name not in fields :
256
- raise ValueError (f"{ suffix } , `{ name } ` invalid, { fields } " )
278
+ fields_msg = f"Allowed fields: { fields } "
279
+ raise ValueError (f"{ suffix } , `{ name } ` invalid, { fields_msg } " )
257
280
258
281
if is_static_output_block (link .source_id ):
259
282
link .is_static = True # Each value block output should be static.
260
283
261
- # TODO: Add type compatibility check here.
262
-
263
284
@staticmethod
264
285
def from_db (graph : AgentGraph , hide_credentials : bool = False ):
265
286
executions = [
0 commit comments