diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index ffa73618..1c1f0653 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -114,6 +114,10 @@ def maximum(lhs: "Register", rhs: "Register") -> "Register": ... +def minimum(lhs: "Register", rhs: "Register") -> "Register": + ... + + def broadcast( arg: "Register", target_shape: Optional[IndexExpr | int] = None ) -> "Register": @@ -607,6 +611,7 @@ def transform_index( @define_py_op(operator.mul) @define_py_op(operator.truediv) @define_interface_op("maximum") +@define_interface_op("minimum") @dataclass class BinaryPyOp(CustomOp, ABC): """ diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index d6507a41..ba06f4f3 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -58,6 +58,7 @@ reciprocal, abs, maximum, + minimum, get_custom, get_result, allocate, @@ -1076,6 +1077,24 @@ def handle_maximum(lhs: Value, rhs: Value) -> OpResult: return result +@handle_binary_op(minimum) +def handle_minimum(lhs: Value, rhs: Value) -> OpResult: + element_type = get_type_or_element_type(lhs.type) + if _is_float_type(element_type): + result = arith_d.minimumf(lhs, rhs) + elif _is_integer_like_type(element_type) and ( + element_type.is_signed() or element_type.is_signless() + ): + result = arith_d.minsi(lhs, rhs) + elif _is_integer_like_type(element_type) and element_type.is_unsigned(): + result = arith_d.minui(lhs, rhs) + else: + raise ValidationError( + f"Found unhandled operand type for minimum: {element_type}" + ) + return result + + ############################################################################### # Unary math Ops ############################################################################### diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 8bbf75fc..9eff299a 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1281,6 +1281,7 @@ def test( res = a_reg - b_reg res = res * a_reg res = res / b_reg + res = tkw.minimum(a_reg, b_reg) tkw.write(res, a, elements_per_thread=4) a = torch.randn(16, 16, dtype=torch.float16) @@ -1291,6 +1292,7 @@ def test( # CHECK: %[[SUB:.+]] = arith.subf # CHECK: %[[MUL:.+]] = arith.mulf %[[SUB]] # CHECK: %[[DIV:.+]] = arith.divf %[[MUL]] + # CHECK: %[[MINIMUM:.+]] = arith.minimumf # TODO: Something is broken in codegen and we are getting int in place of fx.Node