Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVMGPU][ROCM] Disable polynomial approximation and use device libs #19672

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Jan 10, 2025

The device lib implementation is selected by the convertToROCDL pass. This implementation is much more efficient than the polynomial approximation in MLIR.

Issue: #19673

The device lib implementation is selecected by the `convertToROCDL`
pass. This implementation is much more efficient than the polynomial
approximation in MLIR.
@benvanik
Copy link
Collaborator

is there an issue tracking improving the mlir one?

@kuhar
Copy link
Member Author

kuhar commented Jan 10, 2025

is there an issue tracking improving the mlir one?

Not yet, this may be a task for our math PhDs like @bjacob or @zjgarvey.

@benvanik
Copy link
Collaborator

cool - would be good to get on the docket - all other targets use the MLIR one and relying on the device libs isn't great long-term

nice job finding the delta - now we have something to target :)

@kuhar
Copy link
Member Author

kuhar commented Jan 10, 2025

@benvanik I filed it here: #19673

@bjacob
Copy link
Contributor

bjacob commented Jan 10, 2025

I think there is a deep truth disguised as an accident here: a generic polynomial approximation won't in general be the most efficient implementation on a given target. Achieving optimal results requires looking at the specifics of each math function and each target. For example, on gfx9, it so happens that 1/x and exp(x) are cheap to evaluate, defeating the basic assumption underpinning polynomial approximation, at least for functions such as tanh(x) which are easy to evaluate by 1/x and exp(x) steps.

@kuhar
Copy link
Member Author

kuhar commented Jan 10, 2025

@kuhar
Copy link
Member Author

kuhar commented Jan 10, 2025

OK, the remaining issue is that MathToROCDL doesn't handle fpowi and ipowi:
https://github.com/llvm/llvm-project/blob/3fbc344b49800bb0f70fd5af46c0a47f6d55bbd1/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp#L59-L60

Repro:

func.func @main(%arg0: !torch.vtensor<[2,154,6144],f16>) -> !torch.vtensor<[2,154,6144],f16> {
  %str_846 = torch.constant.str "tanh"
  %395 = torch.aten.gelu %arg0, %str_846 : !torch.vtensor<[2,154,6144],f16>, !torch.str -> !torch.vtensor<[2,154,6144],f16>
  return %395 : !torch.vtensor<[2,154,6144],f16>
}

@kuhar
Copy link
Member Author

kuhar commented Jan 10, 2025

I can see a few builtins related to pow here: https://github.com/llvm/llvm-project/blob/3fbc344b49800bb0f70fd5af46c0a47f6d55bbd1/clang/lib/Headers/__clang_hip_libdevice_declares.h#L86-L87

__device__ __attribute__((pure)) float __ocml_pow_f32(float, float);
__device__ __attribute__((pure)) float __ocml_pown_f32(float, int);

@kuhar kuhar marked this pull request as draft January 10, 2025 22:06
@lialan
Copy link
Contributor

lialan commented Jan 11, 2025

I can see a few builtins related to pow here: https://github.com/llvm/llvm-project/blob/3fbc344b49800bb0f70fd5af46c0a47f6d55bbd1/clang/lib/Headers/__clang_hip_libdevice_declares.h#L86-L87

__device__ __attribute__((pure)) float __ocml_pow_f32(float, float);

__device__ __attribute__((pure)) float __ocml_pown_f32(float, int);

That covers fpowi. Guess for ipowi you would cast the result of fpowi from float to int?

@kuhar
Copy link
Member Author

kuhar commented Jan 11, 2025

We need to make sure we are handling the whole input space, including large and negative number. We could also expand to muls when assumptions allow. We would have to benchmark and decide.

Also, there's a bunch of packed math functions at the very bottom that MathToROCDL doesn't use but could potentially benefit from, especially with fp16. Should be lost of room for improvement!

@lialan
Copy link
Contributor

lialan commented Jan 11, 2025

Or you expand to muls. We would have to benchmark and decide.

Also, there's a bunch of packed math functions at the very bottom that MathToROCDL doesn't use but could potentially benefit from, especially with fp16. Should be lost of room for improvement!

What about we just lower ipowi to LLVM::PowIOp and let the codegen handle it?

@lialan
Copy link
Contributor

lialan commented Jan 11, 2025

OK, the remaining issue is that MathToROCDL doesn't handle fpowi and ipowi: https://github.com/llvm/llvm-project/blob/3fbc344b49800bb0f70fd5af46c0a47f6d55bbd1/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp#L59-L60

@krzysz00 Where can I find all the ROCDL functions? For example, those _f16 functions in this linked file are not included in hip header files in clang/lib/Headers/__clang_hip_libdevice_declares.h.

Also any context we should know why fpowi and ipowi are excluded from the conversion? I feel like__ocml_pown_f32/16/64 can be directly mapped from fpowi at least.

@kuhar
Copy link
Member Author

