Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check onnx size before loading #1467

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import gc
import multiprocessing as mp
import os
import sys
import traceback
from inspect import signature
from itertools import chain
Expand All @@ -28,7 +29,7 @@
import onnx
from transformers.utils import is_tf_available, is_torch_available

from ...onnx.utils import _get_onnx_external_data_tensors, check_model_uses_external_data
from ...onnx.utils import _get_onnx_external_data_tensors
from ...utils import (
TORCH_MINIMUM_VERSION,
is_diffusers_available,
Expand Down Expand Up @@ -581,15 +582,16 @@ def remap(value):
)

# check if external data was exported
# TODO: this is quite inefficient as we load in memory if models are <2GB without external data
onnx_model = onnx.load(str(output), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)
if (
os.path.getsize(str(output)) + sys.getsizeof(bytes()) > onnx.checker.MAXIMUM_PROTOBUF
or FORCE_ONNX_EXTERNAL_DATA
):
onnx_model = onnx.load(str(output), load_external_data=False)

if model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA:
tensors_paths = _get_onnx_external_data_tensors(onnx_model)
logger.info("Saving external data to one file...")

# try free model memory
# try to free up model memory
del model
del onnx_model
gc.collect()
Expand Down
Loading