Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue with converting vlm model "InternVL2_5-1B" #1073

Open
endomorphosis opened this issue Dec 13, 2024 · 10 comments · May be fixed by #1105
Open

issue with converting vlm model "InternVL2_5-1B" #1073

endomorphosis opened this issue Dec 13, 2024 · 10 comments · May be fixed by #1105
Assignees

Comments

@endomorphosis
Copy link

devel@workstation:/tmp$ optimum-cli export openvino --model OpenGVLab/InternVL2_5-1B InternVL2_5-1B --trust-remote-code >> /tmp/openvino.txt

Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
The class `optimum.bettertransformers.transformation.BetterTransformer` is deprecated and will be removed in a future release.
WARNING:root:Cannot apply model.to_bettertransformer because of the exception:
The model type qwen2 is not yet supported to be used with BetterTransformer. Feel free to open an issue at https://github.com/huggingface/optimum/issues if you would like this model type to be supported. Currently supported models are: dict_keys(['albert', 'bark', 'bart', 'bert', 'bert-generation', 'blenderbot', 'bloom', 'camembert', 'blip-2', 'clip', 'codegen', 'data2vec-text', 'deit', 'distilbert', 'electra', 'ernie', 'fsmt', 'gpt2', 'gptj', 'gpt_neo', 'gpt_neox', 'hubert', 'layoutlm', 'm2m_100', 'marian', 'markuplm', 'mbart', 'opt', 'pegasus', 'rembert', 'prophetnet', 'roberta', 'roc_bert', 'roformer', 'splinter', 'tapas', 't5', 'vilt', 'vit', 'vit_mae', 'vit_msn', 'wav2vec2', 'xlm-roberta', 'yolos']).. Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
/home/devel/.local/lib/python3.12/site-packages/transformers/cache_utils.py:458: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  or len(self.key_cache[layer_idx]) == 0  # the layer has no cache
/home/devel/.local/lib/python3.12/site-packages/transformers/cache_utils.py:443: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  elif len(self.key_cache[layer_idx]) == 0:  # fills previously skipped layers; checking for tensor causes errors
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float32.
/home/devel/.local/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py:51: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_seqlen_in_batch = seqlens_in_batch.max().item()
/home/devel/.local/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py:106: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if query_length == kv_seq_len:
/home/devel/.local/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py:111: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  elif query_length == 1:
/home/devel/.local/lib/python3.12/site-packages/flash_attn/bert_padding.py:115: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_seqlen_in_batch = seqlens_in_batch.max().item()
Export model to OpenVINO directly failed with: 
Couldn't get TorchScript module by tracing.
Please check correctness of provided 'example_input'. Sometimes models can be converted in scripted mode, please try running conversion without 'example_input'.
 You can also provide TorchScript module that you obtained yourself, please refer to PyTorch documentation: https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html..
