From 696ea76b66d50c9753da415a9d8dd755c2a7b2bf Mon Sep 17 00:00:00 2001 From: "Yanan Cao (PyTorch)" Date: Wed, 18 Dec 2024 17:02:09 -0800 Subject: [PATCH] pytorch/ao/torchao/experimental/ops/mps/test Reviewed By: avikchaudhuri Differential Revision: D67388057 --- torchao/experimental/ops/mps/test/test_quantizer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 72a6b76fa..eb6663c1e 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -4,18 +4,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional import copy import itertools import os import sys +import unittest +from typing import Optional import torch -import unittest from parameterized import parameterized -from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer -from torchao.experimental.quant_api import _quantize +from torchao.experimental.quant_api import _quantize, UIntxWeightOnlyLinearQuantizer libname = "libtorchao_ops_mps_aten.dylib" libpath = os.path.abspath( @@ -80,7 +79,7 @@ def test_export(self, nbit): activations = torch.randn(m, k0, dtype=torch.float32, device="mps") quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) - exported = torch.export.export(quantized_model, (activations,)) + exported = torch.export.export(quantized_model, (activations,), strict=True) for node in exported.graph.nodes: if node.op == "call_function":