kuhar commented Jan 11, 2025

Also any context we should know why fpowi and ipowi are excluded from the conversion

The implementation uses a templated pattern that asserts that the result and operand types are the same, and this is not true for fpowi so it can't easily be plugged in. While ipowi doesn't have a clear mapping to a library function AFAICT.

@lialan
Copy link
Contributor

lialan commented Jan 12, 2025

The implementation uses a templated pattern that asserts that the result and operand types are the same, and this is not true for fpowi so it can't easily be plugged in.

There is just this one static assert to guard this assumption, guess we can relax it and make fpowi work.

This should make fpowi work: llvm/llvm-project#122640

@lialan
Copy link
Contributor

lialan commented Jan 12, 2025

Another thing I found is that there is __2f16 __ocml_pown_2f16(__2f16, __2i16);, so naturally we don't want to fully unroll the vector and use scalar version in the case of f16 vector.

This applies to sqrt and trunc as well.

@benvanik
Copy link
Collaborator

yeah, one of many reasons why libm-like libraries are bad (for us) is that they assuming scalar everything above and below the libm call boundary - native versions that we can represent in IR as vectors have the most potential, but we as we see here a totally untuned/unoptimized vector version can't beat a highly tuned/optimized scalar version, but it's useful to keep in mind that a tuned/optimized vector version always has the potential to beat a scalar version especially as dispatches scale (you don't want to mix vectorized and scalarized stuff in the same lowering flow and the chance of that happening goes up a lot with fusion)

@kuhar
Copy link
Member Author

kuhar commented Jan 12, 2025

IMO we should take it one step at a time. First, let's enable lowering to the remaining device lib calls -- this will unblock this PR and fix known performance issues in IREE on mi300-series cards. Then, we can follow the other prongs concurrently:

  • Improve the MathToROCDL conversion, evaluate the packed math variants with 2xf16, use ValueBounds to select alternative lowering paths for functions like ipowi, etc.
  • Drop polynomial approximation for other targets that provide device libs (CUDA?).

@kuhar kuhar requested a review from lialan January 13, 2025 15:47
@kuhar kuhar marked this pull request as ready for review January 13, 2025 15:48
@kuhar
Copy link
Member Author

kuhar commented Jan 13, 2025

I added a local llvmgpu pass that handles fpowi and ipowi only to unblock this.

@kuhar kuhar requested a review from Groverkss January 13, 2025 15:49
@lialan
Copy link
Contributor

lialan commented Jan 13, 2025

I am taking it from here to address the remaining bug in this PR.

@lialan
Copy link
Contributor

lialan commented Jan 13, 2025

Failure is caused by a type mismatch while converting math functions:

        %13318 = "llvm.insertvalue"(%13316, %13317) <{position = array<i64: 0, 3, 0, 0, 3>}> : (!llvm.array<1 x array<4 x array<1 x array<1 x array<4 x vector<1xf16>>>>>>, vector<1xf16>) -> !llvm.array<1 x array<4 x array<1 x array<1 x array<4 x vector<1xf16>>>>>>
        %13319 = "builtin.unrealized_conversion_cast"(%13318) : (!llvm.array<1 x array<4 x array<1 x array<1 x array<4 x vector<1xf16>>>>>>) -> vector<1x4x1x1x4x1xf16>
        %13320 = "math.erf"(%13319) <{fastmath = #arith.fastmath<none>}> : (vector<1x4x1x1x4x1xf16>) -> vector<1x4x1x1x4x1xf16>
        %13321 = "builtin.unrealized_conversion_cast"(%13320) : (vector<1x4x1x1x4x1xf16>) -> !llvm.array<1 x array<4 x array<1 x array<1 x array<4 x vector<1xf16>>>>>>
        %13322 = "llvm.extractvalue"(%13321) <{position = array<i64: 0, 0, 0, 0, 0>}> : (!llvm.array<1 x array<4 x array<1 x array<1 x array<4 x vector<1xf16>>>>>>) -> vector<1xf16>

nested arrays are flattened to a single vector.

@Groverkss
Copy link
Contributor

Groverkss commented Jan 13, 2025

This patch also fixes #18570 , for LLVMGPU atleast

Copy link
Contributor

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -0,0 +1,34 @@
// Copyright 2023 The IREE Authors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2025

@lialan
Copy link
Contributor

lialan commented Jan 14, 2025

fpowi support is merged upstream. But, the remaining issue in this change is not completely related about fpowi or ipowi. There are some math functions not being handled by ROCDL conversion.

@lialan
Copy link
Contributor

lialan commented Jan 14, 2025

@kuhar This should get the PR to get pass compilation issues: patch.patch I cannot submit it to your branch, can you update your branch with this patch?

And then there is an issue with iree-run-module crash/not-working we will need to fix:
https://github.com/iree-org/iree/actions/runs/12767199313/job/35585668365?pr=19697

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants