@@ -154,7 +154,7 @@ def visit_With(self, node):
154
154
return self .generic_visit (node )
155
155
156
156
157
- def normalize_code (code : str , remove_docstrings : bool = True ) -> str :
157
+ def normalize_code (code : str , remove_docstrings : bool = True , return_ast_dump : bool = False ) -> str :
158
158
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
159
159
Function names, class names, and parameters are preserved.
160
160
@@ -177,6 +177,9 @@ def normalize_code(code: str, remove_docstrings: bool = True) -> str:
177
177
# Normalize variable names
178
178
normalizer = VariableNormalizer ()
179
179
normalized_tree = normalizer .visit (tree )
180
+ if return_ast_dump :
181
+ # This is faster than unparsing etc
182
+ return ast .dump (normalized_tree , annotate_fields = False , include_attributes = False )
180
183
181
184
# Fix missing locations in the AST
182
185
ast .fix_missing_locations (normalized_tree )
@@ -190,16 +193,25 @@ def normalize_code(code: str, remove_docstrings: bool = True) -> str:
190
193
191
194
def remove_docstrings_from_ast (node ):
192
195
"""Remove docstrings from AST nodes."""
193
- # Process all nodes in the tree, but avoid recursion
194
- for current_node in ast .walk (node ):
195
- if isinstance (current_node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef , ast .Module )):
196
+ # Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
197
+ node_types = (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef , ast .Module )
198
+ # Use our own stack-based DFS instead of ast.walk for efficiency
199
+ stack = [node ]
200
+ while stack :
201
+ current_node = stack .pop ()
202
+ if isinstance (current_node , node_types ):
203
+ # Remove docstring if it's the first stmt in body
204
+ body = current_node .body
196
205
if (
197
- current_node . body
198
- and isinstance (current_node . body [0 ], ast .Expr )
199
- and isinstance (current_node . body [0 ].value , ast .Constant )
200
- and isinstance (current_node . body [0 ].value .value , str )
206
+ body
207
+ and isinstance (body [0 ], ast .Expr )
208
+ and isinstance (body [0 ].value , ast .Constant )
209
+ and isinstance (body [0 ].value .value , str )
201
210
):
202
- current_node .body = current_node .body [1 :]
211
+ current_node .body = body [1 :]
212
+ # Only these nodes can nest more docstring-containing nodes
213
+ # Add their body elements to stack, avoiding unnecessary traversal
214
+ stack .extend ([child for child in body if isinstance (child , node_types )])
203
215
204
216
205
217
def get_code_fingerprint (code : str ) -> str :
@@ -228,8 +240,8 @@ def are_codes_duplicate(code1: str, code2: str) -> bool:
228
240
229
241
"""
230
242
try :
231
- normalized1 = normalize_code (code1 )
232
- normalized2 = normalize_code (code2 )
243
+ normalized1 = normalize_code (code1 , return_ast_dump = True )
244
+ normalized2 = normalize_code (code2 , return_ast_dump = True )
233
245
return normalized1 == normalized2
234
246
except Exception :
235
247
return False
0 commit comments