Skip to content

Commit

Permalink
add save_quantized log model total size (ModelCloud#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
PZS-ModelCloud authored Aug 2, 2024
1 parent 6b3e6c7 commit b764a24
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
18 changes: 17 additions & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
convert_gptq_v2_to_v1_format, copy_py_files, find_layers, get_checkpoints, get_device,
get_module_by_name_prefix, get_module_by_name_suffix, get_moe_layer_modules,
gptqmodel_post_init, make_quant, move_to, nested_move_to, pack_model,
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes, get_model_files_size)
from ..version import __version__
from ._const import CPU, CUDA_0, DEVICE, SUPPORTED_MODELS

Expand Down Expand Up @@ -607,6 +607,9 @@ def save_quantized(
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)

pre_quantized_size_mb = get_model_files_size(self.model_name_or_path)
pre_quantized_size_gb = pre_quantized_size_mb / 1024

# write gptqmodel tooling fingerprint to config
self.quantize_config.meta_set_versionable(
key=META_FIELD_QUANTIZER,
Expand Down Expand Up @@ -703,6 +706,7 @@ def save_quantized(
logger.warning(
"We highly suggest saving quantized model using safetensors format for security reasons. Please set `use_safetensors=True` whenever possible.")
torch.save(model.state_dict(), join(save_dir, model_save_name))
total_size_mb = os.path.getsize(join(save_dir, model_save_name)) / (1024 * 1024)
else:
# Shard checkpoint
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name)
Expand All @@ -723,6 +727,7 @@ def save_quantized(
):
os.remove(full_filename)

total_size_mb = 0
# Save the model
for shard_file, shard in shards.items():
if use_safetensors:
Expand Down Expand Up @@ -759,6 +764,8 @@ def save_quantized(
safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
else:
torch.save(shard, join(save_dir, shard_file))
shard_size_mb = os.path.getsize(join(save_dir, shard_file)) / (1024 * 1024)
total_size_mb += shard_size_mb

if index is not None:
index_save_name = model_save_name + ".index.json"
Expand All @@ -767,6 +774,15 @@ def save_quantized(
with open(index_save_path, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

total_size_gb = total_size_mb / 1024
size_diff_mb = pre_quantized_size_mb - total_size_mb
size_diff_gb = size_diff_mb / 1024
percent_diff = (size_diff_mb / pre_quantized_size_mb) * 100
logger.info(f"Pre-Quantized model size: {pre_quantized_size_mb:.2f}MB, {pre_quantized_size_gb:.2f}GB")
logger.info(f"Quantized model size: {total_size_mb:.2f}MB, {total_size_gb:.2f}GB")
logger.info(f"Size difference: {size_diff_mb:.2f}MB, {size_diff_gb:.2f}GB - {percent_diff:.2f}%")

config.quantization_config = quantize_config.to_dict()
config.save_pretrained(save_dir)

Expand Down
21 changes: 21 additions & 0 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,24 @@ def copy_py_files(save_dir, file_extension=".py", model_id_or_path=""):
if file.rfilename.endswith(file_extension):
_ = hf_hub_download(repo_id=model_id_or_path, filename=file.rfilename,
local_dir=save_dir)

def get_model_files_size(pre_quantized_model_path, file_extension=['.bin', '.safetensors', '.pth', '.pt', '.ckpt', '.h5', '.pb', '.onnx']):
if os.path.isdir(pre_quantized_model_path):
pre_quantized_size_bytes = sum(
os.path.getsize(os.path.join(pre_quantized_model_path, f))
for f in os.listdir(pre_quantized_model_path)
if os.path.isfile(os.path.join(pre_quantized_model_path, f)) and os.path.splitext(f)[
1] in file_extension
)
else:
api = HfApi()
files_data = api.list_repo_files(pre_quantized_model_path)
pre_quantized_size_bytes = 0
for file_info in files_data:
if any(file_info.endswith(ext) for ext in file_extension):
file_metadata = api.model_info(pre_quantized_model_path, files_metadata=True)
for file_data in file_metadata.siblings:
if file_data.rfilename == file_info:
pre_quantized_size_bytes += file_data.size
pre_quantized_size_mb = pre_quantized_size_bytes / (1024 * 1024)
return pre_quantized_size_mb

0 comments on commit b764a24

Please sign in to comment.