Skip to content

Commit 9ac5d34

Browse files
Merge pull request #733 from codeflash-ai/deduplicate-better
deduplicate optimizations better
2 parents 2802ae6 + 2792c67 commit 9ac5d34

File tree

4 files changed

+386
-3
lines changed

4 files changed

+386
-3
lines changed
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import ast
2+
import hashlib
3+
from typing import Dict, Set
4+
5+
6+
class VariableNormalizer(ast.NodeTransformer):
7+
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
8+
Preserves function names, class names, parameters, built-ins, and imported names.
9+
"""
10+
11+
def __init__(self):
12+
self.var_counter = 0
13+
self.var_mapping: Dict[str, str] = {}
14+
self.scope_stack = []
15+
self.builtins = set(dir(__builtins__))
16+
self.imports: Set[str] = set()
17+
self.global_vars: Set[str] = set()
18+
self.nonlocal_vars: Set[str] = set()
19+
self.parameters: Set[str] = set() # Track function parameters
20+
21+
def enter_scope(self):
22+
"""Enter a new scope (function/class)"""
23+
self.scope_stack.append(
24+
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
25+
)
26+
27+
def exit_scope(self):
28+
"""Exit current scope and restore parent scope"""
29+
if self.scope_stack:
30+
scope = self.scope_stack.pop()
31+
self.var_mapping = scope["var_mapping"]
32+
self.var_counter = scope["var_counter"]
33+
self.parameters = scope["parameters"]
34+
35+
def get_normalized_name(self, name: str) -> str:
36+
"""Get or create normalized name for a variable"""
37+
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
38+
if (
39+
name in self.builtins
40+
or name in self.imports
41+
or name in self.global_vars
42+
or name in self.nonlocal_vars
43+
or name in self.parameters
44+
):
45+
return name
46+
47+
# Only normalize local variables
48+
if name not in self.var_mapping:
49+
self.var_mapping[name] = f"var_{self.var_counter}"
50+
self.var_counter += 1
51+
return self.var_mapping[name]
52+
53+
def visit_Import(self, node):
54+
"""Track imported names"""
55+
for alias in node.names:
56+
name = alias.asname if alias.asname else alias.name
57+
self.imports.add(name.split(".")[0])
58+
return node
59+
60+
def visit_ImportFrom(self, node):
61+
"""Track imported names from modules"""
62+
for alias in node.names:
63+
name = alias.asname if alias.asname else alias.name
64+
self.imports.add(name)
65+
return node
66+
67+
def visit_Global(self, node):
68+
"""Track global variable declarations"""
69+
# Avoid repeated .add calls by using set.update with list
70+
self.global_vars.update(node.names)
71+
return node
72+
73+
def visit_Nonlocal(self, node):
74+
"""Track nonlocal variable declarations"""
75+
for name in node.names:
76+
self.nonlocal_vars.add(name)
77+
return node
78+
79+
def visit_FunctionDef(self, node):
80+
"""Process function but keep function name and parameters unchanged"""
81+
self.enter_scope()
82+
83+
# Track all parameters (don't modify them)
84+
for arg in node.args.args:
85+
self.parameters.add(arg.arg)
86+
if node.args.vararg:
87+
self.parameters.add(node.args.vararg.arg)
88+
if node.args.kwarg:
89+
self.parameters.add(node.args.kwarg.arg)
90+
for arg in node.args.kwonlyargs:
91+
self.parameters.add(arg.arg)
92+
93+
# Visit function body
94+
node = self.generic_visit(node)
95+
self.exit_scope()
96+
return node
97+
98+
def visit_AsyncFunctionDef(self, node):
99+
"""Handle async functions same as regular functions"""
100+
return self.visit_FunctionDef(node)
101+
102+
def visit_ClassDef(self, node):
103+
"""Process class but keep class name unchanged"""
104+
self.enter_scope()
105+
node = self.generic_visit(node)
106+
self.exit_scope()
107+
return node
108+
109+
def visit_Name(self, node):
110+
"""Normalize variable names in Name nodes"""
111+
if isinstance(node.ctx, (ast.Store, ast.Del)):
112+
# For assignments and deletions, check if we should normalize
113+
if (
114+
node.id not in self.builtins
115+
and node.id not in self.imports
116+
and node.id not in self.parameters
117+
and node.id not in self.global_vars
118+
and node.id not in self.nonlocal_vars
119+
):
120+
node.id = self.get_normalized_name(node.id)
121+
elif isinstance(node.ctx, ast.Load):
122+
# For loading, use existing mapping if available
123+
if node.id in self.var_mapping:
124+
node.id = self.var_mapping[node.id]
125+
return node
126+
127+
def visit_ExceptHandler(self, node):
128+
"""Normalize exception variable names"""
129+
if node.name:
130+
node.name = self.get_normalized_name(node.name)
131+
return self.generic_visit(node)
132+
133+
def visit_comprehension(self, node):
134+
"""Normalize comprehension target variables"""
135+
# Create new scope for comprehension
136+
old_mapping = dict(self.var_mapping)
137+
old_counter = self.var_counter
138+
139+
# Process the comprehension
140+
node = self.generic_visit(node)
141+
142+
# Restore scope
143+
self.var_mapping = old_mapping
144+
self.var_counter = old_counter
145+
return node
146+
147+
def visit_For(self, node):
148+
"""Handle for loop target variables"""
149+
# The target in a for loop is a local variable that should be normalized
150+
return self.generic_visit(node)
151+
152+
def visit_With(self, node):
153+
"""Handle with statement as variables"""
154+
return self.generic_visit(node)
155+
156+
157+
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str:
158+
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
159+
Function names, class names, and parameters are preserved.
160+
161+
Args:
162+
code: Python source code as string
163+
remove_docstrings: Whether to remove docstrings
164+
165+
Returns:
166+
Normalized code as string
167+
168+
"""
169+
try:
170+
# Parse the code
171+
tree = ast.parse(code)
172+
173+
# Remove docstrings if requested
174+
if remove_docstrings:
175+
remove_docstrings_from_ast(tree)
176+
177+
# Normalize variable names
178+
normalizer = VariableNormalizer()
179+
normalized_tree = normalizer.visit(tree)
180+
if return_ast_dump:
181+
# This is faster than unparsing etc
182+
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
183+
184+
# Fix missing locations in the AST
185+
ast.fix_missing_locations(normalized_tree)
186+
187+
# Unparse back to code
188+
return ast.unparse(normalized_tree)
189+
except SyntaxError as e:
190+
msg = f"Invalid Python syntax: {e}"
191+
raise ValueError(msg) from e
192+
193+
194+
def remove_docstrings_from_ast(node):
195+
"""Remove docstrings from AST nodes."""
196+
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
197+
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
198+
# Use our own stack-based DFS instead of ast.walk for efficiency
199+
stack = [node]
200+
while stack:
201+
current_node = stack.pop()
202+
if isinstance(current_node, node_types):
203+
# Remove docstring if it's the first stmt in body
204+
body = current_node.body
205+
if (
206+
body
207+
and isinstance(body[0], ast.Expr)
208+
and isinstance(body[0].value, ast.Constant)
209+
and isinstance(body[0].value.value, str)
210+
):
211+
current_node.body = body[1:]
212+
# Only these nodes can nest more docstring-containing nodes
213+
# Add their body elements to stack, avoiding unnecessary traversal
214+
stack.extend([child for child in body if isinstance(child, node_types)])
215+
216+
217+
def get_code_fingerprint(code: str) -> str:
218+
"""Generate a fingerprint for normalized code.
219+
220+
Args:
221+
code: Python source code
222+
223+
Returns:
224+
SHA-256 hash of normalized code
225+
226+
"""
227+
normalized = normalize_code(code)
228+
return hashlib.sha256(normalized.encode()).hexdigest()
229+
230+
231+
def are_codes_duplicate(code1: str, code2: str) -> bool:
232+
"""Check if two code segments are duplicates after normalization.
233+
234+
Args:
235+
code1: First code segment
236+
code2: Second code segment
237+
238+
Returns:
239+
True if codes are structurally identical (ignoring local variable names)
240+
241+
"""
242+
try:
243+
normalized1 = normalize_code(code1, return_ast_dump=True)
244+
normalized2 = normalize_code(code2, return_ast_dump=True)
245+
return normalized1 == normalized2
246+
except Exception:
247+
return False

