diff --git a/tests/benchmark/README.md b/tests/benchmark/README.md index e1fb1f01dd5..b6d6f9cddad 100644 --- a/tests/benchmark/README.md +++ b/tests/benchmark/README.md @@ -4,29 +4,27 @@ Please refer to https://medium.com/pytorch/bettertransformer-out-of-the-box-perf # GPTQ benchmark -The results below are for AutoGPTQ 0.4.2, PyTorch 2.0.1, bitsandbytes 0.41.1, transformers 4.32. +The results below are for AutoGPTQ 0.5.0, PyTorch 2.0.1, bitsandbytes 0.41.1, transformers 4.35. ## Generation benchmark results Run ```shell -git clone --branch main https://huggingface.co/TheBloke/Llama-2-13B-chat-GPTQ -cd Llama-2-13B-chat-GPTQ -mv gptq_model-4bit-128g.safetensors model.safetensors -mv quantize_config.json quantization_config.json - # pytorch fp16 -CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --sweep --num-batches 4 --task text-generation +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model meta-llama/Llama-2-13b-chat-hf --sweep --num-batches 4 --task text-generation --generate + +# GPTQ with exllamav2 kernel (int4/fp16) +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model TheBloke/Llama-2-13B-chat-GPTQ --sweep --num-batches 4 --gptq --task text-generation --use-exllama --exllama-version 2 --generate # GPTQ with exllama kernel (int4/fp16) -CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --gptq-model /path/to/Llama-2-13B-chat-GPTQ/ --sweep --num-batches 4 --gptq --task text-generation +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model TheBloke/Llama-2-13B-chat-GPTQ --sweep --num-batches 4 --gptq --task text-generation --use-exllama --generate # GPTQ without exllama kernel (int4/fp16) -CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --gptq-model /path/to/Llama-2-13B-chat-GPTQ/ --sweep --num-batches 4 --gptq --task text-generation --disable-exllama +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model TheBloke/Llama-2-13B-chat-GPTQ --sweep --num-batches 4 --gptq --task text-generation --generate # using bitsandbytes fp4/fp16 scheme -CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --sweep --num-batches 4 --task text-generation --bitsandbytes +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model meta-llama/Llama-2-13b-chat-hf --sweep --num-batches 4 --task text-generation --bitsandbytes --generate ``` Here are results obtained on a single NVIDIA A100-SXM4-80GB GPU. We use a prompt length of 512, and generate exactly 512 new tokens. Each generation is repeated for 4 batches, and metrics are averaged over the number of batches and generation length. @@ -42,6 +40,7 @@ Bitsandbytes uses the fp4 scheme, with the compute in fp16. |quantization |act_order|bits|group_size|kernel|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)| |-----|---------|----|----------|------|-------------|----------------------|------------------|----------------| |None|None |None|None |None |26.0 |36.958 |27.058 |29152.98 | +| gptq | False | 4 | 128 | exllamav2 | 36.07 | 32.25 | 31.01 | 11313.75 | |gptq |False |4 |128 |exllama|36.2 |33.711 |29.663 |10484.34 | |gptq |False |4 |128 |autogptq-cuda-old|36.2 |46.44 |21.53 |10344.62 | |bitsandbytes|None |None|None |None |37.64 |52.00 |19.23 |11018.36 | @@ -51,6 +50,7 @@ Bitsandbytes uses the fp4 scheme, with the compute in fp16. |quantization |act_order|bits|group_size|kernel|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)| |-----|---------|----|----------|------|-------------|----------------------|------------------|----------------| |None|None |None|None |None |26.0 |37.35 |53.53 |30831.09 | +| gptq | False | 4 | 128 | exllamav2 | 36.07 | 35.81 | 55.85 | 12112.42 | |gptq |False |4 |128 |exllama|36.2 |37.25 |53.68 |12162.43 | |gptq |False |4 |128 |autogptq-cuda-old|36.2 |47.41 |42.18 |12020.34 | |bitsandbytes|None |None|None |None |37.64 |74.62 |26.80 |12834.84 | @@ -60,6 +60,7 @@ Bitsandbytes uses the fp4 scheme, with the compute in fp16. |quantization |act_order|bits|group_size|kernel |Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)| |-----|---------|----|----------|-----------------|-------------|----------------------|------------------|----------------| |None|None |None|None |None |26.0 |37.89 |105.55 |34187.22 | +| gptq | False | 4 | 128 | exllamav2 | 36.07 | 36.04 | 110.98 | 16387.19 | |gptq |False |4 |128 |exllama |36.2 |54.14 |73.87 |15518.55 | |gptq |False |4 |128 |autogptq-cuda-old|36.2 |60.98 |65.59 |15374.67 | |bitsandbytes|None |None|None |None |37.64 |80.24 |49.85 |16187.69 | @@ -69,6 +70,7 @@ Bitsandbytes uses the fp4 scheme, with the compute in fp16. |quantization |act_order|bits|group_size|kernel|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)| |-----|---------|----|----------|------|-------------|----------------------|------------------|----------------| |None|None |None|None |None |26.0 |47.37 |168.86 |40327.62 | +| gptq | False | 4 | 128 | exllamav2 | 36.07 | 47.31 | 169.11 | 22463.02 | |gptq |False |4 |128 |exllama|36.2 |73.57 |108.73 |21864.56 | |gptq |False |4 |128 |autogptq-cuda-old|36.2 |104.44 |76.59 |20987.68 | |bitsandbytes|None |None|None |None |37.64 |91.29 |87.63 |22894.02 | @@ -78,6 +80,7 @@ Bitsandbytes uses the fp4 scheme, with the compute in fp16. |quantization |act_order|bits|group_size|kernel|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Peak memory (MB)| |-----|---------|----|----------|------|-------------|----------------------|------------------|----------------| |None|None |None|None |None |26.0 |69.94 |228.76 |53986.51 | +| gptq | False | 4 | 128 | exllamav2 | 36.07 | 83.09 | 192.55 | 35740.95 | |gptq |False |4 |128 |exllama|36.2 |95.41 |167.68 |34777.04 | |gptq |False |4 |128 |autogptq-cuda-old|36.2 |192.48 |83.12 |35497.62 | |bitsandbytes|None |None|None |None |37.64 |113.98 |140.38 |35532.37 | @@ -88,16 +91,19 @@ Run ```shell # pytorch fp16 -CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --sweep --num-batches 10 --task text-generation --prefill +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model meta-llama/Llama-2-13b-chat-hf --sweep --num-batches 10 --task text-generation --prefill --generate -# GPTQ with exllama kernel (int4/fp16) -CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --gptq-model ../../../Llama-2-13B-chat-GPTQ/ --sweep --num-batches 10 --gptq --task text-generation --prefill +# GPTQ with exllamav2 kernel (int4/fp16) +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model TheBloke/Llama-2-13B-chat-GPTQ --sweep --num-batches 10 --gptq --task text-generation --prefill --use-exllama --exllama-version 2 --generate + +# GPTQ with exllamav kernel (int4/fp16) +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model TheBloke/Llama-2-13B-chat-GPTQ --sweep --num-batches 10 --gptq --task text-generation --prefill --use-exllama --generate # GPTQ without exllama kernel (int4/fp16) -CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --gptq-model ../../../Llama-2-13B-chat-GPTQ/ --sweep --num-batches 10 --gptq --task text-generation --prefill --disable-exllama +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model TheBloke/Llama-2-13B-chat-GPTQ --sweep --num-batches 10 --gptq --task text-generation --prefill --generate # using bitsandbytes fp4/fp16 scheme -CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model daryl149/llama-2-13b-chat-hf --sweep --num-batches 10 --task text-generation --prefill --bitsandbytes +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model meta-llama/Llama-2-13b-chat-hf --sweep --num-batches 10 --task text-generation --prefill --bitsandbytes --generate ``` The benchmark below is for a prompt length of 512, measuring only the prefill step on a single NVIDIA A100-SXM4-80GB GPU. The forward is repeated 10 times. This benchmark typically corresponds to the forward during training (to the difference that here `generate` is called, which has some overhead). @@ -107,6 +113,7 @@ The benchmark below is for a prompt length of 512, measuring only the prefill st |quantization |act_order|bits|group_size|kernel |prompt_length|new_tokens|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Max memory (MB)| |-----|---------|----|----------|-----------------|-------------|----------|-------------|----------------------|------------------|---------------| |None|None |None|None |None |512 |1 |27.22 |96.38 |10.38 |27999.54 | +| gptq | False | 4 | 128 | exllamav2 | 512 | 1 | 6.63 | 116.07 | 8.62 | 10260.35 | |gptq |False |4 |128 |exllama |512 |1 |38.35 |112.54 |8.89 |9330.89 | |gptq |False |4 |128 |autogptq-cuda-old|512 |1 |43.94 |368.13 |2.72 |9474.19 | |bitsandbytes|None|None|None|None|512|1 |37.46|139.17 |7.19 |9952.65 | @@ -116,6 +123,7 @@ The benchmark below is for a prompt length of 512, measuring only the prefill st |quantization |act_order|bits|group_size|kernel |prompt_length|new_tokens|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Max memory (MB)| |-----|---------|----|----------|-----------------|-------------|----------|-------------|----------------------|------------------|---------------| |None|None |None|None |None |512 |1 |27.22 |169.95 |11.77 |28524.37 | +| gptq | False | 4 | 128 | exllamav2 | 512 | 1 | 6.63 | 212.07 | 9.43 | 10783.60 | |gptq |False |4 |128 |exllama |512 |1 |38.35 |190.44 |10.50 |9855.71 | |gptq |False |4 |128 |autogptq-cuda-old|512 |1 |43.94 |443.80 |4.51 |9928.23 | |bitsandbytes|None|None|None|None|512|1 |37.46|212.76 |9.40 |10421.89| @@ -125,6 +133,7 @@ The benchmark below is for a prompt length of 512, measuring only the prefill st |quantization |act_order|bits|group_size|kernel |prompt_length|new_tokens|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Max memory (MB)| |-----|---------|----|----------|-----------------|-------------|----------|-------------|----------------------|------------------|---------------| |None|None |None|None |None |512 |1 |27.22 |305.99 |13.07 |29574.01 | +| gptq | False | 4 | 128 | exllamav2 | 512 | 1 | 6.63 | 385.58 | 10.37 | 11829.59 | |gptq |False |4 |128 |exllama |512 |1 |38.35 |345.54 |11.58 |10905.35 | |gptq |False |4 |128 |autogptq-cuda-old|512 |1 |43.94 |597.24 |6.70 |10838.42 | |bitsandbytes|None|None|None|None|512|1 |37.46|349.18 |11.46|11440.08| @@ -134,15 +143,46 @@ The benchmark below is for a prompt length of 512, measuring only the prefill st |quantization |act_order|bits|group_size|kernel |prompt_length|new_tokens|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Max memory (MB)| |-----|---------|----|----------|-----------------|-------------|----------|-------------|----------------------|------------------|---------------| |None|None |None|None |None |512 |1 |27.22 |600.47 |13.32 |31673.30 | +| gptq | False | 4 | 128 | exllamav2 | 512 | 1 | 6.63 | 753.06 | 10.62 | 13920.50 | |gptq |False |4 |128 |exllama |512 |1 |38.35 |659.61 |12.13 |13004.64 | |gptq |False |4 |128 |autogptq-cuda-old|512 |1 |43.94 |909.09 |8.80 |12862.18 | |bitsandbytes|None|None|None|None|512|1 |37.46|643.42 |12.43|13539.37| ### Batch size = 16 -|quantization |act_order|bits|group_size|kernel |num_batches|batch_size|prompt_length|new_tokens|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Max memory (MB)| -|-----|---------|----|----------|-----------------|-----------|----------|-------------|----------|-------------|----------------------|------------------|---------------| -|None|None |None|None |None |10 |16 |512 |1 |27.22 |1209.07 |13.23 |35871.88 | -|gptq |False |4 |128 |exllama |10 |16 |512 |1 |38.35 |1280.25 |12.50 |17203.22 | -|gptq |False |4 |128 |autogptq-cuda-old|10 |16 |512 |1 |43.94 |1533.54 |10.43 |17060.76 | +|quantization |act_order|bits|group_size|kernel |prompt_length|new_tokens|Load time (s)|Per-token latency (ms)|Throughput (tok/s)|Max memory (MB)| +|-----|---------|----|-----------|----------|-------------|----------|-------------|----------------------|------------------|---------------| +|None|None |None|None |None |512 |1 |27.22 |1209.07 |13.23 |35871.88 | +| gptq | False | 4 | 128 | exllamav2 | 512 | 1 | 6.63 | 1467.36 | 10.90 | 18104.44 | +|gptq |False |4 |128 |exllama |512 |1 |38.35 |1280.25 |12.50 |17203.22 | +|gptq |False |4 |128 |autogptq-cuda-old |512 |1 |43.94 |1533.54 |10.43 |17060.76 | |bitsandbytes|None|None|None|None|512|1 |37.46|1256.88|12.73|17737.95| + +## Perplexity benchmark results + +Run + +```shell +# pytorch fp16 +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model meta-llama/Llama-2-13b-chat-hf --task text-generation --ppl + +# GPTQ with exllamav2 kernel (int4/fp16) +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model TheBloke/Llama-2-13B-chat-GPTQ --revision gptq-4bit-128g-actorder_True --gptq --task text-generation --use-exllama --exllama-version 2 --ppl + +# GPTQ with exllama kernel (int4/fp16) +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model TheBloke/Llama-2-13B-chat-GPTQ --revision gptq-4bit-128g-actorder_True --gptq --task text-generation --use-exllama --ppl + +# GPTQ without exllama kernel (int4/fp16) +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model TheBloke/Llama-2-13B-chat-GPTQ --revision gptq-4bit-128g-actorder_True --gptq --task text-generation --ppl + +# using bitsandbytes fp4/fp16 scheme +CUDA_VISIBLE_DEVICES=0 python benchmark_gptq.py --model meta-llama/Llama-2-13b-chat-hf ---task text-generation --bitsandbytes --ppl +``` + +| quantization | act_order | bits | group_size | kernel | perplexity | +|--------------|-----------|------|------------|------------------|------------| +| None | None | None | None | None | 6.61 | +| gptq | True | 4 | 128 | exllamav2 | 6.77 | +| gptq | True | 4 | 128 | exllama | 6.77 | +| gptq | True | 4 | 128 | autogptq-cuda-old| 6.77 | +| bitsandbytes | None | 4 | None | None | 6.78 | \ No newline at end of file diff --git a/tests/benchmark/benchmark_gptq.py b/tests/benchmark/benchmark_gptq.py index 06af05056a3..29f986015a4 100644 --- a/tests/benchmark/benchmark_gptq.py +++ b/tests/benchmark/benchmark_gptq.py @@ -1,12 +1,11 @@ import argparse import gc -import json import os import time import numpy as np import torch -from accelerate import init_empty_weights +from auto_gptq.utils import Perplexity from memory_tracker import MemoryTracker from tqdm import tqdm from transformers import ( @@ -16,10 +15,10 @@ AutoTokenizer, BitsAndBytesConfig, GenerationConfig, + GPTQConfig, ) from optimum.exporters import TasksManager -from optimum.gptq import load_quantized_model def get_parser(): @@ -45,13 +44,7 @@ def get_parser(): parser.add_argument( "--model", type=str, - help="Model to benchmark (in the non-quantized case), or reference architecture corresponding to the quantized model (GPTQ case)", - ) - parser.add_argument( - "--gptq-model", - type=str, - default=None, - help="Path to a local GPTQ model.", + help="Model to benchmark", ) parser.add_argument( "--prompt-length", @@ -86,10 +79,32 @@ def get_parser(): help="Use the parameter ranges for (batch_size, prompt_length, new_tokens) defined in the .py file instead of the CLI ones.", ) parser.add_argument( - "--disable-exllama", + "--use-exllama", + action="store_true", + help="Use Exllama kernel, to rather use the AutoGPTQ CUDA (act-order case) or CUDA-old (no act-order case) kernels.", + ) + parser.add_argument( + "--exllama-version", + type=int, + default=2, + help="Use Exllamav2 kernel. Set 1 in order to use exllama kernel", + ) + parser.add_argument( + "--generate", + action="store_true", + help="Calculate the generate speed (prompt processing + token generation)", + ) + parser.add_argument( + "--ppl", action="store_true", - help="Disable Exllama kernel, to rather use the AutoGPTQ CUDA (act-order case) or CUDA-old (no act-order case) kernels.", + help="Calculate the perplexity on wikitext2 dataset", ) + parser.add_argument( + "--revision", + default=None, + help="Revision of the model to benchmark", + ) + return parser @@ -222,7 +237,7 @@ def benchmark_memory( # I am not sure whether we should substract here `inactive_split_bytes.all.peak` (not sure what it corresponds to, though it can get quite large, in the several GB). peak_external_mb = peak_nvml_mb - peak_reserved_torch_mb - assert peak_external_mb > 0 + # assert peak_external_mb > 0 # This formula is to confirm. We measure the actual allocated PyTorch memory, plus the additional non-PyTorch memory (as the CUDA context, CUDA extension device memory). We need to substract the PyTorch peak reserved memory since this one appears in the peak nvidia-smi/nvmlDeviceGetMemoryInfo. @@ -266,7 +281,7 @@ def benchmark_memory( device = torch.device("cuda:0") memory_tracker = MemoryTracker() -tokenizer = AutoTokenizer.from_pretrained(args.model) +tokenizer = AutoTokenizer.from_pretrained(args.model, revision=args.revision, use_fast=False) if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -288,46 +303,22 @@ def benchmark_memory( else: is_decoder = False -act_order = None -bits = None -group_size = None -kernel = None -if args.gptq: - if not args.gptq_model: - raise ValueError("The argument --gptq-model needs to be provided when benchmarking GPTQ.") - - with open(os.path.join(args.gptq_model, "quantization_config.json"), "r", encoding="utf-8") as f: - quantize_config_dict = json.load(f) - - act_order = quantize_config_dict["desc_act"] - bits = quantize_config_dict["bits"] - group_size = quantize_config_dict["group_size"] - - if not args.disable_exllama: - # Exllama kernel can handle both the act-order / no act-order cases. - kernel = "exllama" - elif act_order: - kernel = "autotogptq-cuda" - else: - kernel = "autogptq-cuda-old" - load_start = time.time_ns() if args.gptq: - with init_empty_weights(): - empty_model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16) - empty_model.tie_weights() - model = load_quantized_model( - empty_model, - save_folder=args.gptq_model, - state_dict_name="model.safetensors", + quantization_config = GPTQConfig( + bits=4, use_exllama=args.use_exllama, exllama_config={"version": args.exllama_version} + ) + model = autoclass.from_pretrained( + args.model, + revision=args.revision, + quantization_config=quantization_config, + torch_dtype=torch.float16, device_map="auto", - disable_exllama=args.disable_exllama, ) elif args.bitsandbytes: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="fp4", bnb_4bit_compute_dtype=torch.float16 ) - model = autoclass.from_pretrained( args.model, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.float16 ) @@ -337,6 +328,29 @@ def benchmark_memory( torch.cuda.synchronize() load_end = time.time_ns() +act_order = None +bits = None +group_size = None +kernel = None + +if args.gptq: + quantization_config_dict = model.config.quantization_config.to_dict() + act_order = quantization_config_dict["desc_act"] + bits = quantization_config_dict["bits"] + group_size = quantization_config_dict["group_size"] + use_exllama = quantization_config_dict["use_exllama"] + exllama_version = quantization_config_dict["exllama_config"]["version"] + + if use_exllama: + if exllama_version == 2: + kernel = "exllamav2" + else: + kernel = "exllama" + elif act_order: + kernel = "autotogptq-cuda" + else: + kernel = "autogptq-cuda-old" + load_time = (load_end - load_start) * 1e-9 print(f"Model load time: {load_time:.1f} s") @@ -364,82 +378,100 @@ def benchmark_memory( file_name = file_name + "_noquant" quantization = None -file_name = file_name + ".csv" -output_file = open(file_name, "w") -header = "quantization, act_order, bits, group_size, kernel, num_batches, batch_size, prompt_length, new_tokens, Load time (s), Per-token latency (ms), Throughput (tok/s), Max memory (MB)\n" -output_file.write(header) - -latencies = {} -throughputs = {} -all_max_mem = {} -print( - "WARNING: The reported peak memory is only a rough estimate, and can NOT be precisely relied upon to estimate an OOM limit." -) - -for batch_size in tqdm(batch_sizes): - for prompt_length in tqdm(prompt_lengths): - for new_token in tqdm(new_tokens): - print(f"---- Running: batch_size={batch_size}, prompt_length={prompt_length}, new_tokens={new_token}") - - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device) - masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device) +if args.ppl: + output_file = open(file_name + "_perplexity.csv", "w") + header = "quantization, act_order, bits, group_size, kernel, perplexity\n" + output_file.write(header) + ppl = Perplexity(model, tokenizer) + ppl_value = np.mean(ppl.calculate_perplexity()) + line = "{},{},{},{},{},{}\n".format( + quantization, + act_order, + bits, + group_size, + kernel, + f"{ppl_value:.2f}", + ) + print(header) + print(line) + output_file.write(line) + output_file.close() + +if args.generate: + output_file = open(file_name + ".csv", "w") + header = "quantization, act_order, bits, group_size, kernel, num_batches, batch_size, prompt_length, new_tokens, Load time (s), Per-token latency (ms), Throughput (tok/s), Max memory (MB)\n" + output_file.write(header) + + latencies = {} + throughputs = {} + all_max_mem = {} + print( + "WARNING: The reported peak memory is only a rough estimate, and can NOT be precisely relied upon to estimate an OOM limit." + ) - with torch.no_grad(): - max_mem = benchmark_memory( - model, - input_ids, - masks, - args.num_batches, - is_decoder, - new_token, - tokenizer.pad_token_id, - memory_tracker=memory_tracker, + for batch_size in tqdm(batch_sizes): + for prompt_length in tqdm(prompt_lengths): + for new_token in tqdm(new_tokens): + print(f"---- Running: batch_size={batch_size}, prompt_length={prompt_length}, new_tokens={new_token}") + + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device) + masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device) + + with torch.no_grad(): + max_mem = benchmark_memory( + model, + input_ids, + masks, + args.num_batches, + is_decoder, + new_token, + tokenizer.pad_token_id, + memory_tracker=memory_tracker, + ) + + mean_latency = benchmark_latency( + model, + input_ids, + masks, + args.num_batches, + is_decoder, + new_token, + tokenizer.pad_token_id, + memory_tracker=memory_tracker, + ) + + index = (batch_size, prompt_length, new_token) + + per_token_latency = mean_latency / new_token + latencies[index] = per_token_latency + + throughput = batch_size / (per_token_latency * 1e-3) + throughputs[index] = throughput + all_max_mem[index] = max_mem + + print( + f"Latency per token: {per_token_latency:.3f} ms, throughput: {throughput:.3f} tok/s, peak mem: {max_mem:.2f} MB" ) - mean_latency = benchmark_latency( - model, - input_ids, - masks, + line = "{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format( + quantization, + act_order, + bits, + group_size, + kernel, args.num_batches, - is_decoder, + batch_size, + prompt_length, new_token, - tokenizer.pad_token_id, - memory_tracker=memory_tracker, + f"{load_time:.2f}", + f"{per_token_latency:.2f}", + f"{throughput:.2f}", + f"{max_mem:.2f}", ) - - index = (batch_size, prompt_length, new_token) - - per_token_latency = mean_latency / new_token - latencies[index] = per_token_latency - - throughput = batch_size / (per_token_latency * 1e-3) - throughputs[index] = throughput - all_max_mem[index] = max_mem - - print( - f"Latency per token: {per_token_latency:.3f} ms, throughput: {throughput:.3f} tok/s, peak mem: {max_mem:.2f} MB" - ) - - line = "{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format( - quantization, - act_order, - bits, - group_size, - kernel, - args.num_batches, - batch_size, - prompt_length, - new_token, - f"{load_time:.2f}", - f"{per_token_latency:.2f}", - f"{throughput:.2f}", - f"{max_mem:.2f}", - ) - print(header) - print(line) - output_file.write(line) - -output_file.close() + print(header) + print(line) + output_file.write(line) + output_file.close()