diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index c15aeb17ad2893d..ff5816058a6b768 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -728,13 +728,12 @@ class OpenMPIRBuilder { LoopBodyGenCallbackTy BodyGenCB, Value *TripCount, const Twine &Name = "loop"); - /// Generator for the control flow structure of an OpenMP canonical loop. + /// Calculate the trip count of a canonical loop. /// - /// Instead of a logical iteration space, this allows specifying user-defined - /// loop counter values using increment, upper- and lower bounds. To - /// disambiguate the terminology when counting downwards, instead of lower - /// bounds we use \p Start for the loop counter value in the first body - /// iteration. + /// This allows specifying user-defined loop counter values using increment, + /// upper- and lower bounds. To disambiguate the terminology when counting + /// downwards, instead of lower bounds we use \p Start for the loop counter + /// value in the first body iteration. /// /// Consider the following limitations: /// @@ -758,7 +757,32 @@ class OpenMPIRBuilder { /// /// for (int i = 0; i < 42; i -= 1u) /// - // + /// \param Loc The insert and source location description. + /// \param Start Value of the loop counter for the first iterations. + /// \param Stop Loop counter values past this will stop the loop. + /// \param Step Loop counter increment after each iteration; negative + /// means counting down. + /// \param IsSigned Whether Start, Stop and Step are signed integers. + /// \param InclusiveStop Whether \p Stop itself is a valid value for the loop + /// counter. + /// \param Name Base name used to derive instruction names. + /// + /// \returns The value holding the calculated trip count. + Value *calculateCanonicalLoopTripCount(const LocationDescription &Loc, + Value *Start, Value *Stop, Value *Step, + bool IsSigned, bool InclusiveStop, + const Twine &Name = "loop"); + + /// Generator for the control flow structure of an OpenMP canonical loop. + /// + /// Instead of a logical iteration space, this allows specifying user-defined + /// loop counter values using increment, upper- and lower bounds. To + /// disambiguate the terminology when counting downwards, instead of lower + /// bounds we use \p Start for the loop counter value in the first body + /// + /// It calls \see calculateCanonicalLoopTripCount for trip count calculations, + /// so limitations of that method apply here as well. + /// /// \param Loc The insert and source location description. /// \param BodyGenCB Callback that will generate the loop body code. /// \param Start Value of the loop counter for the first iterations. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index a1445df13ea785e..7bfd9d78cb7479f 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4032,11 +4032,9 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc, return CL; } -Expected OpenMPIRBuilder::createCanonicalLoop( - const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, - Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, - InsertPointTy ComputeIP, const Twine &Name) { - +Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount( + const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step, + bool IsSigned, bool InclusiveStop, const Twine &Name) { // Consider the following difficulties (assuming 8-bit signed integers): // * Adding \p Step to the loop counter which passes \p Stop may overflow: // DO I = 1, 100, 50 @@ -4048,9 +4046,7 @@ Expected OpenMPIRBuilder::createCanonicalLoop( assert(IndVarTy == Stop->getType() && "Stop type mismatch"); assert(IndVarTy == Step->getType() && "Step type mismatch"); - LocationDescription ComputeLoc = - ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc; - updateToLocation(ComputeLoc); + updateToLocation(Loc); ConstantInt *Zero = ConstantInt::get(IndVarTy, 0); ConstantInt *One = ConstantInt::get(IndVarTy, 1); @@ -4090,8 +4086,20 @@ Expected OpenMPIRBuilder::createCanonicalLoop( Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr); CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo); } - Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping, - "omp_" + Name + ".tripcount"); + + return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping, + "omp_" + Name + ".tripcount"); +} + +Expected OpenMPIRBuilder::createCanonicalLoop( + const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB, + Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop, + InsertPointTy ComputeIP, const Twine &Name) { + LocationDescription ComputeLoc = + ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc; + + Value *TripCount = calculateCanonicalLoopTripCount( + ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name); auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) { Builder.restoreIP(CodeGenIP); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index de6cc5d9781a677..3516e53cf38e6f8 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -1427,8 +1427,7 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) { EXPECT_EQ(&Loop->getAfter()->front(), RetInst); } -TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) { - using InsertPointTy = OpenMPIRBuilder::InsertPointTy; +TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) { OpenMPIRBuilder OMPBuilder(*M); OMPBuilder.initialize(); IRBuilder<> Builder(BB); @@ -1444,17 +1443,8 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) { Value *StartVal = ConstantInt::get(LCTy, Start); Value *StopVal = ConstantInt::get(LCTy, Stop); Value *StepVal = ConstantInt::get(LCTy, Step); - auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) { - return Error::success(); - }; - Expected LoopResult = - OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal, - StepVal, IsSigned, InclusiveStop); - assert(LoopResult && "unexpected error"); - CanonicalLoopInfo *Loop = *LoopResult; - Loop->assertOK(); - Builder.restoreIP(Loop->getAfterIP()); - Value *TripCount = Loop->getTripCount(); + Value *TripCount = OMPBuilder.calculateCanonicalLoopTripCount( + Loc, StartVal, StopVal, StepVal, IsSigned, InclusiveStop); return cast(TripCount)->getValue().getZExtValue(); }; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 74205238fbbe59a..d80f7f49f15b58d 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1772,55 +1772,34 @@ LogicalResult TargetOp::verify() { Operation *TargetOp::getInnermostCapturedOmpOp() { Dialect *ompDialect = (*this)->getDialect(); Operation *capturedOp = nullptr; - Region *capturedParentRegion = nullptr; - walk([&](Operation *op) { + // Process in pre-order to check operations from outermost to innermost, + // ensuring we only enter the region of an operation if it meets the criteria + // for being captured. We stop the exploration of nested operations as soon as + // we process a region with no operation to be captured. + walk([&](Operation *op) { if (op == *this) - return; - - // Reset captured op if crossing through an omp.loop_nest, so that the top - // level one will be the one captured. - if (llvm::isa(op)) { - capturedOp = nullptr; - capturedParentRegion = nullptr; - } + return WalkResult::advance(); + // Ignore operations of other dialects or omp operations with no regions, + // because these will only be checked if they are siblings of an omp + // operation that can potentially be captured. bool isOmpDialect = op->getDialect() == ompDialect; bool hasRegions = op->getNumRegions() > 0; - - if (capturedOp) { - bool isImmediateParent = false; - for (Region ®ion : op->getRegions()) { - if (®ion == capturedParentRegion) { - isImmediateParent = true; - capturedParentRegion = op->getParentRegion(); - break; - } - } - - // Make sure the captured op is part of a (possibly multi-level) nest of - // OpenMP-only operations containing no unsupported siblings at any level. - if ((hasRegions && isOmpDialect != isImmediateParent) || - (!isImmediateParent && !siblingAllowedInCapture(op))) { - capturedOp = nullptr; - capturedParentRegion = nullptr; - } - } else { - // The first OpenMP dialect op containing a region found while visiting - // in post-order should be the innermost captured OpenMP operation. - if (isOmpDialect && hasRegions) { - capturedOp = op; - capturedParentRegion = op->getParentRegion(); - - // Don't capture this op if it has a not-allowed sibling. - for (Operation &sibling : op->getParentRegion()->getOps()) { - if (&sibling != op && !siblingAllowedInCapture(&sibling)) { - capturedOp = nullptr; - capturedParentRegion = nullptr; - } - } - } - } + if (!isOmpDialect || !hasRegions) + return WalkResult::skip(); + + // Don't capture this op if it has a not-allowed sibling, and stop recursing + // into nested operations. + for (Operation &sibling : op->getParentRegion()->getOps()) + if (&sibling != op && !siblingAllowedInCapture(&sibling)) + return WalkResult::interrupt(); + + // Don't continue capturing nested operations if we reach an omp.loop_nest. + // Otherwise, process the contents of this operation. + capturedOp = op; + return llvm::isa(op) ? WalkResult::interrupt() + : WalkResult::advance(); }); return capturedOp; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index a49887394290328..6ec261c104676ba 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -32,6 +32,7 @@ #include "llvm/IR/ReplaceConstant.h" #include "llvm/Support/FileSystem.h" #include "llvm/TargetParser/Triple.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include @@ -4028,6 +4029,72 @@ static uint64_t getTeamsReductionDataSize(mlir::omp::TeamsOp &teamsOp) { return getReductionDataSize(teamsOp); } +/// Follow uses of `host_eval`-defined block arguments of the given `omp.target` +/// operation and populate output variables with their corresponding host value +/// (i.e. operand evaluated outside of the target region), based on their uses +/// inside of the target region. +/// +/// Loop bounds and steps are only optionally populated, if output vectors are +/// provided. +static void +extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, + Value &numTeamsLower, Value &numTeamsUpper, + Value &threadLimit, + llvm::SmallVectorImpl *lowerBounds = nullptr, + llvm::SmallVectorImpl *upperBounds = nullptr, + llvm::SmallVectorImpl *steps = nullptr) { + auto blockArgIface = llvm::cast(*targetOp); + for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(), + blockArgIface.getHostEvalBlockArgs())) { + Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item); + + for (Operation *user : blockArg.getUsers()) { + llvm::TypeSwitch(user) + .Case([&](omp::TeamsOp teamsOp) { + if (teamsOp.getNumTeamsLower() == blockArg) + numTeamsLower = hostEvalVar; + else if (teamsOp.getNumTeamsUpper() == blockArg) + numTeamsUpper = hostEvalVar; + else if (teamsOp.getThreadLimit() == blockArg) + threadLimit = hostEvalVar; + else + llvm_unreachable("unsupported host_eval use"); + }) + .Case([&](omp::ParallelOp parallelOp) { + if (parallelOp.getNumThreads() == blockArg) + numThreads = hostEvalVar; + else + llvm_unreachable("unsupported host_eval use"); + }) + .Case([&](omp::LoopNestOp loopOp) { + auto processBounds = + [&](OperandRange opBounds, + llvm::SmallVectorImpl *outBounds) -> bool { + bool found = false; + for (auto [i, lb] : llvm::enumerate(opBounds)) { + if (lb == blockArg) { + found = true; + if (outBounds) + (*outBounds)[i] = hostEvalVar; + } + } + return found; + }; + bool found = + processBounds(loopOp.getLoopLowerBounds(), lowerBounds); + found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) || + found; + found = processBounds(loopOp.getLoopSteps(), steps) || found; + if (!found) + llvm_unreachable("unsupported host_eval use"); + }) + .Default([](Operation *) { + llvm_unreachable("unsupported host_eval use"); + }); + } + } +} + /// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default /// values as stated by the corresponding clauses, if constant. /// @@ -4038,6 +4105,10 @@ static void initTargetDefaultBounds( omp::TargetOp targetOp, llvm::OpenMPIRBuilder::TargetKernelDefaultBounds &bounds, bool isTargetDevice, bool isGPU) { + Value hostNumThreads, hostNumTeamsLower, hostNumTeamsUpper, hostThreadLimit; + extractHostEvalClauses(targetOp, hostNumThreads, hostNumTeamsLower, + hostNumTeamsUpper, hostThreadLimit); + // TODO Handle constant IF clauses Operation *innermostCapturedOmpOp = targetOp.getInnermostCapturedOmpOp(); @@ -4047,8 +4118,8 @@ static void initTargetDefaultBounds( castOrGetParentOfType(innermostCapturedOmpOp)) { // TODO Use teamsOp.getNumTeamsLower() to initialize `minTeamsVal`. For now, // just match clang and set min and max to the same value. - Value numTeamsClause = isTargetDevice ? teamsOp.getNumTeamsUpper() - : targetOp.getNumTeamsUpper(); + Value numTeamsClause = + isTargetDevice ? teamsOp.getNumTeamsUpper() : hostNumTeamsUpper; if (numTeamsClause) { if (auto constOp = dyn_cast_if_present( numTeamsClause.getDefiningOp())) { @@ -4091,8 +4162,8 @@ static void initTargetDefaultBounds( if (auto teamsOp = castOrGetParentOfType(innermostCapturedOmpOp)) { - Value threadLimitClause = isTargetDevice ? teamsOp.getThreadLimit() - : targetOp.getTeamsThreadLimit(); + Value threadLimitClause = + isTargetDevice ? teamsOp.getThreadLimit() : hostThreadLimit; setMaxValueFromClause(threadLimitClause, teamsThreadLimitVal); } @@ -4100,8 +4171,8 @@ static void initTargetDefaultBounds( if (innermostCapturedOmpOp) { if (auto parallelOp = castOrGetParentOfType(innermostCapturedOmpOp)) { - Value numThreadsClause = isTargetDevice ? parallelOp.getNumThreads() - : targetOp.getNumThreads(); + Value numThreadsClause = + isTargetDevice ? parallelOp.getNumThreads() : hostNumThreads; setMaxValueFromClause(numThreadsClause, maxThreadsVal); } else if (castOrGetParentOfType(innermostCapturedOmpOp, /*immediateParent=*/true)) { @@ -4147,26 +4218,68 @@ static void initTargetDefaultBounds( /// only provide correct results if it's called after the body of \c targetOp /// has been fully generated. static void initTargetRuntimeBounds( - LLVM::ModuleTranslation &moduleTranslation, omp::TargetOp targetOp, + llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, + omp::TargetOp targetOp, llvm::OpenMPIRBuilder::TargetKernelRuntimeBounds &bounds) { + omp::LoopNestOp loopOp = castOrGetParentOfType( + targetOp.getInnermostCapturedOmpOp()); + unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0; + + Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit; + llvm::SmallVector lowerBounds(numLoops), upperBounds(numLoops), + steps(numLoops); + extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper, + teamsThreadLimit, &lowerBounds, &upperBounds, &steps); + // TODO Handle IF clauses. - if (Value numTeamsLower = targetOp.getNumTeamsLower()) + llvm::Value *&llvmTargetThreadLimit = + bounds.TargetThreadLimit.emplace_back(nullptr); + if (Value targetThreadLimit = targetOp.getThreadLimit()) + llvmTargetThreadLimit = moduleTranslation.lookupValue(targetThreadLimit); + + if (numTeamsLower) bounds.MinTeams = moduleTranslation.lookupValue(numTeamsLower); llvm::Value *&llvmMaxTeams = bounds.MaxTeams.emplace_back(nullptr); - if (Value numTeamsUpper = targetOp.getNumTeamsUpper()) + if (numTeamsUpper) llvmMaxTeams = moduleTranslation.lookupValue(numTeamsUpper); llvm::Value *&llvmTeamsThreadLimit = bounds.TeamsThreadLimit.emplace_back(nullptr); - if (Value teamsThreadLimit = targetOp.getTeamsThreadLimit()) + if (teamsThreadLimit) llvmTeamsThreadLimit = moduleTranslation.lookupValue(teamsThreadLimit); - if (Value numThreads = targetOp.getNumThreads()) + if (numThreads) bounds.MaxThreads = moduleTranslation.lookupValue(numThreads); - if (Value tripCount = targetOp.getTripCount()) - bounds.LoopTripCount = moduleTranslation.lookupValue(tripCount); + if (targetOp.isTargetSPMDLoop()) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + bounds.LoopTripCount = nullptr; + + // To calculate the trip count, we multiply together the trip counts of + // every collapsed canonical loop. We don't need to create the loop nests + // here, since we're only interested in the trip count. + for (auto [loopLower, loopUpper, loopStep] : + llvm::zip_equal(lowerBounds, upperBounds, steps)) { + llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower); + llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper); + llvm::Value *step = moduleTranslation.lookupValue(loopStep); + + llvm::OpenMPIRBuilder::LocationDescription loc(builder); + llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount( + loc, lowerBound, upperBound, step, /*IsSigned=*/true, + loopOp.getLoopInclusive()); + + if (!bounds.LoopTripCount) { + bounds.LoopTripCount = tripCount; + continue; + } + + // TODO: Enable UndefinedSanitizer to diagnose an overflow here. + bounds.LoopTripCount = builder.CreateMul(bounds.LoopTripCount, tripCount, + {}, /*HasNUW=*/true); + } + } } static LogicalResult @@ -4181,13 +4294,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, bool isGPU = ompBuilder->Config.isGPU(); auto parentFn = opInst.getParentOfType(); + auto blockIface = cast(opInst); auto &targetRegion = targetOp.getRegion(); DataLayout dl = DataLayout(opInst.getParentOfType()); SmallVector mapVars = targetOp.getMapVars(); - ArrayRef mapBlockArgs = - cast(opInst).getMapBlockArgs(); + ArrayRef mapBlockArgs = blockIface.getMapBlockArgs(); llvm::Function *llvmOutlinedFn = nullptr; - llvm::OpenMPIRBuilder::TargetKernelRuntimeBounds runtimeBounds; // TODO: It can also be false if a compile-time constant `false` IF clause is // specified. @@ -4229,7 +4341,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, OperandRange privateVars = targetOp.getPrivateVars(); std::optional privateSyms = targetOp.getPrivateSyms(); MutableArrayRef privateBlockArgs = - cast(opInst).getPrivateBlockArgs(); + blockIface.getPrivateBlockArgs(); for (auto [privVar, privatizerNameAttr, privBlockArg] : llvm::zip_equal(privateVars, *privateSyms, privateBlockArgs)) { @@ -4264,9 +4376,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, builder.SetInsertPoint(*exitBlock); - if (!isTargetDevice) - initTargetRuntimeBounds(moduleTranslation, targetOp, runtimeBounds); - return builder.saveIP(); }; @@ -4313,6 +4422,29 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, }; llvm::SmallVector kernelInput; + llvm::OpenMPIRBuilder::TargetKernelDefaultBounds defaultBounds; + initTargetDefaultBounds(targetOp, defaultBounds, isTargetDevice, isGPU); + + // Collect host-evaluated values needed to properly launch the kernel from the + // host. + llvm::OpenMPIRBuilder::TargetKernelRuntimeBounds runtimeBounds; + if (!isTargetDevice) + initTargetRuntimeBounds(builder, moduleTranslation, targetOp, + runtimeBounds); + + // Pass host-evaluated values as parameters to the kernel / host fallback, + // except if they are constants. In any case, map the MLIR block argument to + // the corresponding LLVM values. + SmallVector hostEvalVars = targetOp.getHostEvalVars(); + ArrayRef hostEvalBlockArgs = blockIface.getHostEvalBlockArgs(); + for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) { + llvm::Value *value = moduleTranslation.lookupValue(var); + moduleTranslation.mapValue(arg, value); + + if (!llvm::isa(value)) + kernelInput.push_back(value); + } + for (size_t i = 0; i < mapVars.size(); ++i) { // declare target arguments are not passed to kernels as arguments // TODO: We currently do not handle cases where a member is explicitly @@ -4327,14 +4459,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(), moduleTranslation, dds); - llvm::OpenMPIRBuilder::TargetKernelDefaultBounds defaultBounds; - initTargetDefaultBounds(targetOp, defaultBounds, isTargetDevice, isGPU); - - llvm::Value *&llvmTargetThreadLimit = - runtimeBounds.TargetThreadLimit.emplace_back(nullptr); - if (Value targetThreadLimit = targetOp.getThreadLimit()) - llvmTargetThreadLimit = moduleTranslation.lookupValue(targetThreadLimit); - llvm::Value *ifCond = nullptr; if (Value targetIfCond = targetOp.getIfExpr()) ifCond = moduleTranslation.lookupValue(targetIfCond); diff --git a/mlir/test/Target/LLVMIR/omptarget-host-eval.mlir b/mlir/test/Target/LLVMIR/omptarget-host-eval.mlir new file mode 100644 index 000000000000000..a6494f334747132 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-host-eval.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @omp_target_region_() { + %out_teams = llvm.mlir.constant(1000 : i32) : i32 + %out_threads = llvm.mlir.constant(2000 : i32) : i32 + %out_lb = llvm.mlir.constant(0 : i32) : i32 + %out_ub = llvm.mlir.constant(3000 : i32) : i32 + %out_step = llvm.mlir.constant(1 : i32) : i32 + + omp.target + host_eval(%out_teams -> %teams, %out_threads -> %threads, + %out_lb -> %lb, %out_ub -> %ub, %out_step -> %step : + i32, i32, i32, i32, i32) { + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// CHECK-LABEL: define void @omp_target_region_ +// CHECK: %[[ARGS:.*]] = alloca %struct.__tgt_kernel_arguments + +// CHECK: %[[TRIPCOUNT_ADDR:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[ARGS]], i32 0, i32 8 +// CHECK: store i64 3000, ptr %[[TRIPCOUNT_ADDR]] + +// CHECK: %[[TEAMS_ADDR:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[ARGS]], i32 0, i32 10 +// CHECK: store [3 x i32] [i32 1000, i32 0, i32 0], ptr %[[TEAMS_ADDR]] + +// CHECK: %[[THREADS_ADDR:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[ARGS]], i32 0, i32 11 +// CHECK: store [3 x i32] [i32 2000, i32 0, i32 0], ptr %[[THREADS_ADDR]] + +// CHECK: call i32 @__tgt_target_kernel(ptr @{{.*}}, i64 {{.*}}, i32 1000, i32 2000, ptr @{{.*}}, ptr %[[ARGS]])