Skip to content

Commit 4c80772

Browse files
committed
chore: Moved comparators function to a class method to avoid redundancy
1 parent 06b304d commit 4c80772

File tree

1 file changed

+22
-72
lines changed

1 file changed

+22
-72
lines changed

tests/py/dynamo/conversion/test_resize_aten.py

Lines changed: 22 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@
66

77

88
class TestResizeConverter(DispatchTestCase):
9+
10+
def compare_resized_tensors(self, tensor1, tensor2, input_shape, target_shape):
11+
# Check if the sizes match
12+
if tensor1.size() != tensor2.size():
13+
return False
14+
15+
# Flatten the tensors to ensure we are comparing the valid elements
16+
flat_tensor1 = tensor1.flatten()
17+
flat_tensor2 = tensor2.flatten()
18+
19+
# Calculate the number of valid elements to compare
20+
input_numel = torch.Size(input_shape).numel()
21+
target_numel = torch.Size(target_shape).numel()
22+
min_size = min(input_numel, target_numel)
23+
24+
# Compare only the valid elements
25+
return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size])
26+
927
@parameterized.expand(
1028
[
1129
((3,),),
@@ -28,24 +46,7 @@ def forward(self, x):
2846
input_shape = (5,)
2947
inputs = [torch.randn(input_shape)]
3048

31-
def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape):
32-
# Check if the sizes match
33-
if tensor1.size() != tensor2.size():
34-
return False
35-
36-
# Flatten the tensors to ensure we are comparing the valid elements
37-
flat_tensor1 = tensor1.flatten()
38-
flat_tensor2 = tensor2.flatten()
39-
40-
# Calculate the number of valid elements to compare
41-
input_numel = torch.Size(input_shape).numel()
42-
target_numel = torch.Size(target_shape).numel()
43-
min_size = min(input_numel, target_numel)
44-
45-
# Compare only the valid elements
46-
return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size])
47-
48-
comparators = [(compare_resized_tensors, [input_shape, target_shape])]
49+
comparators = [(self.compare_resized_tensors, [input_shape, target_shape])]
4950

5051
self.run_test_compare_tensor_attributes_only(
5152
Resize(),
@@ -76,24 +77,7 @@ def forward(self, x):
7677
input_shape = (5,)
7778
inputs = [torch.randint(1, 5, input_shape)]
7879

79-
def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape):
80-
# Check if the sizes match
81-
if tensor1.size() != tensor2.size():
82-
return False
83-
84-
# Flatten the tensors to ensure we are comparing the valid elements
85-
flat_tensor1 = tensor1.flatten()
86-
flat_tensor2 = tensor2.flatten()
87-
88-
# Calculate the number of valid elements to compare
89-
input_numel = torch.Size(input_shape).numel()
90-
target_numel = torch.Size(target_shape).numel()
91-
min_size = min(input_numel, target_numel)
92-
93-
# Compare only the valid elements
94-
return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size])
95-
96-
comparators = [(compare_resized_tensors, [input_shape, target_shape])]
80+
comparators = [(self.compare_resized_tensors, [input_shape, target_shape])]
9781

9882
self.run_test_compare_tensor_attributes_only(
9983
Resize(),
@@ -124,24 +108,7 @@ def forward(self, x):
124108
input_shape = (4, 4)
125109
inputs = [torch.randint(1, 10, input_shape)]
126110

127-
def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape):
128-
# Check if the sizes match
129-
if tensor1.size() != tensor2.size():
130-
return False
131-
132-
# Flatten the tensors to ensure we are comparing the valid elements
133-
flat_tensor1 = tensor1.flatten()
134-
flat_tensor2 = tensor2.flatten()
135-
136-
# Calculate the number of valid elements to compare
137-
input_numel = torch.Size(input_shape).numel()
138-
target_numel = torch.Size(target_shape).numel()
139-
min_size = min(input_numel, target_numel)
140-
141-
# Compare only the valid elements
142-
return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size])
143-
144-
comparators = [(compare_resized_tensors, [input_shape, target_shape])]
111+
comparators = [(self.compare_resized_tensors, [input_shape, target_shape])]
145112

146113
self.run_test_compare_tensor_attributes_only(
147114
Resize(),
@@ -171,24 +138,7 @@ def forward(self, x):
171138
input_shape = (4, 4)
172139
inputs = [torch.randint(1, 10, input_shape)]
173140

174-
def compare_resized_tensors(tensor1, tensor2, input_shape, target_shape):
175-
# Check if the sizes match
176-
if tensor1.size() != tensor2.size():
177-
return False
178-
179-
# Flatten the tensors to ensure we are comparing the valid elements
180-
flat_tensor1 = tensor1.flatten()
181-
flat_tensor2 = tensor2.flatten()
182-
183-
# Calculate the number of valid elements to compare
184-
input_numel = torch.Size(input_shape).numel()
185-
target_numel = torch.Size(target_shape).numel()
186-
min_size = min(input_numel, target_numel)
187-
188-
# Compare only the valid elements
189-
return torch.equal(flat_tensor1[:min_size], flat_tensor2[:min_size])
190-
191-
comparators = [(compare_resized_tensors, [input_shape, target_shape])]
141+
comparators = [(self.compare_resized_tensors, [input_shape, target_shape])]
192142

193143
self.run_test_compare_tensor_attributes_only(
194144
Resize(),

0 commit comments

Comments
 (0)