6
6
7
7
8
8
class TestResizeConverter (DispatchTestCase ):
9
+
10
+ def compare_resized_tensors (self , tensor1 , tensor2 , input_shape , target_shape ):
11
+ # Check if the sizes match
12
+ if tensor1 .size () != tensor2 .size ():
13
+ return False
14
+
15
+ # Flatten the tensors to ensure we are comparing the valid elements
16
+ flat_tensor1 = tensor1 .flatten ()
17
+ flat_tensor2 = tensor2 .flatten ()
18
+
19
+ # Calculate the number of valid elements to compare
20
+ input_numel = torch .Size (input_shape ).numel ()
21
+ target_numel = torch .Size (target_shape ).numel ()
22
+ min_size = min (input_numel , target_numel )
23
+
24
+ # Compare only the valid elements
25
+ return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
26
+
9
27
@parameterized .expand (
10
28
[
11
29
((3 ,),),
@@ -28,24 +46,7 @@ def forward(self, x):
28
46
input_shape = (5 ,)
29
47
inputs = [torch .randn (input_shape )]
30
48
31
- def compare_resized_tensors (tensor1 , tensor2 , input_shape , target_shape ):
32
- # Check if the sizes match
33
- if tensor1 .size () != tensor2 .size ():
34
- return False
35
-
36
- # Flatten the tensors to ensure we are comparing the valid elements
37
- flat_tensor1 = tensor1 .flatten ()
38
- flat_tensor2 = tensor2 .flatten ()
39
-
40
- # Calculate the number of valid elements to compare
41
- input_numel = torch .Size (input_shape ).numel ()
42
- target_numel = torch .Size (target_shape ).numel ()
43
- min_size = min (input_numel , target_numel )
44
-
45
- # Compare only the valid elements
46
- return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
47
-
48
- comparators = [(compare_resized_tensors , [input_shape , target_shape ])]
49
+ comparators = [(self .compare_resized_tensors , [input_shape , target_shape ])]
49
50
50
51
self .run_test_compare_tensor_attributes_only (
51
52
Resize (),
@@ -76,24 +77,7 @@ def forward(self, x):
76
77
input_shape = (5 ,)
77
78
inputs = [torch .randint (1 , 5 , input_shape )]
78
79
79
- def compare_resized_tensors (tensor1 , tensor2 , input_shape , target_shape ):
80
- # Check if the sizes match
81
- if tensor1 .size () != tensor2 .size ():
82
- return False
83
-
84
- # Flatten the tensors to ensure we are comparing the valid elements
85
- flat_tensor1 = tensor1 .flatten ()
86
- flat_tensor2 = tensor2 .flatten ()
87
-
88
- # Calculate the number of valid elements to compare
89
- input_numel = torch .Size (input_shape ).numel ()
90
- target_numel = torch .Size (target_shape ).numel ()
91
- min_size = min (input_numel , target_numel )
92
-
93
- # Compare only the valid elements
94
- return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
95
-
96
- comparators = [(compare_resized_tensors , [input_shape , target_shape ])]
80
+ comparators = [(self .compare_resized_tensors , [input_shape , target_shape ])]
97
81
98
82
self .run_test_compare_tensor_attributes_only (
99
83
Resize (),
@@ -124,24 +108,7 @@ def forward(self, x):
124
108
input_shape = (4 , 4 )
125
109
inputs = [torch .randint (1 , 10 , input_shape )]
126
110
127
- def compare_resized_tensors (tensor1 , tensor2 , input_shape , target_shape ):
128
- # Check if the sizes match
129
- if tensor1 .size () != tensor2 .size ():
130
- return False
131
-
132
- # Flatten the tensors to ensure we are comparing the valid elements
133
- flat_tensor1 = tensor1 .flatten ()
134
- flat_tensor2 = tensor2 .flatten ()
135
-
136
- # Calculate the number of valid elements to compare
137
- input_numel = torch .Size (input_shape ).numel ()
138
- target_numel = torch .Size (target_shape ).numel ()
139
- min_size = min (input_numel , target_numel )
140
-
141
- # Compare only the valid elements
142
- return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
143
-
144
- comparators = [(compare_resized_tensors , [input_shape , target_shape ])]
111
+ comparators = [(self .compare_resized_tensors , [input_shape , target_shape ])]
145
112
146
113
self .run_test_compare_tensor_attributes_only (
147
114
Resize (),
@@ -171,24 +138,7 @@ def forward(self, x):
171
138
input_shape = (4 , 4 )
172
139
inputs = [torch .randint (1 , 10 , input_shape )]
173
140
174
- def compare_resized_tensors (tensor1 , tensor2 , input_shape , target_shape ):
175
- # Check if the sizes match
176
- if tensor1 .size () != tensor2 .size ():
177
- return False
178
-
179
- # Flatten the tensors to ensure we are comparing the valid elements
180
- flat_tensor1 = tensor1 .flatten ()
181
- flat_tensor2 = tensor2 .flatten ()
182
-
183
- # Calculate the number of valid elements to compare
184
- input_numel = torch .Size (input_shape ).numel ()
185
- target_numel = torch .Size (target_shape ).numel ()
186
- min_size = min (input_numel , target_numel )
187
-
188
- # Compare only the valid elements
189
- return torch .equal (flat_tensor1 [:min_size ], flat_tensor2 [:min_size ])
190
-
191
- comparators = [(compare_resized_tensors , [input_shape , target_shape ])]
141
+ comparators = [(self .compare_resized_tensors , [input_shape , target_shape ])]
192
142
193
143
self .run_test_compare_tensor_attributes_only (
194
144
Resize (),
0 commit comments