Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemlite fixes #1432

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,14 @@ def test_gemlite_layout(self, device, dtype):
test_shape=test_shape,
test_dtype=dtype,
)
# test that shapes with non divisible by 128 shapes aren't causing errors
self._test_lin_weight_subclass_api_impl(
lambda mod: quantize_(mod, gemlite_uintx_weight_only(None, 4, 32)),
device,
15,
test_shape=[1, 1025, 513],
test_dtype=dtype,
)


@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
85 changes: 60 additions & 25 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import platform
import sys
Expand All @@ -18,11 +19,6 @@
import torchao
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5
from torchao._models.utils import (
get_arch_name,
write_json_result_ossci,
write_json_result_local,
)

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False

Expand All @@ -41,6 +37,14 @@ def elapsed_time(self, other_event):
return abs(other_event.event_time - self.event_time) * 1000


def get_arch_name() -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why these changes? is this some rebase issue

if torch.cuda.is_available():
return torch.cuda.get_device_name()
else:
# This returns x86_64 or arm64 (for aarch64)
return platform.machine()


def device_timer(device):
if "cuda" in device:
return torch.cuda.Event(enable_timing=True)
Expand All @@ -61,6 +65,39 @@ def device_sync(device):
print(f"device={device} is not yet suppported")


def write_json_result(output_json_path, headers, row):
"""
Write the result into JSON format, so that it can be uploaded to the benchmark database
to be displayed on OSS dashboard. The JSON format is defined at
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
"""
mapping_headers = {headers[i]: v for i, v in enumerate(row)}
record = {
"benchmark": {
"name": "TorchAO benchmark",
"mode": "inference",
"dtype": mapping_headers["dtype"],
"extra_info": {
"device": mapping_headers["device"],
"arch": mapping_headers["arch"],
},
},
"model": {
"name": mapping_headers["name"],
"type": "model",
"origins": ["pytorch"],
},
"metric": {
"name": mapping_headers["metric"],
"benchmark_values": [mapping_headers["actual"]],
"target_value": mapping_headers["target"],
},
}

with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f:
print(json.dumps(record), file=f)


default_device = (
"cuda"
if torch.cuda.is_available()
Expand Down Expand Up @@ -135,7 +172,7 @@ def decode_n_tokens(
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
# in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step
new_tokens.append(next_token.clone())
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob)
cur_token = next_token
Expand Down Expand Up @@ -279,7 +316,6 @@ def main(
precision=torch.bfloat16,
write_result: Optional[Path] = None,
output_json_path: Optional[Path] = None,
output_json_local: bool = False,
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""

Expand Down Expand Up @@ -692,10 +728,20 @@ def ffn_or_attn_only(mod, fqn):
example_input=inputs,
)
if "autoquant-all" == quantization:
all_qtensor_classes = (
torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST
+ torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST
+ torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
)
if torchao.utils.is_sm_89():
# this is fp8 related subclasses, should rename
all_qtensor_classes += (
torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST
)
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST,
qtensor_class_list=all_qtensor_classes,
example_input=inputs,
)
else:
Expand All @@ -713,10 +759,6 @@ def ffn_or_attn_only(mod, fqn):

# do autoquantization
model.finalize_autoquant()
elif "codebook" in quantization:
from torchao.prototype.quantization.codebook import codebook_weight_only
model.to(device)
quantize_(model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64))

else:
if not TORCH_VERSION_AT_LEAST_2_5:
Expand Down Expand Up @@ -936,14 +978,13 @@ def callback(x):
f.write(result_txt)
f.close()

headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
name = checkpoint_path.parent.name
arch = get_arch_name()
dtype = quantization or str(precision)
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
if output_json_path:
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
name = checkpoint_path.parent.name
arch = get_arch_name()
dtype = quantization or "noquant"
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
write_json_result = write_json_result_local if output_json_local else write_json_result_ossci
write_json_result(output_json_path, headers, memory_result)
write_json_result(output_json_path, headers, performance_result)

Expand Down Expand Up @@ -1045,11 +1086,6 @@ def callback(x):
default=None,
help="Path where to write the json result for dashboard",
)
parser.add_argument(
"--output_json_local",
action="store_true",
help="Whether to output json result for local machine or for CI machine, local option will fill in some dummy fields",
)

args = parser.parse_args()
print(args)
Expand Down Expand Up @@ -1077,5 +1113,4 @@ def callback(x):
args.precision,
args.write_result,
args.output_json_path,
args.output_json_local,
)
13 changes: 13 additions & 0 deletions torchao/dtypes/uintx/gemlite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchao.dtypes.utils import Layout, is_device
from torchao.quantization.quant_primitives import quantize_affine
from torchao.utils import fill_defaults
import warnings

aten = torch.ops.aten

Expand Down Expand Up @@ -76,6 +77,14 @@ def apply_gemlite_quant(
out_features, in_features = weight.shape
group_size = in_features if group_size is None else group_size

if in_features % 128 != 0 and out_features % 128 != 0:
warnings.simplefilter("once", UserWarning)
warnings.warn(
"Gemlite only works for layers with in_features or out_features divisible by 128, "
+ "some layers have been skipped", UserWarning
)
return weight

quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size)

layout = GemlitePackedLayout(
Expand Down Expand Up @@ -173,6 +182,10 @@ def from_plain(
exhaustive=False,
use_cuda_graph=False,
)
if _layout.group_size == None and _layout.bit_width == 4:
from gemlite.core import GEMLITE_ACC_DTYPE
from gemlite.dtypes import DType
GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will only work when all the layers use the same group_size, which is ok for now.
The other option will be using this https://github.com/mobiusml/gemlite/blob/master/gemlite/core.py#L87 but for now let's keep it like this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this manually, it works in all cases even when there are different group sizes.


out_features, in_features = int_data.shape
input_dtype, output_dtype = DType.FP16, DType.FP16
Expand Down
Loading