1
1
import ast
2
2
import hashlib
3
- from typing import Dict , Set
4
3
5
4
6
5
class VariableNormalizer (ast .NodeTransformer ):
7
6
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
7
+
8
8
Preserves function names, class names, parameters, built-ins, and imported names.
9
9
"""
10
10
11
- def __init__ (self ):
11
+ def __init__ (self ) -> None :
12
12
self .var_counter = 0
13
- self .var_mapping : Dict [str , str ] = {}
13
+ self .var_mapping : dict [str , str ] = {}
14
14
self .scope_stack = []
15
15
self .builtins = set (dir (__builtins__ ))
16
- self .imports : Set [str ] = set ()
17
- self .global_vars : Set [str ] = set ()
18
- self .nonlocal_vars : Set [str ] = set ()
19
- self .parameters : Set [str ] = set () # Track function parameters
16
+ self .imports : set [str ] = set ()
17
+ self .global_vars : set [str ] = set ()
18
+ self .nonlocal_vars : set [str ] = set ()
19
+ self .parameters : set [str ] = set () # Track function parameters
20
20
21
- def enter_scope (self ):
22
- """Enter a new scope (function/class)"""
21
+ def enter_scope (self ): # noqa : ANN201
22
+ """Enter a new scope (function/class). """
23
23
self .scope_stack .append (
24
24
{"var_mapping" : dict (self .var_mapping ), "var_counter" : self .var_counter , "parameters" : set (self .parameters )}
25
25
)
26
26
27
- def exit_scope (self ):
28
- """Exit current scope and restore parent scope"""
27
+ def exit_scope (self ): # noqa : ANN201
28
+ """Exit current scope and restore parent scope. """
29
29
if self .scope_stack :
30
30
scope = self .scope_stack .pop ()
31
31
self .var_mapping = scope ["var_mapping" ]
32
32
self .var_counter = scope ["var_counter" ]
33
33
self .parameters = scope ["parameters" ]
34
34
35
35
def get_normalized_name (self , name : str ) -> str :
36
- """Get or create normalized name for a variable"""
36
+ """Get or create normalized name for a variable. """
37
37
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
38
38
if (
39
39
name in self .builtins
@@ -50,34 +50,34 @@ def get_normalized_name(self, name: str) -> str:
50
50
self .var_counter += 1
51
51
return self .var_mapping [name ]
52
52
53
- def visit_Import (self , node ):
54
- """Track imported names"""
53
+ def visit_Import (self , node ): # noqa : ANN001, ANN201
54
+ """Track imported names. """
55
55
for alias in node .names :
56
56
name = alias .asname if alias .asname else alias .name
57
57
self .imports .add (name .split ("." )[0 ])
58
58
return node
59
59
60
- def visit_ImportFrom (self , node ):
61
- """Track imported names from modules"""
60
+ def visit_ImportFrom (self , node ): # noqa : ANN001, ANN201
61
+ """Track imported names from modules. """
62
62
for alias in node .names :
63
63
name = alias .asname if alias .asname else alias .name
64
64
self .imports .add (name )
65
65
return node
66
66
67
- def visit_Global (self , node ):
68
- """Track global variable declarations"""
67
+ def visit_Global (self , node ): # noqa : ANN001, ANN201
68
+ """Track global variable declarations. """
69
69
# Avoid repeated .add calls by using set.update with list
70
70
self .global_vars .update (node .names )
71
71
return node
72
72
73
- def visit_Nonlocal (self , node ):
74
- """Track nonlocal variable declarations"""
73
+ def visit_Nonlocal (self , node ): # noqa : ANN001, ANN201
74
+ """Track nonlocal variable declarations. """
75
75
for name in node .names :
76
76
self .nonlocal_vars .add (name )
77
77
return node
78
78
79
- def visit_FunctionDef (self , node ):
80
- """Process function but keep function name and parameters unchanged"""
79
+ def visit_FunctionDef (self , node ): # noqa : ANN001, ANN201
80
+ """Process function but keep function name and parameters unchanged. """
81
81
self .enter_scope ()
82
82
83
83
# Track all parameters (don't modify them)
@@ -95,19 +95,19 @@ def visit_FunctionDef(self, node):
95
95
self .exit_scope ()
96
96
return node
97
97
98
- def visit_AsyncFunctionDef (self , node ):
99
- """Handle async functions same as regular functions"""
98
+ def visit_AsyncFunctionDef (self , node ): # noqa : ANN001, ANN201
99
+ """Handle async functions same as regular functions. """
100
100
return self .visit_FunctionDef (node )
101
101
102
- def visit_ClassDef (self , node ):
103
- """Process class but keep class name unchanged"""
102
+ def visit_ClassDef (self , node ): # noqa : ANN001, ANN201
103
+ """Process class but keep class name unchanged. """
104
104
self .enter_scope ()
105
105
node = self .generic_visit (node )
106
106
self .exit_scope ()
107
107
return node
108
108
109
- def visit_Name (self , node ):
110
- """Normalize variable names in Name nodes"""
109
+ def visit_Name (self , node ): # noqa : ANN001, ANN201
110
+ """Normalize variable names in Name nodes. """
111
111
if isinstance (node .ctx , (ast .Store , ast .Del )):
112
112
# For assignments and deletions, check if we should normalize
113
113
if (
@@ -118,20 +118,20 @@ def visit_Name(self, node):
118
118
and node .id not in self .nonlocal_vars
119
119
):
120
120
node .id = self .get_normalized_name (node .id )
121
- elif isinstance (node .ctx , ast .Load ):
121
+ elif isinstance (node .ctx , ast .Load ): # noqa : SIM102
122
122
# For loading, use existing mapping if available
123
123
if node .id in self .var_mapping :
124
124
node .id = self .var_mapping [node .id ]
125
125
return node
126
126
127
- def visit_ExceptHandler (self , node ):
128
- """Normalize exception variable names"""
127
+ def visit_ExceptHandler (self , node ): # noqa : ANN001, ANN201
128
+ """Normalize exception variable names. """
129
129
if node .name :
130
130
node .name = self .get_normalized_name (node .name )
131
131
return self .generic_visit (node )
132
132
133
- def visit_comprehension (self , node ):
134
- """Normalize comprehension target variables"""
133
+ def visit_comprehension (self , node ): # noqa : ANN001, ANN201
134
+ """Normalize comprehension target variables. """
135
135
# Create new scope for comprehension
136
136
old_mapping = dict (self .var_mapping )
137
137
old_counter = self .var_counter
@@ -144,23 +144,25 @@ def visit_comprehension(self, node):
144
144
self .var_counter = old_counter
145
145
return node
146
146
147
- def visit_For (self , node ):
148
- """Handle for loop target variables"""
147
+ def visit_For (self , node ): # noqa : ANN001, ANN201
148
+ """Handle for loop target variables. """
149
149
# The target in a for loop is a local variable that should be normalized
150
150
return self .generic_visit (node )
151
151
152
- def visit_With (self , node ):
153
- """Handle with statement as variables"""
152
+ def visit_With (self , node ): # noqa : ANN001, ANN201
153
+ """Handle with statement as variables. """
154
154
return self .generic_visit (node )
155
155
156
156
157
- def normalize_code (code : str , remove_docstrings : bool = True , return_ast_dump : bool = False ) -> str :
157
+ def normalize_code (code : str , remove_docstrings : bool = True , return_ast_dump : bool = False ) -> str : # noqa : FBT002, FBT001
158
158
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
159
+
159
160
Function names, class names, and parameters are preserved.
160
161
161
162
Args:
162
163
code: Python source code as string
163
164
remove_docstrings: Whether to remove docstrings
165
+ return_ast_dump: return_ast_dump
164
166
165
167
Returns:
166
168
Normalized code as string
@@ -191,7 +193,7 @@ def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: b
191
193
raise ValueError (msg ) from e
192
194
193
195
194
- def remove_docstrings_from_ast (node ):
196
+ def remove_docstrings_from_ast (node ): # noqa : ANN001, ANN201
195
197
"""Remove docstrings from AST nodes."""
196
198
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
197
199
node_types = (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef , ast .Module )
@@ -242,6 +244,7 @@ def are_codes_duplicate(code1: str, code2: str) -> bool:
242
244
try :
243
245
normalized1 = normalize_code (code1 , return_ast_dump = True )
244
246
normalized2 = normalize_code (code2 , return_ast_dump = True )
245
- return normalized1 == normalized2
246
247
except Exception :
247
248
return False
249
+ else :
250
+ return normalized1 == normalized2
0 commit comments