Skip to content

Commit

Permalink
Add initial IR for alloc enqueue
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Nov 4, 2024
1 parent 4c7b5b8 commit 5ad00f2
Show file tree
Hide file tree
Showing 20 changed files with 477 additions and 144 deletions.
15 changes: 14 additions & 1 deletion mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
disallowHostTensorsInTensorRTClusters, llvm::cl::init(false),
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
"calculations (but they can still be inputs)"));
addOption(
"enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false),
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"));
addOption("executor-index-bitwidth", executorIndexBitwidth,
llvm::cl::init(64));
addOption("device-compute-capability", deviceComputeCapability,
Expand Down Expand Up @@ -306,6 +310,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
plan::StablehloClusteringPassOptions clusteringOpts{};
clusteringOpts.disallowHostTensorsInTensorRTClusters =
opts.disallowHostTensorsInTensorRTClusters;
clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
clusteringOpts.entrypoint = opts.entrypoint;
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);

Expand Down Expand Up @@ -339,7 +344,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(

// Perform bufferization.
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(plan::createPlanAllocTensorsPass());
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -525,6 +532,11 @@ struct ClusteringPipelineCliOpts
*this, "device-compute-capability",
llvm::cl::desc("target device compute capability (SM version)"),
llvm::cl::init(60)};
Option<bool> enableNonDPSReturns{
*this, "enable-non-dps-returns",
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"),
llvm::cl::init(false)};
Option<int64_t> deviceMaxSharedMemoryPerBlockKb{
*this, "device-max-smem-per-block",
llvm::cl::desc("max shared memory per block (in kilobytes)"),
Expand Down Expand Up @@ -552,6 +564,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
opts.deviceMaxSharedMemoryPerBlockKb =
cliOpts.deviceMaxSharedMemoryPerBlockKb;
opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns;
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
opts.entrypoint = cliOpts.entrypoint;
return opts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ convertCallOp(Operation *op, IRRewriter &rewriter,
SmallVector<int64_t> hostTensorArgs;
for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) {
const TensorKindLattice *kind = solver.lookupState<TensorKindLattice>(arg);
if (!isa<RankedTensorType>(arg.getType()))
continue;
RankedTensorType rtt = cast<RankedTensorType>(arg.getType());
// To be conservative, we only do this if type is i32 and num elements
// <= 8.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern<plan::WithValuesOp> {
} // namespace

/// Get a map from `tensorrt.func` functions to associated `tensorrt.call`
/// operations.
static llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>>
/// and `tensorrt.call_alloc` operations.
static llvm::DenseMap<func::FuncOp, SmallVector<Operation *>>
getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>> map;
op->walk([&](tensorrt::CallOp callOp) {
func::FuncOp func = callOp.getFuncCallee(collection);
if (map.contains(func)) {
map[func].push_back(callOp);
llvm::DenseMap<func::FuncOp, SmallVector<Operation *>> map;
op->walk([&](Operation *callOp) {
if (!isa<tensorrt::CallOp, tensorrt::CallAllocOp>(callOp))
return;
}
map.insert(std::make_pair(func, SmallVector<tensorrt::CallOp>{callOp}));

func::FuncOp func;
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
func = call.getFuncCallee(collection);
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
func = callAlloc.getFuncCallee(collection);
else
return;

if (map.count(func))
map[func].push_back(callOp);
else
map.insert({func, SmallVector<Operation *>{callOp}});
});
return map;
}
Expand All @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
/// `tensorrt.call` operations.
static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
ModuleOp op, func::FuncOp funcOp,
ArrayRef<tensorrt::CallOp> callOps) {
ArrayRef<Operation *> callOps) {
llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0);
for (BlockArgument arg : funcOp.getArguments()) {
if (arg.use_empty())
Expand All @@ -99,8 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
funcOp.eraseArgument(i);

// Update the call ops.
for (tensorrt::CallOp callOp : callOps)
callOp.getInputsMutable().erase(i);
for (Operation *callOp : callOps) {
if (auto call = dyn_cast<tensorrt::CallOp>(callOp)) {
call.getInputsMutable().erase(i);
} else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp)) {
callAlloc.getInputsMutable().erase(i);
} else {
llvm::errs() << "Unexpected operation type in callOps\n";
callOp->dump();
}
}
}

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,82 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,

