Skip to content

Commit 5f6f2b2

Browse files
committed
chore: minor linting issue
1 parent 663cc02 commit 5f6f2b2

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -785,20 +785,20 @@ def aten_ops_select(
785785

786786

787787
def index_put_validator(node: Node) -> bool:
788-
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
788+
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
789789
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
790790
accumulate_valid = False
791791
else:
792792
accumulate_valid = True
793-
793+
794794
# Retrieve input tensor's meta information
795795
input_meta = node.args[0].meta.get("tensor_meta")
796796
if not input_meta:
797797
_LOGGER.warning(
798798
"Meta information of input is missing. Unable to validate if broadcasting is needed, falling back to PyTorch operation."
799799
)
800800
return False
801-
801+
802802
input_shape = input_meta.shape
803803
input_num_dims = len(input_shape)
804804

@@ -807,9 +807,11 @@ def index_put_validator(node: Node) -> bool:
807807
if indices_num_dims == input_num_dims:
808808
broadcast_valid = True
809809
else:
810-
_LOGGER.debug("We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions.")
810+
_LOGGER.debug(
811+
"We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions."
812+
)
811813
broadcast_valid = False
812-
814+
813815
# Return validation result
814816
return accumulate_valid and broadcast_valid
815817

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class TestIndexPutConverter(DispatchTestCase):
140140
# test_name="3d_indices_float_broadcase_index",
141141
# source_tensor=torch.zeros([3, 3, 3], dtype = torch.int32),
142142
# indices_tensor=(
143-
# torch.tensor([0,1], dtype=torch.int32),
143+
# torch.tensor([0,1], dtype=torch.int32),
144144
# torch.tensor([0,1], dtype=torch.int32),
145145
# ),
146146
# value_tensor=torch.tensor([10], dtype = torch.int32),

0 commit comments

Comments
 (0)