Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions codeflash/code_utils/deduplicate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def visit_For(self, node):

def visit_With(self, node):
"""Handle with statement as variables"""
return self.generic_visit(node)
# micro-optimization: directly call NodeTransformer's generic_visit (fewer indirections than type-based lookup)
return ast.NodeTransformer.generic_visit(self, node)


def normalize_code(code: str, remove_docstrings: bool = True) -> str:
Expand All @@ -172,7 +173,7 @@ def normalize_code(code: str, remove_docstrings: bool = True) -> str:

# Remove docstrings if requested
if remove_docstrings:
remove_docstrings_from_ast(tree)
fast_remove_docstrings_from_ast(tree)

# Normalize variable names
normalizer = VariableNormalizer()
Expand Down Expand Up @@ -233,3 +234,26 @@ def are_codes_duplicate(code1: str, code2: str) -> bool:
return normalized1 == normalized2
except Exception:
return False


def fast_remove_docstrings_from_ast(node):
"""Efficiently remove docstrings from AST nodes without walking the entire tree."""
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
# Use our own stack-based DFS instead of ast.walk for efficiency
stack = [node]
while stack:
current_node = stack.pop()
if isinstance(current_node, node_types):
# Remove docstring if it's the first stmt in body
body = current_node.body
if (
body
and isinstance(body[0], ast.Expr)
and isinstance(body[0].value, ast.Constant)
and isinstance(body[0].value.value, str)
):
current_node.body = body[1:]
# Only these nodes can nest more docstring-containing nodes
# Add their body elements to stack, avoiding unnecessary traversal
stack.extend([child for child in body if isinstance(child, node_types)])
Loading