static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
plan::InlineClosedAllocGroupOp op) {
return op.emitError("outlinining inline closed alloc group ops to tensorrt "
"dialect is not yet implemented");
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));
FailureOr<FunctionOpInterface> func = createOutlinedFunc(
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
"cluster.tensorrt", TypeRange(op.getInputs()),
op.getYield()->getOperandTypes());
if (failed(func))
return failure();
assert(func->getFunctionBody().getBlocks().size() == 1 &&
"expected body with one block");
func->setPublic();

rewriter.setInsertionPoint(op);

auto callOp = rewriter.create<tensorrt::CallAllocOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(*func)}));

// Populate the function arguments attributes.
for (unsigned i = 0; i < (*func).getNumArguments(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getInputAttrs()[i]);
// We may have scalar (index|signless int)-typed values since we haven't
// eliminated `plan.(with_shape|with_values)` ops yet.
if (!op.argHasTensorType(i) || srcAttr.isNone())
continue;
FailureOr<tensorrt::ShapeProfileAttr> boundAttr =
getTensorRTShapeProfile(srcAttr, op.getInputs()[i]);
if (failed(boundAttr))
return op->emitOpError("failed to create TensorRT shape profile "
"attribute from Plan BoundsAttr for argument #")
<< i << " (" << srcAttr << ")";
if (srcAttr.isShapeBound()) {
func->setArgAttr(i,
tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
*boundAttr);
continue;
}
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
func->setArgAttr(
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
*boundAttr);
func->setArgAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}

// Populate the function entry block.
rewriter.eraseBlock(&func->getFunctionBody().front());

// Move private decomposition funcs associated with all `stablehlo.composite`
// ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op
// has its own symbol table.
SymbolTableCollection symbolTable;
for (auto compositeOp : op.getBody().getOps<stablehlo::CompositeOp>()) {
auto decompositionFunc = dyn_cast_if_present<func::FuncOp>(
symbolTable.lookupSymbolIn(op->getParentOfType<ModuleOp>(),
compositeOp.getDecompositionAttr()));
if (!decompositionFunc)
return emitError(compositeOp.getLoc())
<< "failed to lookup stablehlo.composite decomposition "
"function: "
<< compositeOp.getDecompositionAttr();
rewriter.moveOpAfter(decompositionFunc, func->getOperation());
}

// Move region op operations to the func body.
Operation *regionYieldOp = op.getYield();
rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(),
func->getFunctionBody().end());
rewriter.setInsertionPoint(regionYieldOp);
rewriter.replaceOpWithNewOp<func::ReturnOp>(regionYieldOp,
regionYieldOp->getOperands());

// replace the original region results.
rewriter.replaceOp(op, callOp);
return success();
}

/// Create outlined functions for each `scf.execute_region` operation within
Expand Down
30 changes: 25 additions & 5 deletions mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) {
return !client.ptr;
}

/// Returns null client.
static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() {
return MTRT_RuntimeClient{nullptr};
}

/// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the
/// program execution.
/// The `stream` passed to the client is used by all underlying CUDA methods
Expand Down Expand Up @@ -308,6 +313,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) {
return !value.ptr;
}

// Returns whether the RuntimeValue is MemRef.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value);

// Returns whether the RuntimeValue is Scalar.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value);

/// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue.
MLIR_CAPI_EXPORTED MTRT_RuntimeValue
mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref);
Expand Down Expand Up @@ -391,16 +402,25 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) {
return !session.ptr;
}