Model will be exported to ONNX
[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. A stateless model will be exported instead. It may result in sub-optimal inference performance.Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path.
Traceback (most recent call last):
  File "/home/devel/.local/lib/python3.12/site-packages/openvino/frontend/pytorch/ts_decoder.py", line 57, in __init__
    pt_module = self._get_scripted_model(
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/openvino/frontend/pytorch/ts_decoder.py", line 156, in _get_scripted_model
    scripted = torch.jit.trace(
               ^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/jit/_trace.py", line 820, in trace
    return trace_module(
           ^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/jit/_trace.py", line 1088, in trace_module
    module._c._create_method_from_trace(
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 11, in forward
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/openvino/convert.py", line 419, in ts_patched_forward
    outputs = patched_forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/onnx/model_patcher.py", line 151, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1164, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 895, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 443, in forward
    attn_output = _flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 246, in _flash_attention_forward
    query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
                                                                                   ^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 121, in _upad_input
    query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: too many values to unpack (expected 4)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/openvino/convert.py", line 431, in export_pytorch
    ov_model = convert_model(
               ^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/openvino/tools/ovc/convert.py", line 101, in convert_model
    ov_model, _ = _convert(cli_parser, params, True)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/openvino/tools/ovc/convert_impl.py", line 563, in _convert
    raise e
  File "/home/devel/.local/lib/python3.12/site-packages/openvino/tools/ovc/convert_impl.py", line 458, in _convert
    get_pytorch_decoder(args['input_model'], example_inputs, args)
  File "/home/devel/.local/lib/python3.12/site-packages/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py", line 70, in get_pytorch_decoder
    decoder = TorchScriptPythonDecoder(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/openvino/frontend/pytorch/ts_decoder.py", line 69, in __init__
    raise RuntimeError(
RuntimeError: Couldn't get TorchScript module by tracing.
Please check correctness of provided 'example_input'. Sometimes models can be converted in scripted mode, please try running conversion without 'example_input'.
 You can also provide TorchScript module that you obtained yourself, please refer to PyTorch documentation: https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/commands/export/openvino.py", line 394, in run
    main_export(
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/openvino/__main__.py", line 420, in main_export
    submodel_paths = export_from_model(
                     ^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/openvino/convert.py", line 771, in export_from_model
    export_models(
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/openvino/convert.py", line 553, in export_models
    export(
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/openvino/convert.py", line 179, in export
    return export_pytorch(
           ^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/openvino/convert.py", line 455, in export_pytorch
    return export_pytorch_via_onnx(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/openvino/convert.py", line 298, in export_pytorch_via_onnx
    input_names, output_names = export_pytorch_to_onnx(
                                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/onnx/convert.py", line 584, in export_pytorch
    onnx_export(
  File "/home/devel/.local/lib/python3.12/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/home/devel/.local/lib/python3.12/site-packages/torch/onnx/utils.py", line 1612, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/onnx/utils.py", line 1010, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/onnx/utils.py", line 914, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/jit/_trace.py", line 1310, in _get_trace_graph
    outs = ONNXTracedModule(
           ^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/jit/_trace.py", line 138, in forward
    graph, out = torch._C._create_graph_by_tracing(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/jit/_trace.py", line 129, in wrapper
    outs.append(self.inner(*trace_inputs))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/optimum/exporters/onnx/model_patcher.py", line 151, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1164, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 895, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 145, in wrapped
    return module_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 443, in forward
    attn_output = _flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 246, in _flash_attention_forward
    query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
                                                                                   ^^^^^^^^^^^^
  File "/home/devel/.local/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 121, in _upad_input
    query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: too many values to unpack (expected 4)
@eaidova
Copy link
Collaborator

eaidova commented Dec 25, 2024

@endomorphosis please try to uninstall flash_attention_2 package

@endomorphosis
Copy link
Author

@endomorphosis please try to uninstall flash_attention_2 package

I would, however that is not likely going to work with the design parameters of my project, e.g. a peer to peer model server and endpoint aggregator, that is platform agnostic and model agnostic

@endomorphosis
Copy link
Author

image
also merry christmas

@endomorphosis
Copy link
Author

if this is indeed meant to be treated like an executable, why not just have all of the dependencies for this executable baked in, so that this sort of package conflict wont happen?

@Florianoli
Copy link

Florianoli commented Jan 5, 2025

Hi @endomorphosis, is there already a solution to this? I encountered the same problem during conversion of the jina-embedding-v3 model

@endomorphosis
Copy link
Author

I did not find a solution to fixing this CLI command, and otherwise there is another means of converting the model, whereby openvino traces the torchscript code that is evaluated, before converting it to openvino IR. see e.g. https://github.com/endomorphosis/ipfs_accelerate_py/blob/212c5ad39db2f8d60c3e0230f0025e25c72cf6c2/ipfs_accelerate_py/worker/openvino_utils.py#L197

@eaidova eaidova self-assigned this Jan 10, 2025
@eaidova
Copy link
Collaborator

eaidova commented Jan 10, 2025

if this is indeed meant to be treated like an executable, why not just have all of the dependencies for this executable baked in, so that this sort of package conflict wont happen?

unfortunately, this is impossible because optimum is a flexible and configurable tool that follows common huggingface design practice with lazy initialization and delayed requirements installation including allowance to use remote code. So we can not predict which models which additional packages will require to install all of them simultaneously, in the same time it also may lead to UX problems if all known models dependencies will be installed even user do not need them (e.g. if you only need to run for example bert, there is no need to install dependencies for stable diffusion for example).

The only think that I can recommend is to change attention implementation (in case of internvl code, it always forces flash_attn implementation if this package available in the environment) or ask model authors to fix model.

From my side, I only can try to patch model automatically change attention implementation inside tool.

@eaidova
Copy link
Collaborator

eaidova commented Jan 10, 2025

@Florianoli could you please provide command that you use for export model with optimum-intel?

@eaidova eaidova linked a pull request Jan 10, 2025 that will close this issue
3 tasks
@endomorphosis
Copy link
Author

if this is indeed meant to be treated like an executable, why not just have all of the dependencies for this executable baked in, so that this sort of package conflict wont happen?

unfortunately, this is impossible because optimum is a flexible and configurable tool that follows common huggingface design practice with lazy initialization and delayed requirements installation including allowance to use remote code. So we can not predict which models which additional packages will require to install all of them simultaneously, in the same time it also may lead to UX problems if all known models dependencies will be installed even user do not need them (e.g. if you only need to run for example bert, there is no need to install dependencies for stable diffusion for example).

The only think that I can recommend is to change attention implementation (in case of internvl code, it always forces flash_attn implementation if this package available in the environment) or ask model authors to fix model.

From my side, I only can try to patch model automatically change attention implementation inside tool.

I can see why someone would trade off alot of bloat to save time, but this seems like alot of bloat, I understand that not every model architecture is the same, but llama_cpp doesn't need to dynamically import dependencies in order to quantize the models, and if i remember correctly they do specifically list the model model architecture data in their conversion tool, instead of relying on huggingface libraries and tracing torchscript.,

@Florianoli
Copy link

@Florianoli could you please provide command that you use for export model with optimum-intel?

@eaidova
Hi! I used this command:
optimum-cli export openvino --model jinaai/jina-embeddings-v3 --task feature-extraction --weight-format fp16 ov_model/
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants