diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp index d80da87a94b12d..85e7db450aa428 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp @@ -1660,7 +1660,6 @@ void CGOpenMPRuntimeGPU::emitReduction( return; bool ParallelReduction = isOpenMPParallelDirective(Options.ReductionKind); - bool DistributeReduction = isOpenMPDistributeDirective(Options.ReductionKind); bool TeamsReduction = isOpenMPTeamsDirective(Options.ReductionKind); ASTContext &C = CGM.getContext(); @@ -1756,7 +1755,7 @@ void CGOpenMPRuntimeGPU::emitReduction( CGF.Builder.restoreIP(OMPBuilder.createReductionsGPU( OmpLoc, AllocaIP, CodeGenIP, ReductionInfos, false, TeamsReduction, - DistributeReduction, llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang, + llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang, CGF.getTarget().getGridValue(), C.getLangOpts().OpenMPCUDAReductionBufNum, RTLoc)); return; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 9c3535ef8a3b17..d1a397ca66c6de 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -796,7 +796,6 @@ genReductionVars(mlir::Operation *op, lower::AbstractConverter &converter, mlir::Block *entryBlock = firOpBuilder.createBlock( &op->getRegion(0), {}, reductionTypes, blockArgLocs); - // Bind the reduction arguments to their block arguments. for (auto [arg, prv] : llvm::zip_equal(reductionArgs, entryBlock->getArguments())) { @@ -1659,14 +1658,15 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter, loc, llvm::omp::Directive::OMPD_taskwait); } -static void -genTeamsClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - lower::StatementContext &stmtCtx, const List &clauses, - mlir::Location loc, bool evalOutsideTarget, - mlir::omp::TeamsOperands &clauseOps, - mlir::omp::NumTeamsClauseOps &numTeamsClauseOps, - mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps) { +static void genTeamsClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, bool evalOutsideTarget, + mlir::omp::TeamsOperands &clauseOps, + mlir::omp::NumTeamsClauseOps &numTeamsClauseOps, + mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps, + llvm::SmallVectorImpl &reductionTypes, + llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); @@ -1684,8 +1684,7 @@ genTeamsClauses(lower::AbstractConverter &converter, cp.processNumTeams(stmtCtx, numTeamsClauseOps); cp.processThreadLimit(stmtCtx, threadLimitClauseOps); } - - // cp.processTODO(loc, llvm::omp::Directive::OMPD_teams); + cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); } static void genWsloopClauses( @@ -1874,7 +1873,6 @@ static mlir::omp::ParallelOp genParallelOp( llvm::ArrayRef reductionTypes, DataSharingProcessor *dsp, bool isComposite = false, mlir::omp::TargetOp parentTarget = nullptr) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto reductionCallback = [&](mlir::Operation *op) { genReductionVars(op, converter, loc, reductionSyms, reductionTypes); return llvm::SmallVector(reductionSyms); @@ -2360,14 +2358,22 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::omp::TeamsOperands clauseOps; mlir::omp::NumTeamsClauseOps numTeamsClauseOps; mlir::omp::ThreadLimitClauseOps threadLimitClauseOps; + llvm::SmallVector reductionSyms; + llvm::SmallVector reductionTypes; genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, evalOutsideTarget, clauseOps, numTeamsClauseOps, - threadLimitClauseOps); + threadLimitClauseOps, reductionTypes, reductionSyms); + + auto reductionCallback = [&](mlir::Operation *op) { + genReductionVars(op, converter, loc, reductionSyms, reductionTypes); + return llvm::SmallVector(reductionSyms); + }; auto teamsOp = genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_teams) - .setClauses(&item->clauses), + .setClauses(&item->clauses) + .setGenRegionEntryCb(reductionCallback), queue, item, clauseOps); if (numTeamsClauseOps.numTeamsUpper) { @@ -2436,7 +2442,6 @@ static void genStandaloneDo(lower::AbstractConverter &converter, const ConstructQueue &queue, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; - mlir::omp::WsloopOperands wsloopClauseOps; llvm::SmallVector reductionSyms; llvm::SmallVector reductionTypes; diff --git a/flang/test/Lower/OpenMP/reduction-target-spmd.f90 b/flang/test/Lower/OpenMP/reduction-target-spmd.f90 new file mode 100644 index 00000000000000..353c540c3bbf32 --- /dev/null +++ b/flang/test/Lower/OpenMP/reduction-target-spmd.f90 @@ -0,0 +1,15 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s | FileCheck %s +! RUN: bbc -emit-fir -fopenmp -o - %s | FileCheck %s + +! CHECK: omp.teams +! CHECK-SAME: reduction(@add_reduction_i32 %{{.*}} -> %{{.*}} : !fir.ref) +subroutine myfun() + integer :: i, j + i = 0 + j = 0 + !$omp target teams distribute parallel do reduction(+:i) + do j = 1,5 + i = i + j + end do + !$omp end target teams distribute parallel do +end subroutine myfun diff --git a/flang/test/Lower/OpenMP/Todo/reduction-teams.f90 b/flang/test/Lower/OpenMP/reduction-teams.f90 similarity index 96% rename from flang/test/Lower/OpenMP/Todo/reduction-teams.f90 rename to flang/test/Lower/OpenMP/reduction-teams.f90 index 5948507869452b..eddbd752f7fa40 100644 --- a/flang/test/Lower/OpenMP/Todo/reduction-teams.f90 +++ b/flang/test/Lower/OpenMP/reduction-teams.f90 @@ -1,6 +1,5 @@ ! RUN: bbc -emit-fir -fopenmp -o - %s | FileCheck %s ! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s | FileCheck %s -! XFAIL: * ! CHECK: omp.teams ! CHECK-SAME: reduction diff --git a/flang/test/Lower/OpenMP/sections-array-reduction.f90 b/flang/test/Lower/OpenMP/sections-array-reduction.f90 index e5319e8d6bcc79..13855e62c2e6d8 100644 --- a/flang/test/Lower/OpenMP/sections-array-reduction.f90 +++ b/flang/test/Lower/OpenMP/sections-array-reduction.f90 @@ -35,7 +35,7 @@ subroutine sectionsReduction(x) ! CHECK: omp.parallel { ! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box> ! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref>> -! CHECK: omp.sections reduction(byref @add_reduction_byref_box_Uxf32 -> %[[VAL_3]] : !fir.ref>>) { +! CHECK: omp.sections reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] -> %[[ARG_1:.*]] : !fir.ref>>) { ! CHECK: ^bb0(%[[VAL_4:.*]]: !fir.ref>>): ! CHECK: omp.section { ! CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref>>): diff --git a/flang/test/Lower/OpenMP/sections-reduction.f90 b/flang/test/Lower/OpenMP/sections-reduction.f90 index 854f9ea22a7ddd..9299dfcd6a7115 100644 --- a/flang/test/Lower/OpenMP/sections-reduction.f90 +++ b/flang/test/Lower/OpenMP/sections-reduction.f90 @@ -40,7 +40,7 @@ subroutine sectionsReduction(x,y) ! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEx"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) ! CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_1]] dummy_scope %[[VAL_2]] {uniq_name = "_QFsectionsreductionEy"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) ! CHECK: omp.parallel { -! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref) { +! CHECK: omp.sections reduction(@add_reduction_f32 %[[VAL_3]]#0 -> %[[ARG_0:.*]] : !fir.ref, @add_reduction_f32 %[[VAL_4]]#0 -> %[[ARG_1:.*]] : !fir.ref) { ! CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref, %[[VAL_6:.*]]: !fir.ref): ! CHECK: omp.section { ! CHECK: ^bb0(%[[VAL_7:.*]]: !fir.ref, %[[VAL_8:.*]]: !fir.ref): @@ -71,7 +71,7 @@ subroutine sectionsReduction(x,y) ! CHECK: omp.terminator ! CHECK: } ! CHECK: omp.parallel { -! CHECK: omp.sections reduction(@add_reduction_f32 -> %[[VAL_3]]#0 : !fir.ref, @add_reduction_f32 -> %[[VAL_4]]#0 : !fir.ref) { +! CHECK: omp.sections reduction(@add_reduction_f32 %[[VAL_3]]#0 -> %[[ARG_2:.*]] : !fir.ref, @add_reduction_f32 %[[VAL_4]]#0 -> %[[ARG_3:.*]] : !fir.ref) { ! CHECK: ^bb0(%[[VAL_23:.*]]: !fir.ref, %[[VAL_24:.*]]: !fir.ref): ! CHECK: omp.section { ! CHECK: ^bb0(%[[VAL_25:.*]]: !fir.ref, %[[VAL_26:.*]]: !fir.ref): diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index ab02a46f433cdf..bd76605adce1da 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1844,8 +1844,6 @@ class OpenMPIRBuilder { /// nowait. /// \param IsTeamsReduction Optional flag set if it is a teams /// reduction. - /// \param HasDistribute Optional flag set if it is a - /// distribute reduction. /// \param GridValue Optional GPU grid value. /// \param ReductionBufNum Optional OpenMPCUDAReductionBufNumValue to be /// used for teams reduction. @@ -1854,7 +1852,6 @@ class OpenMPIRBuilder { const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, ArrayRef ReductionInfos, bool IsNoWait = false, bool IsTeamsReduction = false, - bool HasDistribute = false, ReductionGenCBKind ReductionGenCBKind = ReductionGenCBKind::MLIR, std::optional GridValue = {}, unsigned ReductionBufNum = 1024, Value *SrcLocInfo = nullptr); @@ -1926,8 +1923,7 @@ class OpenMPIRBuilder { InsertPointTy AllocaIP, ArrayRef ReductionInfos, ArrayRef IsByRef, bool IsNoWait = false, - bool IsTeamsReduction = false, - bool HasDistribute = false); + bool IsTeamsReduction = false); ///} diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index c2624ee12f5958..8d629f8f0fa9d8 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -3412,9 +3412,9 @@ checkReductionInfos(ArrayRef ReductionInfos, OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, ArrayRef ReductionInfos, - bool IsNoWait, bool IsTeamsReduction, bool HasDistribute, - ReductionGenCBKind ReductionGenCBKind, std::optional GridValue, - unsigned ReductionBufNum, Value *SrcLocInfo) { + bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind, + std::optional GridValue, unsigned ReductionBufNum, + Value *SrcLocInfo) { if (!updateToLocation(Loc)) return InsertPointTy(); Builder.restoreIP(CodeGenIP); @@ -3590,13 +3590,11 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU( ReductionFunc; }); } else { - if (!HasDistribute || IsTeamsReduction) { - Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs"); - Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs"); - Value *Reduced; - RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced); - Builder.CreateStore(Reduced, LHS, false); - } + Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs"); + Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs"); + Value *Reduced; + RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced); + Builder.CreateStore(Reduced, LHS, false); } } emitBlock(ExitBB, CurFunc); @@ -3685,11 +3683,11 @@ static void populateReductionFunction( OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions( const LocationDescription &Loc, InsertPointTy AllocaIP, ArrayRef ReductionInfos, ArrayRef IsByRef, - bool IsNoWait, bool IsTeamsReduction, bool HasDistribute) { + bool IsNoWait, bool IsTeamsReduction) { assert(ReductionInfos.size() == IsByRef.size()); if (Config.isGPU()) return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos, - IsNoWait, IsTeamsReduction, HasDistribute); + IsNoWait, IsTeamsReduction); checkReductionInfos(ReductionInfos, /*IsGPU*/ false); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 089c1e27147e13..bbef95c6db93b9 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -472,16 +472,20 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op, //===----------------------------------------------------------------------===// static ParseResult parseClauseWithRegionArgs( - OpAsmParser &parser, Region ®ion, + OpAsmParser &parser, SmallVectorImpl &operands, SmallVectorImpl &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols, - SmallVectorImpl ®ionPrivateArgs) { + SmallVectorImpl ®ionPrivateArgs, + bool parseParens = true) { SmallVector reductionVec; SmallVector isByRefVec; unsigned regionArgOffset = regionPrivateArgs.size(); + OpAsmParser::Delimiter delimiter = parseParens ? OpAsmParser::Delimiter::Paren + : OpAsmParser::Delimiter::None; + if (failed( - parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + parser.parseCommaSeparatedList(delimiter, [&]() { ParseResult optionalByref = parser.parseOptionalKeyword("byref"); if (parser.parseAttribute(reductionVec.emplace_back()) || parser.parseOperand(operands.emplace_back()) || @@ -536,7 +540,7 @@ static ParseResult parseParallelRegion( llvm::SmallVector regionPrivateArgs; if (succeeded(parser.parseOptionalKeyword("reduction"))) { - if (failed(parseClauseWithRegionArgs(parser, region, reductionVars, + if (failed(parseClauseWithRegionArgs(parser, reductionVars, reductionTypes, reductionByref, reductionSyms, regionPrivateArgs))) return failure(); @@ -544,7 +548,7 @@ static ParseResult parseParallelRegion( if (succeeded(parser.parseOptionalKeyword("private"))) { auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {}); - if (failed(parseClauseWithRegionArgs(parser, region, privateVars, + if (failed(parseClauseWithRegionArgs(parser, privateVars, privateTypes, privateByref, privateSyms, regionPrivateArgs))) return failure(); @@ -597,48 +601,26 @@ static ParseResult parseReductionVarList( SmallVectorImpl &reductionVars, SmallVectorImpl &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms) { - SmallVector reductionVec; - SmallVector isByRefVec; - if (failed(parser.parseCommaSeparatedList([&]() { - ParseResult optionalByref = parser.parseOptionalKeyword("byref"); - if (parser.parseAttribute(reductionVec.emplace_back()) || - parser.parseArrow() || - parser.parseOperand(reductionVars.emplace_back()) || - parser.parseColonType(reductionTypes.emplace_back())) - return failure(); - isByRefVec.push_back(optionalByref.succeeded()); - return success(); - }))) - return failure(); - reductionByref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec); - SmallVector reductions(reductionVec.begin(), reductionVec.end()); - reductionSyms = ArrayAttr::get(parser.getContext(), reductions); - return success(); + llvm::SmallVector regionPrivateArgs; + return parseClauseWithRegionArgs(parser, reductionVars, reductionTypes, + reductionByref, reductionSyms, + regionPrivateArgs, /*parseParens=*/false); } /// Print Reduction clause -static void -printReductionVarList(OpAsmPrinter &p, Operation *op, - OperandRange reductionVars, TypeRange reductionTypes, - std::optional reductionByref, - std::optional reductionSyms) { - auto getByRef = [&](unsigned i) -> const char * { - if (!reductionByref || !*reductionByref) - return ""; - assert(reductionByref->empty() || i < reductionByref->size()); - if (!reductionByref->empty() && (*reductionByref)[i]) - return "byref "; - return ""; - }; - - for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) { - if (i != 0) - p << ", "; - p << getByRef(i) << (*reductionSyms)[i] << " -> " << reductionVars[i] - << " : " << reductionVars[i].getType(); +static void printReductionVarList(OpAsmPrinter &p, Operation *op, + OperandRange reductionVars, + TypeRange reductionTypes, + DenseBoolArrayAttr reductionByref, + ArrayAttr reductionSyms) { + if (reductionSyms) { + auto *argsBegin = op->getRegion(0).front().getArguments().begin(); + MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size()); + printClauseWithRegionArgs(p, op, argsSubrange, llvm::StringRef(), + reductionVars, reductionTypes, reductionByref, + reductionSyms); } } - /// Verifies Reduction Clause static LogicalResult verifyReductionVarList(Operation *op, std::optional reductionSyms, @@ -1824,7 +1806,7 @@ parseWsloop(OpAsmParser &parser, Region ®ion, // Parse an optional reduction clause llvm::SmallVector privates; if (succeeded(parser.parseOptionalKeyword("reduction"))) { - if (failed(parseClauseWithRegionArgs(parser, region, reductionOperands, + if (failed(parseClauseWithRegionArgs(parser, reductionOperands, reductionTypes, reductionByRef, reductionSymbols, privates))) return failure(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8a1911bb50ebcd..62900d815ec486 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -925,7 +925,7 @@ static LogicalResult createReductionsAndCleanup( SmallVector &owningReductionGens, SmallVector &owningAtomicReductionGens, SmallVector &reductionInfos, - bool isTeamsReduction = false, bool hasDistribute = false) { + bool isNowait = false, bool isTeamsReduction = false) { // Process the reductions if required. if (op.getNumReductionVars() == 0) return success(); @@ -945,8 +945,7 @@ static LogicalResult createReductionsAndCleanup( builder.SetInsertPoint(tempTerminator); llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint = ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos, - isByRef, op.getNowait(), isTeamsReduction, - hasDistribute); + isByRef, isNowait, isTeamsReduction); if (!contInsertPoint.getBlock()) return op->emitOpError() << "failed to convert reductions"; auto nextInsertionPoint = @@ -1161,7 +1160,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, return createReductionsAndCleanup( sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls, privateReductionVariables, isByRef, owningReductionGens, - owningAtomicReductionGens, reductionInfos); + owningAtomicReductionGens, reductionInfos, sectionsOp.getNowait()); } /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder. @@ -1205,10 +1204,36 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; LogicalResult bodyGenStatus = success(); - if (!op.getAllocatorVars().empty() || op.getReductionSyms() || - !op.getPrivateVars().empty() || op.getPrivateSyms()) + if (!op.getAllocatorVars().empty() || !op.getPrivateVars().empty() || + op.getPrivateSyms()) return op.emitError("unhandled clauses for translation to LLVM IR"); + llvm::ArrayRef isByRef = getIsByRef(op.getReductionByref()); + assert(isByRef.size() == op.getNumReductionVars()); + + SmallVector reductionDecls; + collectReductionDecls(op, reductionDecls); + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); + + SmallVector privateReductionVariables( + op.getNumReductionVars()); + DenseMap reductionVariableMap; + + MutableArrayRef reductionArgs = op.getRegion().getArguments(); + + if (failed(allocAndInitializeReductionVars( + op, reductionArgs, builder, moduleTranslation, allocaIP, + reductionDecls, privateReductionVariables, reductionVariableMap, + isByRef))) + return failure(); + + // Store the mapping between reduction variables and their private copies on + // ModuleTranslation stack. It can be then recovered when translating + // omp.reduce operations in a separate call. + LLVM::ModuleTranslation::SaveStack mappingGuard( + moduleTranslation, reductionVariableMap); + auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { LLVM::ModuleTranslation::SaveStack frame( moduleTranslation, allocaIP); @@ -1238,7 +1263,18 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, builder.restoreIP(ompBuilder->createTeams( ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr)); - return bodyGenStatus; + if (failed(bodyGenStatus)) + return bodyGenStatus; + + // Process the reductions if required. + SmallVector owningReductionGens; + SmallVector owningAtomicReductionGens; + SmallVector reductionInfos; + return createReductionsAndCleanup(op, builder, moduleTranslation, allocaIP, + reductionDecls, privateReductionVariables, + isByRef, owningReductionGens, + owningAtomicReductionGens, reductionInfos, + /*isNoWait*/false, /*isTeamsReduction*/true ); } static void @@ -1508,8 +1544,8 @@ static LogicalResult convertOmpWsloop( return createReductionsAndCleanup( wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls, privateReductionVariables, isByRef, owningReductionGens, - owningAtomicReductionGens, reductionInfos, /*isTeamsReduction=*/false, - distributeCodeGen); + owningAtomicReductionGens, reductionInfos, wsloopOp.getNowait(), + /*isTeamsReduction=*/false); } static LogicalResult @@ -1704,10 +1740,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // Generate reductions from info llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable(); builder.SetInsertPoint(tempTerminator); + llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint = ompBuilder->createReductions(builder.saveIP(), allocaIP, - reductionInfos, isByRef, false, false, - false); + reductionInfos, isByRef, false, false); if (!contInsertPoint.getBlock()) { bodyGenStatus = opInst->emitOpError() << "failed to convert reductions"; return; @@ -3376,24 +3412,6 @@ static LogicalResult convertOmpDistribute( builder.SetInsertPoint(regionBlock->getTerminator()); } - - // FIXME(JAN): We need to know if we are inside a distribute and - // if there is an inner wsloop reduction, in that case we need to - // generate the teams reduction bits to combine everything correctly. We - // will try to collect the reduction info from the inner wsloop and use - // that instead of the reduction clause that could have been on the - // omp.parallel - auto IP = builder.saveIP(); - if (ompBuilder->Config.isGPU()) { - // TODO: Consider passing the isByref array together with reductionInfos - // if it needs to match nested parallel-do or simd. - SmallVector isByref(reductionInfos.size(), true); - llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint = - ompBuilder->createReductions(IP, allocaIP, reductionInfos, isByref, - /*IsNoWait=*/false, - /*IsTeamsReduction=*/true); - builder.restoreIP(contInsertPoint); - } }; llvm::OpenMPIRBuilder::InsertPointTy allocaIP = @@ -3661,6 +3679,29 @@ static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) { return op->getParentOfType(); } +static uint64_t getTypeByteSize(mlir::Type type, DataLayout dl) { + uint64_t sizeInBits = dl.getTypeSizeInBits(type); + uint64_t sizeInBytes = sizeInBits / 8; + return sizeInBytes; +} + +template +static uint64_t getReductionDataSize(OpTy &op) { + if (op.getNumReductionVars() > 0) { + assert(op.getNumReductionVars() && + "Only 1 reduction variable currently supported"); + mlir::Type reductionVarTy = op.getReductionVars()[0].getType(); + Operation *opp = op.getOperation(); + DataLayout dl = DataLayout(opp->getParentOfType()); + return getTypeByteSize(reductionVarTy, dl); + } + return 0; +} + +static uint64_t getTeamsReductionDataSize(mlir::omp::TeamsOp &teamsOp) { + return getReductionDataSize(teamsOp); +} + /// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default /// values as stated by the corresponding clauses, if constant. /// @@ -3757,32 +3798,9 @@ static void initTargetDefaultBounds( // for now. int32_t reductionDataSize = 0; if (isGPU && innermostCapturedOmpOp) { - if (auto loopNestOp = - mlir::dyn_cast(innermostCapturedOmpOp)) { - // FIXME: This treats 'DO SIMD' as if it was a 'DO' construct. Reductions - // on other constructs apart from 'DO' aren't considered either. - mlir::omp::WsloopOp wsloopOp = nullptr; - SmallVector wrappers; - loopNestOp.gatherWrappers(wrappers); - for (auto wrapper : wrappers) { - wsloopOp = mlir::dyn_cast(*wrapper); - if (wsloopOp) - break; - } - if (wsloopOp) { - if (wsloopOp.getNumReductionVars() > 0) { - assert(wsloopOp.getNumReductionVars() && - "Only 1 reduction variable currently supported"); - mlir::Value reductionVar = wsloopOp.getReductionVars()[0]; - DataLayout dl = - DataLayout(innermostCapturedOmpOp->getParentOfType()); - - mlir::Type reductionVarTy = reductionVar.getType(); - uint64_t sizeInBits = dl.getTypeSizeInBits(reductionVarTy); - uint64_t sizeInBytes = sizeInBits / 8; - reductionDataSize = sizeInBytes; - } - } + if (auto teamsOp = + castOrGetParentOfType(innermostCapturedOmpOp)) { + reductionDataSize = getTeamsReductionDataSize(teamsOp); } } @@ -4059,45 +4077,6 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, return success(); } -/////////////////////////////////////////////////////////////////////////////// -// CompoundConstructs lowering forward declarations -class OpenMPDialectLLVMIRTranslationInterface; - -using ConvertFunctionTy = std::function( - Operation *, llvm::IRBuilderBase &, LLVM::ModuleTranslation &)>; - -class ConversionDispatchList { -private: - llvm::SmallVector functions; - -public: - std::pair - convertOperation(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { - for (auto riter = functions.rbegin(); riter != functions.rend(); ++riter) { - bool match = false; - LogicalResult result = failure(); - std::tie(match, result) = (*riter)(op, builder, moduleTranslation); - if (match) - return {true, result}; - } - return {false, failure()}; - } - - void pushConversionFunction(ConvertFunctionTy function) { - functions.push_back(function); - } - void popConversionFunction() { functions.pop_back(); } -}; - -static LogicalResult convertOmpDistributeParallelWsloop( - omp::ParallelOp parallel, omp::DistributeOp distribute, - omp::WsloopOp wsloop, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - ConversionDispatchList &dispatchList); - -/////////////////////////////////////////////////////////////////////////////// -// Dispatch functions /// Given an OpenMP MLIR operation, create the corresponding LLVM IR /// (including OpenMP runtime calls). @@ -4239,55 +4218,6 @@ static bool isTargetDeviceOp(Operation *op) { return false; } -// Returns true if the given block has a single instruction. -static bool singleInstrBlock(Block &block) { - bool result = (block.getOperations().size() == 2); - if (!result) { - llvm::errs() << "Num ops: " << block.getOperations().size() << "\n"; - } - return result; -} - -// Returns the operation if it only contains one instruction otherwise -// return nullptr. -template -Operation *getContainedInstr(OpType op) { - Region ®ion = op.getRegion(); - if (!region.hasOneBlock()) { - llvm::errs() << "Region has multiple blocks\n"; - return nullptr; - } - Block &block = region.front(); - if (!singleInstrBlock(block)) { - return nullptr; - } - return &(block.getOperations().front()); -} - -// Returns the operation if it only contains one instruction otherwise -// return nullptr. -template -Block &getContainedBlock(OpType op) { - Region ®ion = op.getRegion(); - return region.front(); -} - -template -bool matchOpScanNest(Block &block, FirstOpType &firstOp, - RestOpTypes &...restOps) { - for (Operation &op : block) { - if ((firstOp = mlir::dyn_cast(op))) { - if constexpr (sizeof...(RestOpTypes) == 0) { - return true; - } else { - Block &innerBlock = getContainedBlock(firstOp); - return matchOpScanNest(innerBlock, restOps...); - } - } - } - return false; -} - template bool matchOpNest(Operation *op, FirstOpType &firstOp, RestOpTypes &...restOps) { if ((firstOp = mlir::dyn_cast(op))) { @@ -4303,17 +4233,7 @@ bool matchOpNest(Operation *op, FirstOpType &firstOp, RestOpTypes &...restOps) { static LogicalResult convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - ConversionDispatchList &dispatchList) { - omp::ParallelOp parallel; - omp::DistributeOp distribute; - omp::WsloopOp wsloop; - // Match composite constructs - if (matchOpNest(op, parallel, distribute, wsloop)) { - return convertOmpDistributeParallelWsloop( - parallel, distribute, wsloop, builder, moduleTranslation, dispatchList); - } - + LLVM::ModuleTranslation &moduleTranslation) { return convertHostOrTargetOperation(op, builder, moduleTranslation); } @@ -4341,67 +4261,6 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder, return failure(interrupted); } -/////////////////////////////////////////////////////////////////////////////// -// CompoundConstructs lowering implementations - -// Implementation converting a nest of operations in a single function. This -// just overrides the parallel and wsloop dispatches but does the normal -// lowering for now. -static LogicalResult convertOmpDistributeParallelWsloop( - omp::ParallelOp parallel, omp::DistributeOp distribute, - omp::WsloopOp wsloop, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - ConversionDispatchList &dispatchList) { - - // Reduction related data structures - SmallVector owningReductionGens; - SmallVector owningAtomicReductionGens; - SmallVector reductionInfos; - llvm::OpenMPIRBuilder::InsertPointTy redAllocaIP; - - // Convert wsloop alternative implementation - ConvertFunctionTy convertWsloop = - [&redAllocaIP, &owningReductionGens, &owningAtomicReductionGens, - &reductionInfos](Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { - if (!isa(op)) { - return std::make_pair(false, failure()); - } - - LogicalResult result = convertOmpWsloop( - *op, builder, moduleTranslation, redAllocaIP, owningReductionGens, - owningAtomicReductionGens, reductionInfos); - return std::make_pair(true, result); - }; - - // Convert distribute alternative implementation - ConvertFunctionTy convertDistribute = - [&redAllocaIP, - &reductionInfos](Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { - if (!isa(op)) { - return std::make_pair(false, failure()); - } - - LogicalResult result = convertOmpDistribute( - *op, builder, moduleTranslation, &redAllocaIP, reductionInfos); - return std::make_pair(true, result); - }; - - // Push the new alternative functions - dispatchList.pushConversionFunction(convertWsloop); - dispatchList.pushConversionFunction(convertDistribute); - - // Lower the current parallel operation - LogicalResult result = - convertOmpParallel(parallel, builder, moduleTranslation); - - // Pop the alternative functions - dispatchList.popConversionFunction(); - dispatchList.popConversionFunction(); - - return result; -} /////////////////////////////////////////////////////////////////////////////// // OpenMPDialectLLVMIRTranslationInterface @@ -4410,9 +4269,6 @@ static LogicalResult convertOmpDistributeParallelWsloop( /// to the OpenMP dialect to LLVM IR. class OpenMPDialectLLVMIRTranslationInterface : public LLVMTranslationDialectInterface { -private: - mutable ConversionDispatchList dispatchList; - public: using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; @@ -4541,21 +4397,12 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const { - // Check to see if there is a lowering that overrides the default lowering - // if not use the default dispatch. - bool match = false; - LogicalResult result = success(); - std::tie(match, result) = - dispatchList.convertOperation(op, builder, moduleTranslation); - if (match) - return result; - llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); if (ompBuilder->Config.isTargetDevice()) { if (isTargetDeviceOp(op)) - return convertTargetDeviceOp(op, builder, moduleTranslation, - dispatchList); - return convertTargetOpsInNest(op, builder, moduleTranslation); + return convertTargetDeviceOp(op, builder, moduleTranslation); + else + return convertTargetOpsInNest(op, builder, moduleTranslation); } return convertHostOrTargetOperation(op, builder, moduleTranslation); } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index c8213f6e32f5b3..ccd8c1e1dbf8f6 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1643,7 +1643,8 @@ func.func @omp_task_depend(%data_var: memref) { func.func @omp_task(%ptr: !llvm.ptr) { // expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}} - omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr) { + omp.task in_reduction(@add_f32 %ptr -> %arg0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -1667,7 +1668,8 @@ combiner { func.func @omp_task(%ptr: !llvm.ptr) { // expected-error @below {{op accumulator variable used more than once}} - omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr, @add_f32 -> %ptr : !llvm.ptr) { + omp.task in_reduction(@add_f32 %ptr -> %arg0 : !llvm.ptr, @add_f32 %ptr -> %arg1 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -1697,7 +1699,8 @@ atomic { func.func @omp_task(%mem: memref<1xf32>) { // expected-error @below {{op expected accumulator ('memref<1xf32>') to be the same type as reduction declaration ('!llvm.ptr')}} - omp.task in_reduction(@add_i32 -> %mem : memref<1xf32>) { + omp.task in_reduction(@add_i32 %mem -> %arg0 : memref<1xf32>) { + ^bb0(%arg0: memref<1xf32>): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -1908,7 +1911,8 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) { %testf32 = "test.f32"() : () -> (!llvm.ptr) %testf32_2 = "test.f32"() : () -> (!llvm.ptr) // expected-error @below {{if a reduction clause is present on the taskloop directive, the nogroup clause must not be specified}} - omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) nogroup { + omp.taskloop reduction(@add_f32 %testf32 -> %arg0 : !llvm.ptr, @add_f32 %testf32_2 -> %arg1 : !llvm.ptr) nogroup { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { omp.yield } @@ -1933,7 +1937,8 @@ combiner { func.func @taskloop(%lb: i32, %ub: i32, %step: i32) { %testf32 = "test.f32"() : () -> (!llvm.ptr) // expected-error @below {{the same list item cannot appear in both a reduction and an in_reduction clause}} - omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr) in_reduction(@add_f32 -> %testf32 : !llvm.ptr) { + omp.taskloop reduction(@add_f32 %testf32 -> %arg0 : !llvm.ptr) in_reduction(@add_f32 %testf32 -> %arg0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { omp.yield } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 8d868f8879bd34..651e4314d348b5 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1094,16 +1094,18 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, // Test reduction. %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - // CHECK: omp.teams reduction(@add_f32 -> %{{.+}} : !llvm.ptr) { - omp.teams reduction(@add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.teams reduction(@add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr) { + omp.teams reduction(@add_f32 %0 -> %arg0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): %1 = arith.constant 2.0 : f32 // CHECK: omp.terminator omp.terminator } // Test reduction byref - // CHECK: omp.teams reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr) { - omp.teams reduction(byref @add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.teams reduction(byref @add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr) { + omp.teams reduction(byref @add_f32 %0 -> %arg0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): %1 = arith.constant 2.0 : f32 // CHECK: omp.terminator omp.terminator @@ -1123,8 +1125,9 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32, func.func @sections_reduction() { %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - // CHECK: omp.sections reduction(@add_f32 -> {{.+}} : !llvm.ptr) - omp.sections reduction(@add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.sections reduction(@add_f32 {{.+}} -> {{.+}} : !llvm.ptr) + omp.sections reduction(@add_f32 %0 -> %arg0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): // CHECK: omp.section omp.section { %1 = arith.constant 2.0 : f32 @@ -1144,9 +1147,10 @@ func.func @sections_reduction() { func.func @sections_reduction_byref() { %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr - // CHECK: omp.sections reduction(byref @add_f32 -> {{.+}} : !llvm.ptr) - omp.sections reduction(byref @add_f32 -> %0 : !llvm.ptr) { - // CHECK: omp.section + // CHECK: omp.sections reduction(byref @add_f32 {{.+}} -> {{.+}} : !llvm.ptr) + omp.sections reduction(byref @add_f32 %0 -> %arg0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): + // CHECK: omp.section omp.section { %1 = arith.constant 2.0 : f32 omp.terminator @@ -1243,8 +1247,9 @@ func.func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) { // CHECK-LABEL: func @sections_reduction2 func.func @sections_reduction2() { %0 = memref.alloca() : memref<1xf32> - // CHECK: omp.sections reduction(@add2_f32 -> %{{.+}} : memref<1xf32>) - omp.sections reduction(@add2_f32 -> %0 : memref<1xf32>) { + // CHECK: omp.sections reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>) + omp.sections reduction(@add2_f32 %0 -> %arg0 : memref<1xf32>) { + ^bb0(%arg0: !llvm.ptr): omp.section { %1 = arith.constant 2.0 : f32 omp.terminator @@ -1899,8 +1904,9 @@ func.func @omp_sectionsop(%data_var1 : memref, %data_var2 : memref, omp.terminator }) {operandSegmentSizes = array} : (memref, memref) -> () - // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) + // CHECK: omp.sections reduction(@add_f32 %{{.*}} -> %{{.*}} : !llvm.ptr) "omp.sections" (%redn_var) ({ + ^bb0(%arg0: !llvm.ptr): // CHECK: omp.terminator omp.terminator }) {operandSegmentSizes = array, reduction_byref = array, reduction_syms=[@add_f32]} : (!llvm.ptr) -> () @@ -1911,8 +1917,9 @@ func.func @omp_sectionsop(%data_var1 : memref, %data_var2 : memref, omp.terminator } - // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) { - omp.sections reduction(@add_f32 -> %redn_var : !llvm.ptr) { + // CHECK: omp.sections reduction(@add_f32 %{{.*}} -> %{{.*}} : !llvm.ptr) { + omp.sections reduction(@add_f32 %redn_var -> %arg0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): // CHECK: omp.terminator omp.terminator } @@ -2085,8 +2092,9 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr %0 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr // CHECK: %[[redn_var2:.*]] = llvm.alloca %{{.*}} x f32 : (i32) -> !llvm.ptr %1 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr - // CHECK: omp.task in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) { - omp.task in_reduction(@add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) { + // CHECK: omp.task in_reduction(@add_f32 %[[redn_var1]] -> %arg4 : !llvm.ptr, @add_f32 %[[redn_var2]] -> %arg5 : !llvm.ptr) { + omp.task in_reduction(@add_f32 %0 -> %arg0 : !llvm.ptr, @add_f32 %1 -> %arg1 : !llvm.ptr) { + ^bb0(%arg4: !llvm.ptr, %arg5: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -2094,8 +2102,9 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr } // Checking `in_reduction` clause (mixed) byref - // CHECK: omp.task in_reduction(byref @add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) { - omp.task in_reduction(byref @add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) { + // CHECK: omp.task in_reduction(byref @add_f32 %[[redn_var1]] -> %arg4 : !llvm.ptr, @add_f32 %[[redn_var2]] -> %arg5 : !llvm.ptr) { + omp.task in_reduction(byref @add_f32 %0 -> %arg0 : !llvm.ptr, @add_f32 %1 -> %arg1 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -2125,10 +2134,11 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr omp.task allocate(%data_var : memref -> %data_var : memref) // CHECK-SAME: final(%[[bool_var]]) if(%[[bool_var]]) final(%bool_var) if(%bool_var) - // CHECK-SAME: in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr, byref @add_f32 -> %[[redn_var2]] : !llvm.ptr) - in_reduction(@add_f32 -> %0 : !llvm.ptr, byref @add_f32 -> %1 : !llvm.ptr) + // CHECK-SAME: in_reduction(@add_f32 %[[redn_var1]] -> %arg4 : !llvm.ptr, byref @add_f32 %[[redn_var2]] -> %arg5 : !llvm.ptr) + in_reduction(@add_f32 %0 -> %arg0 : !llvm.ptr, byref @add_f32 %1 -> %arg1 : !llvm.ptr) // CHECK-SAME: priority(%[[i32_var]] : i32) untied priority(%i32_var : i32) untied { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): // CHECK: "test.foo"() : () -> () "test.foo"() : () -> () // CHECK: omp.terminator @@ -2304,8 +2314,9 @@ func.func @omp_taskgroup_multiple_tasks() -> () { func.func @omp_taskgroup_clauses() -> () { %testmemref = "test.memref"() : () -> (memref) %testf32 = "test.f32"() : () -> (!llvm.ptr) - // CHECK: omp.taskgroup allocate(%{{.+}}: memref -> %{{.+}}: memref) task_reduction(@add_f32 -> %{{.+}}: !llvm.ptr) - omp.taskgroup allocate(%testmemref : memref -> %testmemref : memref) task_reduction(@add_f32 -> %testf32 : !llvm.ptr) { + // CHECK: omp.taskgroup allocate(%{{.+}}: memref -> %{{.+}}: memref) task_reduction(@add_f32 %{{.+}} -> %{{.+}}: !llvm.ptr) + omp.taskgroup allocate(%testmemref : memref -> %testmemref : memref) task_reduction(@add_f32 %testf32 -> %arg0 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr): // CHECK: omp.task omp.task { "test.foo"() : () -> () @@ -2376,8 +2387,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () { %testf32 = "test.f32"() : () -> (!llvm.ptr) %testf32_2 = "test.f32"() : () -> (!llvm.ptr) - // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) { - omp.taskloop in_reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) { + // CHECK: omp.taskloop in_reduction(@add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr, @add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr) { + omp.taskloop in_reduction(@add_f32 %testf32 -> %arg0 : !llvm.ptr, @add_f32 %testf32_2 -> %arg1 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { // CHECK: omp.yield omp.yield @@ -2386,8 +2398,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () { } // Checking byref attribute for in_reduction - // CHECK: omp.taskloop in_reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) { - omp.taskloop in_reduction(byref @add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) { + // CHECK: omp.taskloop in_reduction(byref @add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr, @add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr) { + omp.taskloop in_reduction(byref @add_f32 %testf32 -> %arg0 : !llvm.ptr, @add_f32 %testf32_2 -> %arg1 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { // CHECK: omp.yield omp.yield @@ -2395,8 +2408,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () { omp.terminator } - // CHECK: omp.taskloop reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) { - omp.taskloop reduction(byref @add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) { + // CHECK: omp.taskloop reduction(byref @add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr, @add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr) { + omp.taskloop reduction(byref @add_f32 %testf32 -> %arg0 : !llvm.ptr, @add_f32 %testf32_2 -> %arg1 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { // CHECK: omp.yield omp.yield @@ -2405,8 +2419,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () { } // check byref attrbute for reduction - // CHECK: omp.taskloop reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, byref @add_f32 -> %{{.+}} : !llvm.ptr) { - omp.taskloop reduction(byref @add_f32 -> %testf32 : !llvm.ptr, byref @add_f32 -> %testf32_2 : !llvm.ptr) { + // CHECK: omp.taskloop reduction(byref @add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr, byref @add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr) { + omp.taskloop reduction(byref @add_f32 %testf32 -> %arg0 : !llvm.ptr, byref @add_f32 %testf32_2 -> %arg1 : !llvm.ptr) { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { // CHECK: omp.yield omp.yield @@ -2414,8 +2429,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () { omp.terminator } - // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr) reduction(@add_f32 -> %{{.+}} : !llvm.ptr) { - omp.taskloop in_reduction(@add_f32 -> %testf32 : !llvm.ptr) reduction(@add_f32 -> %testf32_2 : !llvm.ptr) { + // CHECK: omp.taskloop in_reduction(@add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr) reduction(@add_f32 %{{.+}} -> %{{.+}} : !llvm.ptr) { + omp.taskloop in_reduction(@add_f32 %testf32 -> %arg3 : !llvm.ptr) reduction(@add_f32 %testf32_2 -> %arg4 : !llvm.ptr) { + ^bb0(%arg3: !llvm.ptr, %arg4: !llvm.ptr): omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) { // CHECK: omp.yield omp.yield diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir index 2d8a13ccd2a1f5..2d696fc2a01694 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir @@ -44,7 +44,7 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute %2 = llvm.mlir.constant(1 : index) : i64 omp.parallel { %3 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr - omp.sections reduction(byref @add_reduction_byref_box_Uxf32 -> %3 : !llvm.ptr) { + omp.sections reduction(byref @add_reduction_byref_box_Uxf32 %3 -> %arg1 : !llvm.ptr) { ^bb0(%arg1: !llvm.ptr): omp.section { ^bb0(%arg2: !llvm.ptr): diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir index 694180a5ced373..f70e8ecb15d2d2 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir @@ -13,7 +13,7 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in %0 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %1 = llvm.mlir.constant(1.000000e+00 : f32) : f32 omp.parallel { - omp.sections reduction(@add_reduction_f32 -> %arg0 : !llvm.ptr) { + omp.sections reduction(@add_reduction_f32 %arg0 -> %arg1 : !llvm.ptr) { ^bb0(%arg1: !llvm.ptr): omp.section { ^bb0(%arg2: !llvm.ptr):