From 666e5ecb84f514b5ea586b9c21e9ba3bc60a55f0 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 23 May 2024 10:24:39 -0700 Subject: [PATCH] ptq test error correction --- tests/py/ts/ptq/test_ptq_dataloader_calibrator.py | 5 +++-- tests/py/ts/ptq/test_ptq_to_backend.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py index 2fac02f542..275aaadd9b 100644 --- a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py +++ b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py @@ -6,6 +6,7 @@ import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms +import torch_tensorrt.ts.ptq as PTQ from torch.nn import functional as F from torch_tensorrt.ts.logging import * @@ -76,11 +77,11 @@ def test_compile_script(self): self.testing_dataloader = torch.utils.data.DataLoader( self.testing_dataset, batch_size=1, shuffle=False, num_workers=1 ) - self.calibrator = torchtrt.ptq.DataLoaderCalibrator( + self.calibrator = PTQ.DataLoaderCalibrator( self.testing_dataloader, cache_file="./calibration.cache", use_cache=False, - algo_type=torchtrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2, + algo_type=PTQ.CalibrationAlgo.ENTROPY_CALIBRATION_2, device=torch.device("cuda:0"), ) diff --git a/tests/py/ts/ptq/test_ptq_to_backend.py b/tests/py/ts/ptq/test_ptq_to_backend.py index d016dedb15..015ce97126 100644 --- a/tests/py/ts/ptq/test_ptq_to_backend.py +++ b/tests/py/ts/ptq/test_ptq_to_backend.py @@ -6,6 +6,7 @@ import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms +import torch_tensorrt.ts.ptq as PTQ from torch.nn import functional as F from torch_tensorrt.ts.logging import * @@ -76,11 +77,11 @@ def test_compile_script(self): self.testing_dataloader = torch.utils.data.DataLoader( self.testing_dataset, batch_size=1, shuffle=False, num_workers=1 ) - self.calibrator = torchtrt.ptq.DataLoaderCalibrator( + self.calibrator = PTQ.DataLoaderCalibrator( self.testing_dataloader, cache_file="./calibration.cache", use_cache=False, - algo_type=torchtrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2, + algo_type=PTQ.CalibrationAlgo.ENTROPY_CALIBRATION_2, device=torch.device("cuda:0"), )