Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 43 additions & 40 deletions codeflash/code_utils/deduplicate_code.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
import ast
import hashlib
from typing import Dict, Set


class VariableNormalizer(ast.NodeTransformer):
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.

Preserves function names, class names, parameters, built-ins, and imported names.
"""

def __init__(self):
def __init__(self) -> None:
self.var_counter = 0
self.var_mapping: Dict[str, str] = {}
self.var_mapping: dict[str, str] = {}
self.scope_stack = []
self.builtins = set(dir(__builtins__))
self.imports: Set[str] = set()
self.global_vars: Set[str] = set()
self.nonlocal_vars: Set[str] = set()
self.parameters: Set[str] = set() # Track function parameters
self.imports: set[str] = set()
self.global_vars: set[str] = set()
self.nonlocal_vars: set[str] = set()
self.parameters: set[str] = set() # Track function parameters

def enter_scope(self):
"""Enter a new scope (function/class)"""
def enter_scope(self): # noqa : ANN201
"""Enter a new scope (function/class)."""
self.scope_stack.append(
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
)

def exit_scope(self):
"""Exit current scope and restore parent scope"""
def exit_scope(self): # noqa : ANN201
"""Exit current scope and restore parent scope."""
if self.scope_stack:
scope = self.scope_stack.pop()
self.var_mapping = scope["var_mapping"]
self.var_counter = scope["var_counter"]
self.parameters = scope["parameters"]

def get_normalized_name(self, name: str) -> str:
"""Get or create normalized name for a variable"""
"""Get or create normalized name for a variable."""
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
if (
name in self.builtins
Expand All @@ -50,34 +50,34 @@ def get_normalized_name(self, name: str) -> str:
self.var_counter += 1
return self.var_mapping[name]

def visit_Import(self, node):
"""Track imported names"""
def visit_Import(self, node): # noqa : ANN001, ANN201
"""Track imported names."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name.split(".")[0])
return node

