|
4 | 4 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
|
5 | 5 |
|
6 | 6 | import numpy as np
|
| 7 | +import tensorrt as trt |
7 | 8 | import torch
|
8 | 9 | import torch_tensorrt.dynamo.conversion.impl as impl
|
9 | 10 | from torch import SymBool, SymFloat, SymInt
|
|
15 | 16 | ConverterRegistry,
|
16 | 17 | DynamoConverterImplSignature,
|
17 | 18 | )
|
18 |
| -from torch_tensorrt.fx.converters.converter_utils import get_axes_for_reduce_op |
| 19 | +from torch_tensorrt.fx.converters.converter_utils import ( |
| 20 | + broadcast, |
| 21 | + get_axes_for_reduce_op, |
| 22 | +) |
19 | 23 | from torch_tensorrt.fx.types import TRTDataType, TRTTensor
|
20 | 24 |
|
21 |
| -import tensorrt as trt |
22 |
| - |
23 | 25 | _LOGGER: logging.Logger = logging.getLogger(__name__)
|
24 | 26 |
|
25 | 27 |
|
@@ -205,6 +207,72 @@ def broadcastable(
|
205 | 207 | return True
|
206 | 208 |
|
207 | 209 |
|
| 210 | +def broadcast_to_same_shape( |
| 211 | + ctx: ConversionContext, |
| 212 | + target: Target, |
| 213 | + source_ir: Optional[SourceIR], |
| 214 | + name: str, |
| 215 | + lhs_val: TRTTensor, |
| 216 | + rhs_val: TRTTensor, |
| 217 | +) -> Tuple[TRTTensor, TRTTensor]: |
| 218 | + """Broadcast ITensors `lhs_val` and `rhs_val` to the same shape. If the shapes are already the same, return the |
| 219 | + original tensors. If the shapes are different, broadcast the tensors to the same shape. |
| 220 | +
|
| 221 | + This helper function is different from fx/converter_utils.broadcast. |
| 222 | + fx/converter_utils.broadcast only broadcasts two ITensors to the same number of dimensions (ranks) |
| 223 | + by prepending 1s, while this function broadcasts two ITensors to the same shape. |
| 224 | +
|
| 225 | + For example, we have original ITensors: lhs_val.shape: (2, 3) rhs_val.shape: (2, 2, 1, 3) |
| 226 | + If calling fx/converter_utils.broadcast, lhs_val.shape: (1, 1, 2, 3) lhs_val.shape: (2, 2, 1, 3). |
| 227 | + If calling this function broadcast_to_same_shape, lhs_val.shape: (2, 2, 2, 3) lhs_val.shape: (2, 2, 2, 3). |
| 228 | +
|
| 229 | + Args: |
| 230 | + lhs_val (TRTTensor): A TensorRT ITensor. |
| 231 | + rhs_val (TRTTensor): A TensorRT ITensor. |
| 232 | +
|
| 233 | + Returns: |
| 234 | + Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape |
| 235 | +
|
| 236 | + """ |
| 237 | + lhs_val, rhs_val = broadcast( |
| 238 | + ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" |
| 239 | + ) |
| 240 | + |
| 241 | + lhs_val_shape = lhs_val.shape |
| 242 | + rhs_val_shape = rhs_val.shape |
| 243 | + |
| 244 | + if tuple(lhs_val_shape) != tuple(rhs_val_shape): |
| 245 | + rank = len(lhs_val_shape) |
| 246 | + expanded_dims = [-1] * len(lhs_val_shape) |
| 247 | + |
| 248 | + for dim in range(rank): |
| 249 | + expanded_dims[dim] = max(lhs_val_shape[dim], rhs_val_shape[dim]) |
| 250 | + |
| 251 | + expanded_shape = tuple(expanded_dims) |
| 252 | + |
| 253 | + if lhs_val_shape != expanded_shape: |
| 254 | + lhs_val = impl.slice.expand( |
| 255 | + ctx, |
| 256 | + target, |
| 257 | + source_ir, |
| 258 | + f"{name}_expand_lhs_val", |
| 259 | + lhs_val, |
| 260 | + expanded_shape, |
| 261 | + ) |
| 262 | + |
| 263 | + if rhs_val_shape != expanded_shape: |
| 264 | + rhs_val = impl.slice.expand( |
| 265 | + ctx, |
| 266 | + target, |
| 267 | + source_ir, |
| 268 | + f"{name}_expand_rhs_val", |
| 269 | + rhs_val, |
| 270 | + expanded_shape, |
| 271 | + ) |
| 272 | + |
| 273 | + return lhs_val, rhs_val |
| 274 | + |
| 275 | + |
208 | 276 | get_axes_for_reduce_op = functools.partial(
|
209 | 277 | get_axes_for_reduce_op, has_implicit_batch_dimension=False
|
210 | 278 | )
|
|
0 commit comments