Skip to content

Commit

Permalink
pytorch/ao/torchao/experimental/tests
Browse files Browse the repository at this point in the history
Reviewed By: avikchaudhuri

Differential Revision: D67388084
  • Loading branch information
gmagogsfm authored and facebook-github-bot committed Dec 19, 2024
1 parent 38c79d4 commit 3360ab6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 16 deletions.
17 changes: 11 additions & 6 deletions torchao/experimental/tests/test_embedding_xbit_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import copy

import glob
import subprocess
import os
import subprocess

import sys
import tempfile
Expand All @@ -18,10 +18,11 @@

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
from torchao.experimental.quant_api import (
IntxWeightEmbeddingQuantizer,
_IntxWeightQuantizedEmbeddingFallback,
IntxWeightEmbeddingQuantizer,
)


def cmake_build_torchao_ops(temp_build_dir):
from distutils.sysconfig import get_python_lib

Expand Down Expand Up @@ -62,7 +63,9 @@ def test_accuracy(self):
group_size = 128
embedding_dim = 4096
num_embeddings = 131
model = torch.nn.Sequential(*[torch.nn.Embedding(num_embeddings, embedding_dim)])
model = torch.nn.Sequential(
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
)
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)

for nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
Expand All @@ -88,10 +91,11 @@ def test_export_compile_aoti(self):
group_size = 128
embedding_dim = 4096
num_embeddings = 131
model = torch.nn.Sequential(*[torch.nn.Embedding(num_embeddings, embedding_dim)])
model = torch.nn.Sequential(
*[torch.nn.Embedding(num_embeddings, embedding_dim)]
)
indices = torch.randint(0, num_embeddings, (42,), dtype=torch.int32)


print("Quantizing model")
quantizer = IntxWeightEmbeddingQuantizer(
device="cpu",
Expand All @@ -102,7 +106,7 @@ def test_export_compile_aoti(self):
quantized_model = quantizer.quantize(model)

print("Exporting quantized model")
exported = torch.export.export(quantized_model, (indices,))
exported = torch.export.export(quantized_model, (indices,), strict=True)

print("Compiling quantized model")
quantized_model_compiled = torch.compile(quantized_model)
Expand All @@ -121,5 +125,6 @@ def test_export_compile_aoti(self):
fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu")
fn(indices)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_export_compile_aoti(self):
quantized_model = quantizer.quantize(model)

print("Exporting quantized model")
exported = torch.export.export(quantized_model, (activations,))
exported = torch.export.export(quantized_model, (activations,), strict=True)

print("Compiling quantized model")
quantized_model_compiled = torch.compile(quantized_model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))

from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight
from torchao.quantization.quant_api import quantize_

from torchao.utils import unwrap_tensor_subclass
from torchao.experimental.quant_api import (
_Int8DynActIntxWeightQuantizedLinearFallback,
int8_dynamic_activation_intx_weight,
)
from torchao.quantization.quant_api import quantize_

from torchao.utils import unwrap_tensor_subclass


def cmake_build_torchao_ops(temp_build_dir):
from distutils.sysconfig import get_python_lib
Expand Down Expand Up @@ -98,7 +99,7 @@ def test_accuracy(self):
result = quantized_model(activations)
expected_result = quantized_model_reference(activations)

#TODO: remove expected_result2 checks when we deprecate non-subclass API
# TODO: remove expected_result2 checks when we deprecate non-subclass API
reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback()
reference_impl.quantize_and_pack_weights(
model[0].weight, nbit, group_size, has_weight_zeros
Expand All @@ -115,8 +116,12 @@ def test_accuracy(self):
self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6))
if not torch.allclose(actual_val, expected_val):
num_mismatch_at_low_tol += 1

self.assertTrue(torch.allclose(expected_val, expected_val2, atol=1e-2, rtol=1e-1))

self.assertTrue(
torch.allclose(
expected_val, expected_val2, atol=1e-2, rtol=1e-1
)
)
if not torch.allclose(expected_val, expected_val2):
num_mismatch_at_low_tol2 += 1

Expand Down Expand Up @@ -156,8 +161,8 @@ def test_export_compile_aoti(self):
unwrap_tensor_subclass(model)

print("Exporting quantized model")
exported = torch.export.export(model, (activations,))
exported = torch.export.export(model, (activations,), strict=True)

print("Compiling quantized model")
compiled = torch.compile(unwrapped_model)
with torch.no_grad():
Expand Down

0 comments on commit 3360ab6

Please sign in to comment.