From 1a5f1034eb816ee1be8ba4d7c27b14e5f9b885c2 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 16 Sep 2025 15:29:08 -0700 Subject: [PATCH 1/4] fix overlappings args in codeflash wrap --- .../code_utils/instrument_existing_tests.py | 56 ++++++++++--------- tests/test_instrument_all_and_run.py | 12 ++-- tests/test_instrument_tests.py | 46 +++++++-------- 3 files changed, 60 insertions(+), 54 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 6eac52809..94e732eb3 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -365,15 +365,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun targets=[ast.Name(id="test_id", ctx=ast.Store())], value=ast.JoinedStr( values=[ - ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1), + ast.FormattedValue(value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1), ast.Constant(value=":"), - ast.FormattedValue(value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1), + ast.FormattedValue(value=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), conversion=-1), ast.Constant(value=":"), - ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1), + ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1), ast.Constant(value=":"), - ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1), + ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1), ast.Constant(value=":"), - ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1), + ast.FormattedValue(value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1), ] ), lineno=lineno + 1, @@ -453,7 +453,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun targets=[ast.Name(id="invocation_id", ctx=ast.Store())], value=ast.JoinedStr( values=[ - ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1), + ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1), ast.Constant(value="_"), ast.FormattedValue(value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1), ] @@ -466,13 +466,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())], value=ast.JoinedStr( values=[ - ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1), + ast.FormattedValue( + value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1 + ), ast.Constant(value=":"), ast.FormattedValue( value=ast.IfExp( - test=ast.Name(id="test_class_name", ctx=ast.Load()), + test=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), body=ast.BinOp( - left=ast.Name(id="test_class_name", ctx=ast.Load()), + left=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), op=ast.Add(), right=ast.Constant(value="."), ), @@ -480,11 +482,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ), conversion=-1, ), - ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1), + ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1), ast.Constant(value=":"), - ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1), + ast.FormattedValue( + value=ast.Name(id="codeflash_function_name", ctx=ast.Load()), conversion=-1 + ), ast.Constant(value=":"), - ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1), + ast.FormattedValue( + value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1 + ), ast.Constant(value=":"), ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1), ] @@ -537,7 +543,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ast.Assign( targets=[ast.Name(id="return_value", ctx=ast.Store())], value=ast.Call( - func=ast.Name(id="wrapped", ctx=ast.Load()), + func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], ), @@ -664,11 +670,11 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"), ast.Tuple( elts=[ - ast.Name(id="test_module_name", ctx=ast.Load()), - ast.Name(id="test_class_name", ctx=ast.Load()), - ast.Name(id="test_name", ctx=ast.Load()), - ast.Name(id="function_name", ctx=ast.Load()), - ast.Name(id="loop_index", ctx=ast.Load()), + ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), + ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), + ast.Name(id="codeflash_test_name", ctx=ast.Load()), + ast.Name(id="codeflash_function_name", ctx=ast.Load()), + ast.Name(id="codeflash_loop_index", ctx=ast.Load()), ast.Name(id="invocation_id", ctx=ast.Load()), ast.Name(id="codeflash_duration", ctx=ast.Load()), ast.Name(id="pickled_return_value", ctx=ast.Load()), @@ -707,13 +713,13 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun name="codeflash_wrap", args=ast.arguments( args=[ - ast.arg(arg="wrapped", annotation=None), - ast.arg(arg="test_module_name", annotation=None), - ast.arg(arg="test_class_name", annotation=None), - ast.arg(arg="test_name", annotation=None), - ast.arg(arg="function_name", annotation=None), - ast.arg(arg="line_id", annotation=None), - ast.arg(arg="loop_index", annotation=None), + ast.arg(arg="codeflash_wrapped", annotation=None), + ast.arg(arg="codeflash_test_module_name", annotation=None), + ast.arg(arg="codeflash_test_class_name", annotation=None), + ast.arg(arg="codeflash_test_name", annotation=None), + ast.arg(arg="codeflash_function_name", annotation=None), + ast.arg(arg="codeflash_line_id", annotation=None), + ast.arg(arg="codeflash_loop_index", annotation=None), *([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []), *([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []), ], diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index 7e1a20f49..cb34727a0 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -15,8 +15,8 @@ from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture # Used by cli instrumentation -codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): - test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' +codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs): + test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} if test_id in codeflash_wrap.index: @@ -24,14 +24,14 @@ else: codeflash_wrap.index[test_id] = 0 codeflash_test_index = codeflash_wrap.index[test_id] - invocation_id = f'{{line_id}}_{{codeflash_test_index}}' - test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}" + invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}' + test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}" print(f"!$######{{test_stdout_tag}}######$!") exception = None gc.disable() try: counter = time.perf_counter_ns() - return_value = wrapped(*args, **kwargs) + return_value = codeflash_wrapped(*args, **kwargs) codeflash_duration = time.perf_counter_ns() - counter except Exception as e: codeflash_duration = time.perf_counter_ns() - counter @@ -39,7 +39,7 @@ gc.enable() print(f"!######{{test_stdout_tag}}######!") pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) - codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: raise exception diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index ccec5ffe3..8b73329a2 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -27,8 +27,8 @@ from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig -codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): - test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' +codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs): + test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} if test_id in codeflash_wrap.index: @@ -36,14 +36,14 @@ else: codeflash_wrap.index[test_id] = 0 codeflash_test_index = codeflash_wrap.index[test_id] - invocation_id = f'{{line_id}}_{{codeflash_test_index}}' - test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}" + invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}' + test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}" print(f"!$######{{test_stdout_tag}}######$!") exception = None gc.disable() try: counter = time.perf_counter_ns() - return_value = wrapped(*args, **kwargs) + return_value = codeflash_wrapped(*args, **kwargs) codeflash_duration = time.perf_counter_ns() - counter except Exception as e: codeflash_duration = time.perf_counter_ns() - counter @@ -51,15 +51,15 @@ gc.enable() print(f"!######{{test_stdout_tag}}######!") pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) - codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: raise exception return return_value """ -codeflash_wrap_perfonly_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, *args, **kwargs): - test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' +codeflash_wrap_perfonly_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, *args, **kwargs): + test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} if test_id in codeflash_wrap.index: @@ -67,14 +67,14 @@ else: codeflash_wrap.index[test_id] = 0 codeflash_test_index = codeflash_wrap.index[test_id] - invocation_id = f'{{line_id}}_{{codeflash_test_index}}' - test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}" + invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}' + test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}" print(f"!$######{{test_stdout_tag}}######$!") exception = None gc.disable() try: counter = time.perf_counter_ns() - return_value = wrapped(*args, **kwargs) + return_value = codeflash_wrapped(*args, **kwargs) codeflash_duration = time.perf_counter_ns() - counter except Exception as e: codeflash_duration = time.perf_counter_ns() - counter @@ -118,8 +118,8 @@ def test_sort(self): from code_to_optimize.bubble_sort import sorter -def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): - test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' +def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs): + test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} if test_id in codeflash_wrap.index: @@ -127,16 +127,16 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi else: codeflash_wrap.index[test_id] = 0 codeflash_test_index = codeflash_wrap.index[test_id] - invocation_id = f'{{line_id}}_{{codeflash_test_index}}' + invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}' """ - expected += """test_stdout_tag = f'{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}' + expected += """test_stdout_tag = f'{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}' """ expected += """print(f'!$######{{test_stdout_tag}}######$!') exception = None gc.disable() try: counter = time.perf_counter_ns() - return_value = wrapped(*args, **kwargs) + return_value = codeflash_wrapped(*args, **kwargs) codeflash_duration = time.perf_counter_ns() - counter except Exception as e: codeflash_duration = time.perf_counter_ns() - counter @@ -144,7 +144,7 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi gc.enable() print(f'!######{{test_stdout_tag}}######!') pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) - codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: raise exception @@ -218,8 +218,8 @@ def test_prepare_image_for_yolo(): from codeflash.validation.equivalence import compare_results -def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): - test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' +def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs): + test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} if test_id in codeflash_wrap.index: @@ -227,16 +227,16 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi else: codeflash_wrap.index[test_id] = 0 codeflash_test_index = codeflash_wrap.index[test_id] - invocation_id = f'{{line_id}}_{{codeflash_test_index}}' + invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}' """ - expected += """test_stdout_tag = f'{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}' + expected += """test_stdout_tag = f'{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}' """ expected += """print(f'!$######{{test_stdout_tag}}######$!') exception = None gc.disable() try: counter = time.perf_counter_ns() - return_value = wrapped(*args, **kwargs) + return_value = codeflash_wrapped(*args, **kwargs) codeflash_duration = time.perf_counter_ns() - counter except Exception as e: codeflash_duration = time.perf_counter_ns() - counter @@ -244,7 +244,7 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi gc.enable() print(f'!######{{test_stdout_tag}}######!') pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) - codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: raise exception From 5891dfaba21a1bf49ed3c8f79c457aa9518340b0 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 16 Sep 2025 15:53:30 -0700 Subject: [PATCH 2/4] linting fixes --- codeflash/code_utils/deduplicate_code.py | 83 ++++++++++++------------ codeflash/models/models.py | 2 +- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/codeflash/code_utils/deduplicate_code.py b/codeflash/code_utils/deduplicate_code.py index d0f9f3271..9a13458ab 100644 --- a/codeflash/code_utils/deduplicate_code.py +++ b/codeflash/code_utils/deduplicate_code.py @@ -1,31 +1,31 @@ import ast import hashlib -from typing import Dict, Set class VariableNormalizer(ast.NodeTransformer): """Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc. + Preserves function names, class names, parameters, built-ins, and imported names. """ - def __init__(self): + def __init__(self) -> None: self.var_counter = 0 - self.var_mapping: Dict[str, str] = {} + self.var_mapping: dict[str, str] = {} self.scope_stack = [] self.builtins = set(dir(__builtins__)) - self.imports: Set[str] = set() - self.global_vars: Set[str] = set() - self.nonlocal_vars: Set[str] = set() - self.parameters: Set[str] = set() # Track function parameters + self.imports: set[str] = set() + self.global_vars: set[str] = set() + self.nonlocal_vars: set[str] = set() + self.parameters: set[str] = set() # Track function parameters - def enter_scope(self): - """Enter a new scope (function/class)""" + def enter_scope(self): # noqa : ANN201 + """Enter a new scope (function/class).""" self.scope_stack.append( {"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)} ) - def exit_scope(self): - """Exit current scope and restore parent scope""" + def exit_scope(self): # noqa : ANN201 + """Exit current scope and restore parent scope.""" if self.scope_stack: scope = self.scope_stack.pop() self.var_mapping = scope["var_mapping"] @@ -33,7 +33,7 @@ def exit_scope(self): self.parameters = scope["parameters"] def get_normalized_name(self, name: str) -> str: - """Get or create normalized name for a variable""" + """Get or create normalized name for a variable.""" # Don't normalize if it's a builtin, import, global, nonlocal, or parameter if ( name in self.builtins @@ -50,34 +50,34 @@ def get_normalized_name(self, name: str) -> str: self.var_counter += 1 return self.var_mapping[name] - def visit_Import(self, node): - """Track imported names""" + def visit_Import(self, node): # noqa : ANN001, ANN201 + """Track imported names.""" for alias in node.names: name = alias.asname if alias.asname else alias.name self.imports.add(name.split(".")[0]) return node - def visit_ImportFrom(self, node): - """Track imported names from modules""" + def visit_ImportFrom(self, node): # noqa : ANN001, ANN201 + """Track imported names from modules.""" for alias in node.names: name = alias.asname if alias.asname else alias.name self.imports.add(name) return node - def visit_Global(self, node): - """Track global variable declarations""" + def visit_Global(self, node): # noqa : ANN001, ANN201 + """Track global variable declarations.""" # Avoid repeated .add calls by using set.update with list self.global_vars.update(node.names) return node - def visit_Nonlocal(self, node): - """Track nonlocal variable declarations""" + def visit_Nonlocal(self, node): # noqa : ANN001, ANN201 + """Track nonlocal variable declarations.""" for name in node.names: self.nonlocal_vars.add(name) return node - def visit_FunctionDef(self, node): - """Process function but keep function name and parameters unchanged""" + def visit_FunctionDef(self, node): # noqa : ANN001, ANN201 + """Process function but keep function name and parameters unchanged.""" self.enter_scope() # Track all parameters (don't modify them) @@ -95,19 +95,19 @@ def visit_FunctionDef(self, node): self.exit_scope() return node - def visit_AsyncFunctionDef(self, node): - """Handle async functions same as regular functions""" + def visit_AsyncFunctionDef(self, node): # noqa : ANN001, ANN201 + """Handle async functions same as regular functions.""" return self.visit_FunctionDef(node) - def visit_ClassDef(self, node): - """Process class but keep class name unchanged""" + def visit_ClassDef(self, node): # noqa : ANN001, ANN201 + """Process class but keep class name unchanged.""" self.enter_scope() node = self.generic_visit(node) self.exit_scope() return node - def visit_Name(self, node): - """Normalize variable names in Name nodes""" + def visit_Name(self, node): # noqa : ANN001, ANN201 + """Normalize variable names in Name nodes.""" if isinstance(node.ctx, (ast.Store, ast.Del)): # For assignments and deletions, check if we should normalize if ( @@ -118,20 +118,20 @@ def visit_Name(self, node): and node.id not in self.nonlocal_vars ): node.id = self.get_normalized_name(node.id) - elif isinstance(node.ctx, ast.Load): + elif isinstance(node.ctx, ast.Load): # noqa : SIM102 # For loading, use existing mapping if available if node.id in self.var_mapping: node.id = self.var_mapping[node.id] return node - def visit_ExceptHandler(self, node): - """Normalize exception variable names""" + def visit_ExceptHandler(self, node): # noqa : ANN001, ANN201 + """Normalize exception variable names.""" if node.name: node.name = self.get_normalized_name(node.name) return self.generic_visit(node) - def visit_comprehension(self, node): - """Normalize comprehension target variables""" + def visit_comprehension(self, node): # noqa : ANN001, ANN201 + """Normalize comprehension target variables.""" # Create new scope for comprehension old_mapping = dict(self.var_mapping) old_counter = self.var_counter @@ -144,23 +144,25 @@ def visit_comprehension(self, node): self.var_counter = old_counter return node - def visit_For(self, node): - """Handle for loop target variables""" + def visit_For(self, node): # noqa : ANN001, ANN201 + """Handle for loop target variables.""" # The target in a for loop is a local variable that should be normalized return self.generic_visit(node) - def visit_With(self, node): - """Handle with statement as variables""" + def visit_With(self, node): # noqa : ANN001, ANN201 + """Handle with statement as variables.""" return self.generic_visit(node) -def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str: +def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str: # noqa : FBT002, FBT001 """Normalize Python code by parsing, cleaning, and normalizing only variable names. + Function names, class names, and parameters are preserved. Args: code: Python source code as string remove_docstrings: Whether to remove docstrings + return_ast_dump: return_ast_dump Returns: Normalized code as string @@ -191,7 +193,7 @@ def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: b raise ValueError(msg) from e -def remove_docstrings_from_ast(node): +def remove_docstrings_from_ast(node): # noqa : ANN001, ANN201 """Remove docstrings from AST nodes.""" # Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0] node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module) @@ -242,6 +244,7 @@ def are_codes_duplicate(code1: str, code2: str) -> bool: try: normalized1 = normalize_code(code1, return_ast_dump=True) normalized2 = normalize_code(code2, return_ast_dump=True) - return normalized1 == normalized2 except Exception: return False + else: + return normalized1 == normalized2 diff --git a/codeflash/models/models.py b/codeflash/models/models.py index e91bba3c6..8417148ef 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -558,7 +558,7 @@ def unique_invocation_loop_id(self) -> str: return f"{self.loop_index}:{self.id.id()}" -class TestResults(BaseModel): +class TestResults(BaseModel): # noqa: PLW1641 # don't modify these directly, use the add method # also we don't support deletion of test results elements - caution is advised test_results: list[FunctionTestInvocation] = [] From 978924dd6fde9820458b0e39eb2dd7200f088e0f Mon Sep 17 00:00:00 2001 From: Kevin Turcios <106575910+KRRT7@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:31:10 -0700 Subject: [PATCH 3/4] Apply suggestion from @codeflash-ai[bot] Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> --- codeflash/code_utils/deduplicate_code.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codeflash/code_utils/deduplicate_code.py b/codeflash/code_utils/deduplicate_code.py index 9a13458ab..6ad4f1f8f 100644 --- a/codeflash/code_utils/deduplicate_code.py +++ b/codeflash/code_utils/deduplicate_code.py @@ -70,10 +70,10 @@ def visit_Global(self, node): # noqa : ANN001, ANN201 self.global_vars.update(node.names) return node - def visit_Nonlocal(self, node): # noqa : ANN001, ANN201 + def visit_Nonlocal(self, node): """Track nonlocal variable declarations.""" - for name in node.names: - self.nonlocal_vars.add(name) + # Using set.update for batch insertion (faster than add-in-loop) + self.nonlocal_vars.update(node.names) return node def visit_FunctionDef(self, node): # noqa : ANN001, ANN201 From 161f8cf6f5768d256edc94cce0b642bb508f3d7a Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 16 Sep 2025 16:37:02 -0700 Subject: [PATCH 4/4] lint fix --- codeflash/code_utils/deduplicate_code.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/code_utils/deduplicate_code.py b/codeflash/code_utils/deduplicate_code.py index 6ad4f1f8f..35a4a29ff 100644 --- a/codeflash/code_utils/deduplicate_code.py +++ b/codeflash/code_utils/deduplicate_code.py @@ -70,7 +70,7 @@ def visit_Global(self, node): # noqa : ANN001, ANN201 self.global_vars.update(node.names) return node - def visit_Nonlocal(self, node): + def visit_Nonlocal(self, node): # noqa : ANN001, ANN201 """Track nonlocal variable declarations.""" # Using set.update for batch insertion (faster than add-in-loop) self.nonlocal_vars.update(node.names)