@@ -22,11 +22,140 @@ class TestIndexPutConverter(DispatchTestCase):
22
22
indices_tensor = (torch .tensor ([0 , 3 ], dtype = torch .int32 ),),
23
23
value_tensor = torch .tensor ([1 , 3 ], dtype = torch .int32 ),
24
24
),
25
+ param (
26
+ test_name = "2d_indices_single" ,
27
+ source_tensor = torch .zeros ([5 , 5 ], dtype = torch .int32 ),
28
+ indices_tensor = (
29
+ torch .tensor ([2 ], dtype = torch .int32 ),
30
+ torch .tensor ([0 ], dtype = torch .int32 ),
31
+ ),
32
+ value_tensor = torch .tensor ([3 ], dtype = torch .int32 ),
33
+ ),
34
+ param (
35
+ test_name = "2d_indices_multiple" ,
36
+ source_tensor = torch .zeros ([5 , 5 ], dtype = torch .int32 ),
37
+ indices_tensor = (
38
+ torch .tensor ([0 , 2 , 2 ], dtype = torch .int32 ),
39
+ torch .tensor ([2 , 0 , 2 ], dtype = torch .int32 ),
40
+ ),
41
+ value_tensor = torch .tensor ([1 , 3 , 4 ], dtype = torch .int32 ),
42
+ ),
43
+ param (
44
+ test_name = "3d_indices_single" ,
45
+ source_tensor = torch .zeros ([3 , 3 , 3 ], dtype = torch .int32 ),
46
+ indices_tensor = (
47
+ torch .tensor ([1 ], dtype = torch .int32 ),
48
+ torch .tensor ([2 ], dtype = torch .int32 ),
49
+ torch .tensor ([2 ], dtype = torch .int32 ),
50
+ ),
51
+ value_tensor = torch .tensor ([7 ], dtype = torch .int32 ),
52
+ ),
53
+ param (
54
+ test_name = "3d_indices_multiple" ,
55
+ source_tensor = torch .zeros ([3 , 3 , 3 ], dtype = torch .int32 ),
56
+ indices_tensor = (
57
+ torch .tensor ([0 , 1 , 1 ], dtype = torch .int32 ),
58
+ torch .tensor ([1 , 2 , 1 ], dtype = torch .int32 ),
59
+ torch .tensor ([2 , 0 , 2 ], dtype = torch .int32 ),
60
+ ),
61
+ value_tensor = torch .tensor ([5 , 7 , 2 ], dtype = torch .int32 ),
62
+ ),
63
+ param (
64
+ test_name = "4d_indices_single" ,
65
+ source_tensor = torch .zeros ([2 , 2 , 2 , 2 ], dtype = torch .int32 ),
66
+ indices_tensor = (
67
+ torch .tensor ([1 ], dtype = torch .int32 ),
68
+ torch .tensor ([1 ], dtype = torch .int32 ),
69
+ torch .tensor ([0 ], dtype = torch .int32 ),
70
+ torch .tensor ([1 ], dtype = torch .int32 ),
71
+ ),
72
+ value_tensor = torch .tensor ([5 ], dtype = torch .int32 ),
73
+ ),
74
+ param (
75
+ test_name = "4d_indices_multiple" ,
76
+ source_tensor = torch .zeros ([2 , 2 , 2 , 2 ], dtype = torch .int32 ),
77
+ indices_tensor = (
78
+ torch .tensor ([0 , 1 ], dtype = torch .int32 ),
79
+ torch .tensor ([1 , 1 ], dtype = torch .int32 ),
80
+ torch .tensor ([1 , 0 ], dtype = torch .int32 ),
81
+ torch .tensor ([1 , 0 ], dtype = torch .int32 ),
82
+ ),
83
+ value_tensor = torch .tensor ([5 , 7 ], dtype = torch .int32 ),
84
+ ),
85
+ param (
86
+ test_name = "negative_indices" ,
87
+ source_tensor = torch .zeros ([5 , 5 ], dtype = torch .int32 ),
88
+ indices_tensor = (
89
+ torch .tensor ([- 1 , - 2 ], dtype = torch .int32 ),
90
+ torch .tensor ([2 , 0 ], dtype = torch .int32 ),
91
+ ),
92
+ value_tensor = torch .tensor ([1 , 3 ], dtype = torch .int32 ),
93
+ ),
94
+ param (
95
+ test_name = "mixed_indices" ,
96
+ source_tensor = torch .zeros ([4 , 4 ], dtype = torch .int32 ),
97
+ indices_tensor = (
98
+ torch .tensor ([0 , 1 , - 1 , - 2 ], dtype = torch .int32 ),
99
+ torch .tensor ([0 , - 1 , 2 , 1 ], dtype = torch .int32 ),
100
+ ),
101
+ value_tensor = torch .tensor ([2 , 4 , 6 , 8 ], dtype = torch .int32 ),
102
+ ),
103
+ param (
104
+ test_name = "1d_indices_float" ,
105
+ source_tensor = torch .zeros ([5 ], dtype = torch .float32 ),
106
+ indices_tensor = (torch .tensor ([0 , 3 ], dtype = torch .int32 ),),
107
+ value_tensor = torch .tensor ([1.5 , 3.5 ], dtype = torch .float32 ),
108
+ ),
109
+ param (
110
+ test_name = "2d_indices_float" ,
111
+ source_tensor = torch .zeros ([5 , 5 ], dtype = torch .float32 ),
112
+ indices_tensor = (
113
+ torch .tensor ([0 , 2 ], dtype = torch .int32 ),
114
+ torch .tensor ([2 , 0 ], dtype = torch .int32 ),
115
+ ),
116
+ value_tensor = torch .tensor ([1.5 , 3.5 ], dtype = torch .float32 ),
117
+ ),
118
+ param (
119
+ test_name = "3d_indices_float" ,
120
+ source_tensor = torch .zeros ([3 , 3 , 3 ], dtype = torch .float32 ),
121
+ indices_tensor = (
122
+ torch .tensor ([0 , 1 ], dtype = torch .int32 ),
123
+ torch .tensor ([1 , 2 ], dtype = torch .int32 ),
124
+ torch .tensor ([2 , 0 ], dtype = torch .int32 ),
125
+ ),
126
+ value_tensor = torch .tensor ([5.5 , 7.5 ], dtype = torch .float32 ),
127
+ ),
128
+ param (
129
+ test_name = "4d_indices_float" ,
130
+ source_tensor = torch .zeros ([2 , 2 , 2 , 2 ], dtype = torch .float32 ),
131
+ indices_tensor = (
132
+ torch .tensor ([0 , 1 ], dtype = torch .int32 ),
133
+ torch .tensor ([1 , 0 ], dtype = torch .int32 ),
134
+ torch .tensor ([0 , 1 ], dtype = torch .int32 ),
135
+ torch .tensor ([1 , 0 ], dtype = torch .int32 ),
136
+ ),
137
+ value_tensor = torch .tensor ([5.5 , 7.5 ], dtype = torch .float32 ),
138
+ ),
139
+ # param(
140
+ # test_name="2d_indices_accumulate_True",
141
+ # source_tensor=torch.zeros([5, 5], dtype=torch.int32),
142
+ # indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
143
+ # value_tensor=torch.tensor([1, 2], dtype=torch.int32),
144
+ # accumulate=True,
145
+ # ),
146
+ # param(
147
+ # test_name="3d_indices_accumulate_True",
148
+ # source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
149
+ # indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)),
150
+ # value_tensor=torch.tensor([1, 2], dtype=torch.int32),
151
+ # accumulate=True,
152
+ # ),
25
153
# param(
26
- # test_name="2d_indices",
27
- # source_tensor=torch.zeros([5,5], dtype=torch.int32),
28
- # indices_tensor=(torch.tensor([0,2], dtype=torch.int32),torch.tensor([2,0], dtype=torch.int32),),
29
- # value_tensor=torch.tensor([1,3], dtype=torch.int32),
154
+ # test_name="4d_indices_accumulate_True",
155
+ # source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
156
+ # indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
157
+ # value_tensor=torch.tensor([1, 2], dtype=torch.int32),
158
+ # accumulate=True,
30
159
# ),
31
160
]
32
161
)
@@ -36,7 +165,7 @@ def test_index_put(
36
165
class TestIndexPut (torch .nn .Module ):
37
166
def forward (self , source_tensor , value_tensor ):
38
167
return torch .ops .aten .index_put_ .default (
39
- source_tensor , indices_tensor , value_tensor
168
+ source_tensor , indices_tensor , value_tensor , accumulate
40
169
)
41
170
42
171
self .run_test (
0 commit comments