Skip to content

Commit 6754b1f

Browse files
committed
optimize performance
Signed-off-by: Saurabh Misra <[email protected]>
1 parent a831ee3 commit 6754b1f

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

codeflash/code_utils/deduplicate_code.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def visit_With(self, node):
154154
return self.generic_visit(node)
155155

156156

157-
def normalize_code(code: str, remove_docstrings: bool = True) -> str:
157+
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str:
158158
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
159159
Function names, class names, and parameters are preserved.
160160
@@ -177,6 +177,9 @@ def normalize_code(code: str, remove_docstrings: bool = True) -> str:
177177
# Normalize variable names
178178
normalizer = VariableNormalizer()
179179
normalized_tree = normalizer.visit(tree)
180+
if return_ast_dump:
181+
# This is faster than unparsing etc
182+
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
180183

181184
# Fix missing locations in the AST
182185
ast.fix_missing_locations(normalized_tree)
@@ -190,16 +193,25 @@ def normalize_code(code: str, remove_docstrings: bool = True) -> str:
190193

191194
def remove_docstrings_from_ast(node):
192195
"""Remove docstrings from AST nodes."""
193-
# Process all nodes in the tree, but avoid recursion
194-
for current_node in ast.walk(node):
195-
if isinstance(current_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)):
196+
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
197+
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
198+
# Use our own stack-based DFS instead of ast.walk for efficiency
199+
stack = [node]
200+
while stack:
201+
current_node = stack.pop()
202+
if isinstance(current_node, node_types):
203+
# Remove docstring if it's the first stmt in body
204+
body = current_node.body
196205
if (
197-
current_node.body
198-
and isinstance(current_node.body[0], ast.Expr)
199-
and isinstance(current_node.body[0].value, ast.Constant)
200-
and isinstance(current_node.body[0].value.value, str)
206+
body
207+
and isinstance(body[0], ast.Expr)
208+
and isinstance(body[0].value, ast.Constant)
209+
and isinstance(body[0].value.value, str)
201210
):
202-
current_node.body = current_node.body[1:]
211+
current_node.body = body[1:]
212+
# Only these nodes can nest more docstring-containing nodes
213+
# Add their body elements to stack, avoiding unnecessary traversal
214+
stack.extend([child for child in body if isinstance(child, node_types)])
203215

204216

205217
def get_code_fingerprint(code: str) -> str:
@@ -228,8 +240,8 @@ def are_codes_duplicate(code1: str, code2: str) -> bool:
228240
229241
"""
230242
try:
231-
normalized1 = normalize_code(code1)
232-
normalized2 = normalize_code(code2)
243+
normalized1 = normalize_code(code1, return_ast_dump=True)
244+
normalized2 = normalize_code(code2, return_ast_dump=True)
233245
return normalized1 == normalized2
234246
except Exception:
235247
return False

0 commit comments

Comments
 (0)