Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions benchmarks/quantization/eval_accuracy_for_readme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import subprocess

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
PerRow,
)


def string_to_config(s):
if s is None:
return None
elif s == "float8_rowwise":
return Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
elif s == "int4_groupwise_weight_float8_rowwise_activation":
return Float8DynamicActivationInt4WeightConfig()
elif s == "int4_groupwise_hqq_weight_only":
return Int4WeightOnlyConfig(
group_size=32,
int4_packing_format="tile_packed_to_4d",
int4_choose_qparams_algorithm="hqq",
)
elif s == "int8_rowwise_weight_only":
return Int8WeightOnlyConfig()
elif s == "int8_rowwise":
return Int8DynamicActivationInt8WeightConfig()
else:
raise AssertionError(f"unsupported {s}")


def quantize_model_and_save(model_id, quant_config, output_dir="results"):
"""Quantize the model and save it to the output directory."""
print("Quantizing model with config: ", quant_config)
if quant_config is None:
quantization_config = None
else:
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
dtype=torch.bfloat16,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
quantized_model.save_pretrained(output_dir, safe_serialization=False)
tokenizer.save_pretrained(output_dir, safe_serialization=False)
return quantized_model, tokenizer


def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8):
"""Run the lm_eval command using subprocess."""
tasks_str = ",".join(tasks_list)
command = [
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={model_dir}",
"--tasks",
f"{tasks_str}",
"--device",
f"{device}",
"--batch_size",
f"{batch_size}",
"--output_path",
f"{model_dir}/lm_eval_outputs/",
]
subprocess.run(command, check=True)


def get_size_of_dir(model_output_dir):
# get dir size from shell, to skip complexity of dealing with tensor
# subclasses
result = subprocess.run(
["du", "-sb", model_output_dir], capture_output=True, text=True
)
size = int(result.stdout.split()[0])
return size


def run(
model_id: str,
quant_recipe_name: str | None,
tasks,
device,
batch_size,
model_output_dir,
):
print(f"\nRunning {model_id=} with {quant_recipe_name=}\n")
model_name = model_id.split("/")[-1]
model_output_dir = (
f"benchmarks/data/quantized_model/{model_name}-{quant_recipe_name}"
)
quant_config = string_to_config(quant_recipe_name)
quantized_model, tokenizer = quantize_model_and_save(
model_id, quant_config=quant_config, output_dir=model_output_dir
)
print(quantized_model)

model_size = get_size_of_dir(model_output_dir) / 1e9
print(f"checkpoint size: {model_size} GB")

run_lm_eval(
model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size
)
print("done\n")


if __name__ == "__main__":
try:
import lm_eval # noqa: F401
except:
print(
"lm_eval is required to run this script. Please install it using pip install lm-eval."
)
exit(0)

# Set up argument parser
parser = argparse.ArgumentParser(
description="Quantize a model and evaluate its throughput."
)
parser.add_argument(
"--model_id",
type=str,
default="meta-llama/Llama-3.1-8B",
help="The model ID to use.",
)
parser.add_argument(
"--quant_recipe_name",
type=str,
default=None,
help="The quantization recipe to use.",
)
parser.add_argument(
"--tasks",
nargs="+",
type=str,
default=["wikitext"],
help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2",
)
parser.add_argument(
"--device", type=str, default="cuda:0", help="Device to run the model on."
)
parser.add_argument(
"--batch_size", type=str, default="auto", help="Batch size for lm_eval."
)
parser.add_argument(
"--output_dir",
type=str,
default="quantized_models",
help="Output directory for quantized model.",
)
args = parser.parse_args()

# Use parsed arguments
run(
model_id=args.model_id,
quant_recipe_name=args.quant_recipe_name,
tasks=args.tasks,
device=args.device,
batch_size=args.batch_size,
model_output_dir=args.output_dir,
)
30 changes: 30 additions & 0 deletions benchmarks/quantization/eval_accuracy_for_readme.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash

set -e

# Get model_id as positional argument (optional)
MODEL_ID="${1:-meta-llama/Llama-3.1-8B}"

# Get log file as first positional argument (optional)
LOG_FILE="${2:-benchmarks/data/eval_accuracy_for_readme_log.txt}"

# Build the base command arguments
BASE_ARGS="--tasks wikitext winogrande"
if [[ -n "$MODEL_ID" ]]; then
BASE_ARGS="--model_id $MODEL_ID $BASE_ARGS"
fi

