Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,39 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctio

def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
self.context_stack.append(node.name)
i = len(node.body) - 1
nbody = node.body
i = len(nbody) - 1
test_qualified_name = ".".join(self.context_stack)
key = test_qualified_name + "#" + str(self.abs_path)
original_runtimes = self.original_runtimes
optimized_runtimes = self.optimized_runtimes

while i >= 0:
line_node = node.body[i]
line_node = nbody[i]
if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)):
j = len(line_node.body) - 1
line_body = line_node.body
j = len(line_body) - 1
while j >= 0:
compound_line_node: ast.stmt = line_node.body[j]
nodes_to_check = [compound_line_node]
nodes_to_check.extend(getattr(compound_line_node, "body", []))
for internal_node in nodes_to_check:
if isinstance(internal_node, (ast.stmt, ast.Assign)):
inv_id = str(i) + "_" + str(j)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
self.results[internal_node.lineno] = self.get_comment(match_key)
compound_line_node: ast.stmt = line_body[j]
# Fast-path: most ast.stmt don't have a body
if hasattr(compound_line_node, "body"):
for internal_node in compound_line_node.body:
if isinstance(internal_node, (ast.stmt, ast.Assign)):
inv_id = f"{i}_{j}"
match_key = f"{key}#{inv_id}"
if match_key in original_runtimes and match_key in optimized_runtimes:
self.results[internal_node.lineno] = self.get_comment(match_key)
# Always check the compound_line_node itself
if isinstance(compound_line_node, (ast.stmt, ast.Assign)):
inv_id = f"{i}_{j}"
match_key = f"{key}#{inv_id}"
if match_key in original_runtimes and match_key in optimized_runtimes:
self.results[compound_line_node.lineno] = self.get_comment(match_key)
j -= 1
else:
inv_id = str(i)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
match_key = f"{key}#{inv_id}"
if match_key in original_runtimes and match_key in optimized_runtimes:
self.results[line_node.lineno] = self.get_comment(match_key)
i -= 1
self.context_stack.pop()
Expand Down
Loading