def visit_ImportFrom(self, node):
"""Track imported names from modules"""
def visit_ImportFrom(self, node): # noqa : ANN001, ANN201
"""Track imported names from modules."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name)
return node

def visit_Global(self, node):
"""Track global variable declarations"""
def visit_Global(self, node): # noqa : ANN001, ANN201
"""Track global variable declarations."""
# Avoid repeated .add calls by using set.update with list
self.global_vars.update(node.names)
return node

def visit_Nonlocal(self, node):
"""Track nonlocal variable declarations"""
def visit_Nonlocal(self, node): # noqa : ANN001, ANN201
"""Track nonlocal variable declarations."""
for name in node.names:
self.nonlocal_vars.add(name)
return node

def visit_FunctionDef(self, node):
"""Process function but keep function name and parameters unchanged"""
def visit_FunctionDef(self, node): # noqa : ANN001, ANN201
"""Process function but keep function name and parameters unchanged."""
self.enter_scope()

# Track all parameters (don't modify them)
Expand All @@ -95,19 +95,19 @@ def visit_FunctionDef(self, node):
self.exit_scope()
return node

def visit_AsyncFunctionDef(self, node):
"""Handle async functions same as regular functions"""
def visit_AsyncFunctionDef(self, node): # noqa : ANN001, ANN201
"""Handle async functions same as regular functions."""
return self.visit_FunctionDef(node)

def visit_ClassDef(self, node):
"""Process class but keep class name unchanged"""
def visit_ClassDef(self, node): # noqa : ANN001, ANN201
"""Process class but keep class name unchanged."""
self.enter_scope()
node = self.generic_visit(node)
self.exit_scope()
return node

def visit_Name(self, node):
"""Normalize variable names in Name nodes"""
def visit_Name(self, node): # noqa : ANN001, ANN201
"""Normalize variable names in Name nodes."""
if isinstance(node.ctx, (ast.Store, ast.Del)):
# For assignments and deletions, check if we should normalize
if (
Expand All @@ -118,20 +118,20 @@ def visit_Name(self, node):
and node.id not in self.nonlocal_vars
):
node.id = self.get_normalized_name(node.id)
elif isinstance(node.ctx, ast.Load):
elif isinstance(node.ctx, ast.Load): # noqa : SIM102
# For loading, use existing mapping if available
if node.id in self.var_mapping:
node.id = self.var_mapping[node.id]
return node

def visit_ExceptHandler(self, node):
"""Normalize exception variable names"""
def visit_ExceptHandler(self, node): # noqa : ANN001, ANN201
"""Normalize exception variable names."""
if node.name:
node.name = self.get_normalized_name(node.name)
return self.generic_visit(node)

def visit_comprehension(self, node):
"""Normalize comprehension target variables"""
def visit_comprehension(self, node): # noqa : ANN001, ANN201
"""Normalize comprehension target variables."""
# Create new scope for comprehension
old_mapping = dict(self.var_mapping)
old_counter = self.var_counter
Expand All @@ -144,23 +144,25 @@ def visit_comprehension(self, node):
self.var_counter = old_counter
return node

def visit_For(self, node):
"""Handle for loop target variables"""
def visit_For(self, node): # noqa : ANN001, ANN201
"""Handle for loop target variables."""
# The target in a for loop is a local variable that should be normalized
return self.generic_visit(node)

def visit_With(self, node):
"""Handle with statement as variables"""
def visit_With(self, node): # noqa : ANN001, ANN201
"""Handle with statement as variables."""
return self.generic_visit(node)


def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str:
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str: # noqa : FBT002, FBT001
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.

Function names, class names, and parameters are preserved.

Args:
code: Python source code as string
remove_docstrings: Whether to remove docstrings
return_ast_dump: return_ast_dump

Returns:
Normalized code as string
Expand Down Expand Up @@ -191,7 +193,7 @@ def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: b
raise ValueError(msg) from e


def remove_docstrings_from_ast(node):
def remove_docstrings_from_ast(node): # noqa : ANN001, ANN201
"""Remove docstrings from AST nodes."""
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
Expand Down Expand Up @@ -242,6 +244,7 @@ def are_codes_duplicate(code1: str, code2: str) -> bool:
try:
normalized1 = normalize_code(code1, return_ast_dump=True)
normalized2 = normalize_code(code2, return_ast_dump=True)
return normalized1 == normalized2
except Exception:
return False
else:
return normalized1 == normalized2
56 changes: 31 additions & 25 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,15 +365,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
targets=[ast.Name(id="test_id", ctx=ast.Store())],
value=ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1),
]
),
lineno=lineno + 1,
Expand Down Expand Up @@ -453,7 +453,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
targets=[ast.Name(id="invocation_id", ctx=ast.Store())],
value=ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
ast.Constant(value="_"),
ast.FormattedValue(value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1),
]
Expand All @@ -466,25 +466,31 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())],
value=ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(
value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.IfExp(
test=ast.Name(id="test_class_name", ctx=ast.Load()),
test=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
body=ast.BinOp(
left=ast.Name(id="test_class_name", ctx=ast.Load()),
left=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
op=ast.Add(),
right=ast.Constant(value="."),
),
orelse=ast.Constant(value=""),
),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(
value=ast.Name(id="codeflash_function_name", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(
value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
]
Expand Down Expand Up @@ -537,7 +543,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
ast.Assign(
targets=[ast.Name(id="return_value", ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id="wrapped", ctx=ast.Load()),
func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()),
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
),
Expand Down Expand Up @@ -664,11 +670,11 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"),
ast.Tuple(
elts=[
ast.Name(id="test_module_name", ctx=ast.Load()),
ast.Name(id="test_class_name", ctx=ast.Load()),
ast.Name(id="test_name", ctx=ast.Load()),
ast.Name(id="function_name", ctx=ast.Load()),
ast.Name(id="loop_index", ctx=ast.Load()),
ast.Name(id="codeflash_test_module_name", ctx=ast.Load()),
ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
ast.Name(id="codeflash_test_name", ctx=ast.Load()),
ast.Name(id="codeflash_function_name", ctx=ast.Load()),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
ast.Name(id="invocation_id", ctx=ast.Load()),
ast.Name(id="codeflash_duration", ctx=ast.Load()),
ast.Name(id="pickled_return_value", ctx=ast.Load()),
Expand Down Expand Up @@ -707,13 +713,13 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
name="codeflash_wrap",
args=ast.arguments(
args=[
ast.arg(arg="wrapped", annotation=None),
ast.arg(arg="test_module_name", annotation=None),
ast.arg(arg="test_class_name", annotation=None),
ast.arg(arg="test_name", annotation=None),
ast.arg(arg="function_name", annotation=None),
ast.arg(arg="line_id", annotation=None),
ast.arg(arg="loop_index", annotation=None),
ast.arg(arg="codeflash_wrapped", annotation=None),
ast.arg(arg="codeflash_test_module_name", annotation=None),
ast.arg(arg="codeflash_test_class_name", annotation=None),
ast.arg(arg="codeflash_test_name", annotation=None),
ast.arg(arg="codeflash_function_name", annotation=None),
ast.arg(arg="codeflash_line_id", annotation=None),
ast.arg(arg="codeflash_loop_index", annotation=None),
*([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
*([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
],
Expand Down
2 changes: 1 addition & 1 deletion codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def unique_invocation_loop_id(self) -> str:
return f"{self.loop_index}:{self.id.id()}"


class TestResults(BaseModel):
class TestResults(BaseModel): # noqa: PLW1641
# don't modify these directly, use the add method
# also we don't support deletion of test results elements - caution is advised
test_results: list[FunctionTestInvocation] = []
Expand Down
12 changes: 6 additions & 6 deletions tests/test_instrument_all_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,31 @@
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture

# Used by cli instrumentation
codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}"
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}"
print(f"!$######{{test_stdout_tag}}######$!")
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f"!######{{test_stdout_tag}}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
Expand Down
Loading
Loading