Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsupported Scalar Type 5? -- Portable/optimized ops don't consistently support half/bfloat16 #7748

Open
bluejack opened this issue Jan 17, 2025 · 3 comments
Assignees

Comments

@bluejack
Copy link

🐛 Describe the bug

After exporting a model to pte form and running it through executor_runner, I get:

E 00:00:02.220756 executorch:inputs_portable.cpp:45] Unsupported scalar type 5

I believe this is the "Half" type, or float16

Does that simply mean executor_runner does not support float16? Or does the whole framework not support float16?

Noted that when I run some investigation on the file using a python script, I get as far as sending it my float16 tensors, but it still fails to execute with a similar error:

[op_native_layer_norm.cpp:169] In function operator()(), assert failed (false): Unhandled dtype Half for native_layer_norm.out

I'm including the versions below, but note that this is using executorch built from head, rather than the last release. Should I expect the framework to support float16? And look to my own code for the error?

Versions

PyTorch version: 2.6.0.dev20250104
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.4)
CMake version: version 3.31.4
Libc version: N/A

Python version: 3.12.7 | packaged by Anaconda, Inc. | (main, Oct 4 2024, 08:22:19) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] executorch==0.6.0a0+cd0e584
[pip3] flake8==7.0.0
[pip3] mypy==1.11.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.0.0
[pip3] numpydoc==1.7.0
[pip3] torch==2.6.0.dev20250104
[pip3] torchao==0.8.0+git2e032c6b
[pip3] torchaudio==2.6.0.dev20250104
[pip3] torchsr==1.0.4
[pip3] torchvision==0.22.0.dev20250104
[conda] executorch 0.6.0a0+cd0e584 pypi_0 pypi
[conda] numpy 2.0.0 pypi_0 pypi
[conda] numpydoc 1.7.0 py312hca03da5_0
[conda] torch 2.6.0.dev20250104 pypi_0 pypi
[conda] torchao 0.8.0+git2e032c6b pypi_0 pypi
[conda] torchaudio 2.6.0.dev20250104 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.22.0.dev20250104 pypi_0 pypi

swolchok added a commit that referenced this issue Jan 17, 2025
Partial fix for #7748.

ghstack-source-id: 0c7e0a5712cba6829fdf5461ea50a8cc4afd39f0
ghstack-comment-id: 2599375147
Pull Request resolved: #7750
@swolchok
Copy link
Contributor

Does that simply mean executor_runner does not support float16?

It looks like this particular function does not support float16. I've just sent #7750 to fix it.

Or does the whole framework not support float16?

We are capable of supporting it, but it looks like portable ops coverage is spotty. I'll send a fix for native_layer_norm and as many other places as I can find.

swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: 9f183dddcd87edb2493af0f97d7ad4e40d9be434
ghstack-comment-id: 2599398274
Pull Request resolved: #7758
@swolchok swolchok self-assigned this Jan 18, 2025
@swolchok swolchok changed the title Unsupported Scalar Type 5? Unsupported Scalar Type 5? -- Portable/optimized ops don't consistently support half/bfloat16 Jan 18, 2025
swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: a72e5e33f005abc47cc1143f7b282f8050374955
ghstack-comment-id: 2599413770
Pull Request resolved: #7760
@swolchok
Copy link
Contributor

swolchok commented Jan 18, 2025

By the way, if you're running on your Mac, you might want to enable the XNNPACK delegate when exporting; there's a good chance you will get both better performance and a workaround for the remaining instance of this issue I haven't got PRs out for yet (though I don't know whether XNNPACK has layer norm off the top of my head).

swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: 02bfc58615997b27f0ecb99f8efcf7fce0694b8c
ghstack-comment-id: 2599413770
Pull Request resolved: #7760
swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: b7b33809ec99537c0f44c7abb5880c6502d30698
ghstack-comment-id: 2599481711
Pull Request resolved: #7767
swolchok added a commit that referenced this issue Jan 18, 2025
Partial fix for #7748.

ghstack-source-id: 02a1dc797b933f836efe17aa659b6a0c27ecc460
ghstack-comment-id: 2599483099
Pull Request resolved: #7769
@bluejack
Copy link
Author

By the way, if you're running on your Mac, you might want to enable the XNNPACK delegate when exporting; there's a good chance you will get both better performance and a workaround for the remaining instance of this issue I haven't got PRs out for yet (though I don't know whether XNNPACK has layer norm off the top of my head).

Ok, I will look at this option, thanks for the tip.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants