Skip to content

Commit 5891dfa

Browse files
committed
linting fixes
1 parent 1a5f103 commit 5891dfa

File tree

2 files changed

+44
-41
lines changed

2 files changed

+44
-41
lines changed

codeflash/code_utils/deduplicate_code.py

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
import ast
22
import hashlib
3-
from typing import Dict, Set
43

54

65
class VariableNormalizer(ast.NodeTransformer):
76
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
7+
88
Preserves function names, class names, parameters, built-ins, and imported names.
99
"""
1010

11-
def __init__(self):
11+
def __init__(self) -> None:
1212
self.var_counter = 0
13-
self.var_mapping: Dict[str, str] = {}
13+
self.var_mapping: dict[str, str] = {}
1414
self.scope_stack = []
1515
self.builtins = set(dir(__builtins__))
16-
self.imports: Set[str] = set()
17-
self.global_vars: Set[str] = set()
18-
self.nonlocal_vars: Set[str] = set()
19-
self.parameters: Set[str] = set() # Track function parameters
16+
self.imports: set[str] = set()
17+
self.global_vars: set[str] = set()
18+
self.nonlocal_vars: set[str] = set()
19+
self.parameters: set[str] = set() # Track function parameters
2020

21-
def enter_scope(self):
22-
"""Enter a new scope (function/class)"""
21+
def enter_scope(self): # noqa : ANN201
22+
"""Enter a new scope (function/class)."""
2323
self.scope_stack.append(
2424
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
2525
)
2626

27-
def exit_scope(self):
28-
"""Exit current scope and restore parent scope"""
27+
def exit_scope(self): # noqa : ANN201
28+
"""Exit current scope and restore parent scope."""
2929
if self.scope_stack:
3030
scope = self.scope_stack.pop()
3131
self.var_mapping = scope["var_mapping"]
3232
self.var_counter = scope["var_counter"]
3333
self.parameters = scope["parameters"]
3434

3535
def get_normalized_name(self, name: str) -> str:
36-
"""Get or create normalized name for a variable"""
36+
"""Get or create normalized name for a variable."""
3737
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
3838
if (
3939
name in self.builtins
@@ -50,34 +50,34 @@ def get_normalized_name(self, name: str) -> str:
5050
self.var_counter += 1
5151
return self.var_mapping[name]
5252

53-
def visit_Import(self, node):
54-
"""Track imported names"""
53+
def visit_Import(self, node): # noqa : ANN001, ANN201
54+
"""Track imported names."""
5555
for alias in node.names:
5656
name = alias.asname if alias.asname else alias.name
5757
self.imports.add(name.split(".")[0])
5858
return node
5959

60-
def visit_ImportFrom(self, node):
61-
"""Track imported names from modules"""
60+
def visit_ImportFrom(self, node): # noqa : ANN001, ANN201
61+
"""Track imported names from modules."""
6262
for alias in node.names:
6363
name = alias.asname if alias.asname else alias.name
6464
self.imports.add(name)
6565
return node
6666

67-
def visit_Global(self, node):
68-
"""Track global variable declarations"""
67+
def visit_Global(self, node): # noqa : ANN001, ANN201
68+
"""Track global variable declarations."""
6969
# Avoid repeated .add calls by using set.update with list
7070
self.global_vars.update(node.names)
7171
return node
7272

73-
def visit_Nonlocal(self, node):
74-
"""Track nonlocal variable declarations"""
73+
def visit_Nonlocal(self, node): # noqa : ANN001, ANN201
74+
"""Track nonlocal variable declarations."""
7575
for name in node.names:
7676
self.nonlocal_vars.add(name)
7777
return node
7878

79-
def visit_FunctionDef(self, node):
80-
"""Process function but keep function name and parameters unchanged"""
79+
def visit_FunctionDef(self, node): # noqa : ANN001, ANN201
80+
"""Process function but keep function name and parameters unchanged."""
8181
self.enter_scope()
8282

8383
# Track all parameters (don't modify them)
@@ -95,19 +95,19 @@ def visit_FunctionDef(self, node):
9595
self.exit_scope()
9696
return node
9797

98-
def visit_AsyncFunctionDef(self, node):
99-
"""Handle async functions same as regular functions"""
98+
def visit_AsyncFunctionDef(self, node): # noqa : ANN001, ANN201
99+
"""Handle async functions same as regular functions."""
100100
return self.visit_FunctionDef(node)
101101

102-
def visit_ClassDef(self, node):
103-
"""Process class but keep class name unchanged"""
102+
def visit_ClassDef(self, node): # noqa : ANN001, ANN201
103+
"""Process class but keep class name unchanged."""
104104
self.enter_scope()
105105
node = self.generic_visit(node)
106106
self.exit_scope()
107107
return node
108108

109-
def visit_Name(self, node):
110-
"""Normalize variable names in Name nodes"""
109+
def visit_Name(self, node): # noqa : ANN001, ANN201
110+
"""Normalize variable names in Name nodes."""
111111
if isinstance(node.ctx, (ast.Store, ast.Del)):
112112
# For assignments and deletions, check if we should normalize
113113
if (
@@ -118,20 +118,20 @@ def visit_Name(self, node):
118118
and node.id not in self.nonlocal_vars
119119
):
120120
node.id = self.get_normalized_name(node.id)
121-
elif isinstance(node.ctx, ast.Load):
121+
elif isinstance(node.ctx, ast.Load): # noqa : SIM102
122122
# For loading, use existing mapping if available
123123
if node.id in self.var_mapping:
124124
node.id = self.var_mapping[node.id]
125125
return node
126126

127-
def visit_ExceptHandler(self, node):
128-
"""Normalize exception variable names"""
127+
def visit_ExceptHandler(self, node): # noqa : ANN001, ANN201
128+
"""Normalize exception variable names."""
129129
if node.name:
130130
node.name = self.get_normalized_name(node.name)
131131
return self.generic_visit(node)
132132

133-
def visit_comprehension(self, node):
134-
"""Normalize comprehension target variables"""
133+
def visit_comprehension(self, node): # noqa : ANN001, ANN201
134+
"""Normalize comprehension target variables."""
135135
# Create new scope for comprehension
136136
old_mapping = dict(self.var_mapping)
137137
old_counter = self.var_counter
@@ -144,23 +144,25 @@ def visit_comprehension(self, node):
144144
self.var_counter = old_counter
145145
return node
146146

147-
def visit_For(self, node):
148-
"""Handle for loop target variables"""
147+
def visit_For(self, node): # noqa : ANN001, ANN201
148+
"""Handle for loop target variables."""
149149
# The target in a for loop is a local variable that should be normalized
150150
return self.generic_visit(node)
151151

152-
def visit_With(self, node):
153-
"""Handle with statement as variables"""
152+
def visit_With(self, node): # noqa : ANN001, ANN201
153+
"""Handle with statement as variables."""
154154
return self.generic_visit(node)
155155

156156

157-
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str:
157+
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str: # noqa : FBT002, FBT001
158158
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
159+
159160
Function names, class names, and parameters are preserved.
160161
161162
Args:
162163
code: Python source code as string
163164
remove_docstrings: Whether to remove docstrings
165+
return_ast_dump: return_ast_dump
164166
165167
Returns:
166168
Normalized code as string
@@ -191,7 +193,7 @@ def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: b
191193
raise ValueError(msg) from e
192194

193195

194-
def remove_docstrings_from_ast(node):
196+
def remove_docstrings_from_ast(node): # noqa : ANN001, ANN201
195197
"""Remove docstrings from AST nodes."""
196198
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
197199
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
@@ -242,6 +244,7 @@ def are_codes_duplicate(code1: str, code2: str) -> bool:
242244
try:
243245
normalized1 = normalize_code(code1, return_ast_dump=True)
244246
normalized2 = normalize_code(code2, return_ast_dump=True)
245-
return normalized1 == normalized2
246247
except Exception:
247248
return False
249+
else:
250+
return normalized1 == normalized2

codeflash/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def unique_invocation_loop_id(self) -> str:
558558
return f"{self.loop_index}:{self.id.id()}"
559559

560560

561-
class TestResults(BaseModel):
561+
class TestResults(BaseModel): # noqa: PLW1641
562562
# don't modify these directly, use the add method
563563
# also we don't support deletion of test results elements - caution is advised
564564
test_results: list[FunctionTestInvocation] = []

0 commit comments

Comments
 (0)