/// Using `session`, execute the pubic function with the specified name.
/// The `inArgs` and `outArgs` are arrays for input arguments and destination
/// arguments, respectively. Input arguments may be MemRefs or scalars, but
/// destination arguments must be MemRefs.
/// Using `session`, execute the public function with the specified name.
/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments,
/// output arguments, and return values, respectively. Arguments and results
/// can be MemRefs, scalars, or other supported types. Both `outArgs` and
/// `results` can be used simultaneously, allowing for functions that both
/// modify arguments and return values.
/// A stream may optionally be specified, otherwise pass the result of
/// `mtrtStreamGetNull()`.
///
/// The `results` array must point to an array with at least the number of
/// elements returned by mtrtRuntimeSessionGetNumResults for the given function.
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream);
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client);

MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNumResults(
MTRT_RuntimeSession session, MTRT_StringView name, int64_t *numResults);

//===----------------------------------------------------------------------===//
// DLPack
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,6 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name,
std::optional<CudaStream> stream = {},
std::optional<RuntimeClient *> client = {});

// Parses the results of a function call, handling both scalar and MemRef return
// types
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
parseResults(const sol::protected_function_result &pfr,
const FunctionSignatureView &sig,
std::optional<RuntimeClient *> client);

} // namespace mlirtrt::runtime

#endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H
40 changes: 34 additions & 6 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
return wrap(static_cast<ScalarValue *>(x));
}

bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::MemRef;
}

bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::Scalar;
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeSessionOptions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -721,7 +731,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) {
MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) {
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));

Expand All @@ -731,19 +742,36 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction(
llvm::SmallVector<RuntimeValue *> outArgValues =
llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs),
[](MTRT_RuntimeValue arg) { return unwrap(arg); });

StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> result =
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> resultValues =
executeFunctionWithLuaBackend(
*cppSession, std::string_view(name.data, name.length), inArgValues,
outArgValues,
!mtrtStreamIsNull(stream)
? std::optional(unwrap(stream)->getRawStream())
: std::nullopt);
if (!result.isOk())
return wrap(result.getStatus());
: std::nullopt,
!mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client))
: std::nullopt);
if (!resultValues.isOk())
return wrap(resultValues.getStatus());

for (size_t i = 0; i < resultValues->size(); ++i)
results[i] = wrap((*resultValues)[i].release());

return mtrtStatusGetOk();
}

MTRT_Status mtrtRuntimeSessionGetNumResults(MTRT_RuntimeSession session,
MTRT_StringView name,
int64_t *numResults) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));
*numResults = cppSession->getExecutable()
.getFunction(std::string_view(name.data, name.length))
.getSignature()
.getNumResults();
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeClient
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir-executor/Conversion/ConvertToExecutorCommon.h"
#include "mlir-executor/Conversion/Passes.h"
#include "mlir-executor/Executor/IR/Executor.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -548,6 +549,21 @@ void executor::populateMemRefToExecutorPatterns(
}

namespace {

class RemoveNoOpClonePattern : public OpRewritePattern<bufferization::CloneOp> {
public:
using OpRewritePattern<bufferization::CloneOp>::OpRewritePattern;

LogicalResult matchAndRewrite(bufferization::CloneOp op,
PatternRewriter &rewriter) const override {
if (op.getInput().getType() == op.getOutput().getType()) {
rewriter.replaceOp(op, op.getInput());
return success();
}
return failure();
}
};

/// Pass to convert `memref` to `executor` dialect operrations.
class ConvertMemRefToExecutorPass
: public mlir::executor::impl::ConvertMemRefToExecutorPassBase<
Expand Down Expand Up @@ -579,6 +595,10 @@ class ConvertMemRefToExecutorPass
RewritePatternSet patterns(ctx);
executor::populateMemRefToExecutorPatterns(
patterns, typeConverter, allowUncheckedMemrefCastConversion);

// Remove unrealized cast and redundant clone operations.
patterns.add<RemoveNoOpClonePattern>(ctx);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
Expand Down
Loading

0 comments on commit 5ad00f2

Please sign in to comment.