codeflash/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def unique_invocation_loop_id(self) -> str:
558558
return f"{self.loop_index}:{self.id.id()}"
559559

560560

561-
class TestResults(BaseModel): # noqa: PLW1641
561+
class TestResults(BaseModel):
562562
# don't modify these directly, use the add method
563563
# also we don't support deletion of test results elements - caution is advised
564564
test_results: list[FunctionTestInvocation] = []

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
REPEAT_OPTIMIZATION_PROBABILITY,
4949
TOTAL_LOOPING_TIME,
5050
)
51+
from codeflash.code_utils.deduplicate_code import normalize_code
5152
from codeflash.code_utils.edit_generated_tests import (
5253
add_runtime_comments_to_generated_tests,
5354
remove_functions_from_generated_tests,
@@ -519,7 +520,7 @@ def determine_best_candidate(
519520
)
520521
continue
521522
# check if this code has been evaluated before by checking the ast normalized code string
522-
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
523+
normalized_code = normalize_code(candidate.source_code.flat.strip())
523524
if normalized_code in ast_code_to_id:
524525
logger.info(
525526
"Current candidate has been encountered before in testing, Skipping optimization candidate."
@@ -669,7 +670,7 @@ def determine_best_candidate(
669670
diff_strs = []
670671
runtimes_list = []
671672
for valid_opt in valid_optimizations:
672-
valid_opt_normalized_code = ast.unparse(ast.parse(valid_opt.candidate.source_code.flat.strip()))
673+
valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip())
673674
new_candidate_with_shorter_code = OptimizedCandidate(
674675
source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"],
675676
optimization_id=valid_opt.candidate.optimization_id,

0 commit comments

Comments
 (0)