# baseline
# note: the -u flag is to prevent python from buffering stdout and stderr
# and make the output log file be in chronological order
time python -u benchmarks/quantization/eval_accuracy_for_readme.py $BASE_ARGS 2>&1 | tee "$LOG_FILE"

# quantized recipes
# note:
# * `int4_groupwise_hqq_weight_float8_rowwise_activation` doesn't work with dtype_map auto: https://gist.github.com/vkuzo/6b128681b628744d445c553cdeac8a85
# * `int4_groupwise_hqq_weight_only` only works on A100
for quant_recipe in float8_rowwise int4_groupwise_weight_float8_rowwise_activation int4_groupwise_hqq_weight_only int8_rowwise_weight_only int8_rowwise; do
time python -u benchmarks/quantization/eval_accuracy_for_readme.py $BASE_ARGS --quant_recipe_name $quant_recipe 2>&1 | tee -a "$LOG_FILE"
done

# TODO(future PR): script to parse the log file instead of manual copy-paste
66 changes: 42 additions & 24 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
@@ -1,36 +1,54 @@
# Quantization
Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change. Note: exact APIs are not stable, we may change them in the future.

## Benchmarks
Benchmarks and evaluation are gathered using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data on the meta-llama/Meta-Llama-3-8B model.
## Accuracy benchmarks

All the following benchmarks are for `meta-llama/Llama-3-8.1B` using `lm-eval` measured on an H100 GPU.

| weight | activation | wikitext-perplexity | winogrande | checkpoint size (GB) |
| --------- | ------------------- | ---------- | -------------------- |
| bfloat16 | bfloat16 | 7.3315 | 0.7380 | 16.1 |
| float8_rowwise | float8_rowwise | 7.4197 | 0.7388 | 9.1 |
| int8_rowwise | bfloat16 | 7.3451 | 0.7340 | 9.1 |
| int8_rowwise | int8_rowwise | 7.4535 | 0.7285 | 9.1 |

To reproduce, run the following command:

```bash
./benchmarks/quantization/eval_accuracy_for_readme.sh
```

## Performance benchmarks

Benchmarks are gathered using the scripts for [generation](../_models/llama/generate.py).

### CUDA backend | NVIDIA-A100-80GB GPU
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 |
| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 |
| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 |
| | fp6 | 7.661 | 161.58 | 910.02 | 7.72 | 5.63 |
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) |
| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- |
| Llama-3-8B | Base (bfloat16) | 95.64 | 1435.54 | 16.43 |
| | int8dq | 8.61 | 64.75 | 9.24 |
| | int8wo | 153.03 | 1150.80 | 10.42 |
| | fp6 | 161.58 | 910.02 | 7.72 |
| | int4wo-64 | 180.80 | 763.33 | 6.88 |
| | int4wo-64-GPTQ | 180.80 | 763.33 | 6.88 |
| | autoquant-int4hqq | 188.41 | 800.58 | 7.14 |

### CUDA backend | NVIDIA-H100 GPU
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 |
| | int8wo | 7.56 | 198.85 | 1495.41 | 11.05 | 7.52 |
| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 |
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 |
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) |
| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- |
| Llama-3.1-8B | Base (bfloat16) | 126.90 | 1904.75 | 16.75 |
| | int8wo | 198.85 | 1495.41 | 11.05 |
| | int4wo-64 | 241.39 | 1019.14 | 7.08 |
| | float8wo | 178.46 | 1339.93 | 12.09 |
| | float8dq (PerTensor) | 116.40 | 873.58 | 11.14 |
| | float8dq (Per Row) | 154.63 | 1161.47 | 11.14 |

### XPU backend | Intel-Max1100
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
| Llama-3-8.1B | Base (bfloat16) | 7.441 | 40.36 | 605.77 | 16.35 | 15.01 |
| | int8dq | 7.581 | 13.60 | 102.28 | 18.69 | 7.52 |
| | int8wo | 7.447 | 59.49 | 447.27 | 18.60 | 7.52
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) |
| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- |
| Llama-3-8.1B | Base (bfloat16) | 40.36 | 605.77 | 16.35 |
| | int8dq | 13.60 | 102.28 | 18.69 |
| | int8wo | 59.49 | 447.27 | 18.60 |


Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are gathered using [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.
Expand Down
Loading