From deee85cc8740241c26ef7bfa138c8a4b2acd24e0 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Fri, 5 Jul 2024 02:59:34 -0700 Subject: [PATCH] Enable triton sparse gemm only for CUDA --- xla/service/BUILD | 1 + xla/service/elemental_ir_emitter.cc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/xla/service/BUILD b/xla/service/BUILD index 766bd8712382c..1c42226ea20c4 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -5387,6 +5387,7 @@ cc_library( name = "elemental_ir_emitter", srcs = ["elemental_ir_emitter.cc"], hdrs = ["elemental_ir_emitter.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":algorithm_util", ":float8_fnuz_ir_emitter", diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index ed4c16fe37cbf..06e18ed09a12d 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -2921,10 +2921,12 @@ absl::StatusOr ElementalIrEmitter::EmitElementalDot( "Algorithm not supported by the ElementalIrEmitter: %s", PrecisionConfig::Algorithm_Name(hlo->precision_config().algorithm()))); } +#ifdef GOOGLE_CUDA const HloDotInstruction* dot = Cast(hlo); if (dot->sparse_operands()) { return Unimplemented("Sparse dot is supported by Triton emitter only."); } +#endif auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1));