diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 8e50b1d7..3f028d1b 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -61,28 +61,39 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctio def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: self.context_stack.append(node.name) - i = len(node.body) - 1 + nbody = node.body + i = len(nbody) - 1 test_qualified_name = ".".join(self.context_stack) key = test_qualified_name + "#" + str(self.abs_path) + original_runtimes = self.original_runtimes + optimized_runtimes = self.optimized_runtimes + while i >= 0: - line_node = node.body[i] + line_node = nbody[i] if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)): - j = len(line_node.body) - 1 + line_body = line_node.body + j = len(line_body) - 1 while j >= 0: - compound_line_node: ast.stmt = line_node.body[j] - nodes_to_check = [compound_line_node] - nodes_to_check.extend(getattr(compound_line_node, "body", [])) - for internal_node in nodes_to_check: - if isinstance(internal_node, (ast.stmt, ast.Assign)): - inv_id = str(i) + "_" + str(j) - match_key = key + "#" + inv_id - if match_key in self.original_runtimes and match_key in self.optimized_runtimes: - self.results[internal_node.lineno] = self.get_comment(match_key) + compound_line_node: ast.stmt = line_body[j] + # Fast-path: most ast.stmt don't have a body + if hasattr(compound_line_node, "body"): + for internal_node in compound_line_node.body: + if isinstance(internal_node, (ast.stmt, ast.Assign)): + inv_id = f"{i}_{j}" + match_key = f"{key}#{inv_id}" + if match_key in original_runtimes and match_key in optimized_runtimes: + self.results[internal_node.lineno] = self.get_comment(match_key) + # Always check the compound_line_node itself + if isinstance(compound_line_node, (ast.stmt, ast.Assign)): + inv_id = f"{i}_{j}" + match_key = f"{key}#{inv_id}" + if match_key in original_runtimes and match_key in optimized_runtimes: + self.results[compound_line_node.lineno] = self.get_comment(match_key) j -= 1 else: inv_id = str(i) - match_key = key + "#" + inv_id - if match_key in self.original_runtimes and match_key in self.optimized_runtimes: + match_key = f"{key}#{inv_id}" + if match_key in original_runtimes and match_key in optimized_runtimes: self.results[line_node.lineno] = self.get_comment(match_key) i -= 1 self.context_stack.pop()