Skip to content

Commit

Permalink
Fix dynamic shapes and multiple return values
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Oct 27, 2024
1 parent 61a1bd9 commit cfeefc6
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,49 +161,4 @@ def TensorRTRuntime_EnqueueAllocOp : TensorRTRuntime_Op<"enqueue_alloc", [
}];
}

//===----------------------------------------------------------------------===//
// EnqueueAllocOp
//===----------------------------------------------------------------------===//

def TensorRTRuntime_EnqueueAllocOp : TensorRTRuntime_Op<"enqueue_alloc", [
DeclareOpInterfaceMethods<TensorKindOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
]> {
let description = [{
Asynchronously executes the computation represented by the
`execution_context` on the specified CUDA stream. This operation
can accept inputs of either tensor or memref types and returns
results of either tensor or memref types.
}];

let arguments = (ins
TensorRTRuntime_Context:$execution_context,
CUDA_Stream:$stream,
Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$inputs,
OptionalAttr<DenseI64ArrayAttr>:$host_tensor_args
);

let results = (outs Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$results);

let assemblyFormat = [{
$execution_context `stream` `(` $stream `)` ` `
(`host_tensor_args` $host_tensor_args^ ` ` )?
`(` $inputs `)`
attr-dict `:` functional-type($inputs, $results)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
/// Return true if the operand is a host tensor argument.
bool isOperandOnHost(OpOperand *operand) {
unsigned operandIdx = operand->getOperandNumber();
if(std::optional<ArrayRef<int64_t>> indices = getHostTensorArgs()) {
return llvm::is_contained(*indices, operandIdx - 2);
}
return false;
}
}];
}

#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTRUNTIMEOPS_TD
12 changes: 6 additions & 6 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
"calculations (but they can still be inputs)"));
addOption(
"use-non-dps-call-conv", useNonDPSCallConv, llvm::cl::init(false),
"enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false),
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"));
addOption("executor-index-bitwidth", executorIndexBitwidth,
Expand Down Expand Up @@ -307,7 +307,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
plan::StablehloClusteringPassOptions clusteringOpts{};
clusteringOpts.disallowHostTensorsInTensorRTClusters =
opts.disallowHostTensorsInTensorRTClusters;
clusteringOpts.useNonDPSCallConv = opts.useNonDPSCallConv;
clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
clusteringOpts.entrypoint = opts.entrypoint;
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);

Expand Down Expand Up @@ -342,7 +342,7 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(
// Perform bufferization.
pm.addPass(createMemRefCastEliminationPass());
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
allocTensorsOpts.useNonDPSCallConv = opts.useNonDPSCallConv;
allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
Expand Down Expand Up @@ -532,8 +532,8 @@ struct ClusteringPipelineCliOpts
*this, "device-compute-capability",
llvm::cl::desc("target device compute capability (SM version)"),
llvm::cl::init(60)};
Option<bool> useNonDPSCallConv{
*this, "use-non-dps-call-conv",
Option<bool> enableNonDPSReturns{
*this, "enable-non-dps-returns",
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"),
llvm::cl::init(false)};
Expand Down Expand Up @@ -564,7 +564,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
opts.deviceMaxSharedMemoryPerBlockKb =
cliOpts.deviceMaxSharedMemoryPerBlockKb;
opts.useNonDPSCallConv = cliOpts.useNonDPSCallConv;
opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns;
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
opts.entrypoint = cliOpts.entrypoint;
return opts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ convertCallOp(Operation *op, IRRewriter &rewriter,
SmallVector<int64_t> hostTensorArgs;
for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) {
const TensorKindLattice *kind = solver.lookupState<TensorKindLattice>(arg);
if (!isa<RankedTensorType>(arg.getType()))
continue;
RankedTensorType rtt = cast<RankedTensorType>(arg.getType());
// To be conservative, we only do this if type is i32 and num elements
// <= 8.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern<plan::WithValuesOp> {
} // namespace

/// Get a map from `tensorrt.func` functions to associated `tensorrt.call`
/// operations.
static llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>>
/// and `tensorrt.call_alloc` operations.
static llvm::DenseMap<func::FuncOp, SmallVector<Operation *>>
getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>> map;
op->walk([&](tensorrt::CallOp callOp) {
func::FuncOp func = callOp.getFuncCallee(collection);
if (map.contains(func)) {
map[func].push_back(callOp);
llvm::DenseMap<func::FuncOp, SmallVector<Operation *>> map;
op->walk([&](Operation *callOp) {
if (!isa<tensorrt::CallOp, tensorrt::CallAllocOp>(callOp))
return;
}
map.insert(std::make_pair(func, SmallVector<tensorrt::CallOp>{callOp}));

func::FuncOp func;
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
func = call.getFuncCallee(collection);
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
func = callAlloc.getFuncCallee(collection);
else
return;

if (map.count(func))
map[func].push_back(callOp);
else
map.insert({func, SmallVector<Operation *>{callOp}});
});
return map;
}
Expand All @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
/// `tensorrt.call` operations.
static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
ModuleOp op, func::FuncOp funcOp,
ArrayRef<tensorrt::CallOp> callOps) {
ArrayRef<Operation *> callOps) {
llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0);
for (BlockArgument arg : funcOp.getArguments()) {
if (arg.use_empty())
Expand All @@ -99,8 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
funcOp.eraseArgument(i);

// Update the call ops.
for (tensorrt::CallOp callOp : callOps)
callOp.getInputsMutable().erase(i);
for (Operation *callOp : callOps) {
if (auto call = dyn_cast<tensorrt::CallOp>(callOp)) {
call.getInputsMutable().erase(i);
} else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp)) {
callAlloc.getInputsMutable().erase(i);
} else {
llvm::errs() << "Unexpected operation type in callOps\n";
callOp->dump();
}
}
}

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,82 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,

static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
plan::InlineClosedAllocGroupOp op) {
return op.emitError("outlinining inline closed alloc group ops to tensorrt "
"dialect is not yet implemented");
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));
FailureOr<FunctionOpInterface> func = createOutlinedFunc(
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
"cluster.tensorrt", TypeRange(op.getInputs()),
op.getYield()->getOperandTypes());
if (failed(func))
return failure();
assert(func->getFunctionBody().getBlocks().size() == 1 &&
"expected body with one block");
func->setPublic();

rewriter.setInsertionPoint(op);

auto callOp = rewriter.create<tensorrt::CallAllocOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(*func)}));

// Populate the function arguments attributes.
for (unsigned i = 0; i < (*func).getNumArguments(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getInputAttrs()[i]);
// We may have scalar (index|signless int)-typed values since we haven't
// eliminated `plan.(with_shape|with_values)` ops yet.
if (!op.argHasTensorType(i) || srcAttr.isNone())
continue;
FailureOr<tensorrt::ShapeProfileAttr> boundAttr =
getTensorRTShapeProfile(srcAttr, op.getInputs()[i]);
if (failed(boundAttr))
return op->emitOpError("failed to create TensorRT shape profile "
"attribute from Plan BoundsAttr for argument #")
<< i << " (" << srcAttr << ")";
if (srcAttr.isShapeBound()) {
func->setArgAttr(i,
tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
*boundAttr);
continue;
}
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
func->setArgAttr(
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
*boundAttr);
func->setArgAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}

// Populate the function entry block.
rewriter.eraseBlock(&func->getFunctionBody().front());

// Move private decomposition funcs associated with all `stablehlo.composite`
// ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op
// has its own symbol table.
SymbolTableCollection symbolTable;
for (auto compositeOp : op.getBody().getOps<stablehlo::CompositeOp>()) {
auto decompositionFunc = dyn_cast_if_present<func::FuncOp>(
symbolTable.lookupSymbolIn(op->getParentOfType<ModuleOp>(),
compositeOp.getDecompositionAttr()));
if (!decompositionFunc)
return emitError(compositeOp.getLoc())
<< "failed to lookup stablehlo.composite decomposition "
"function: "
<< compositeOp.getDecompositionAttr();
rewriter.moveOpAfter(decompositionFunc, func->getOperation());
}

// Move region op operations to the func body.
Operation *regionYieldOp = op.getYield();
rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(),
func->getFunctionBody().end());
rewriter.setInsertionPoint(regionYieldOp);
rewriter.replaceOpWithNewOp<func::ReturnOp>(regionYieldOp,
regionYieldOp->getOperands());

// replace the original region results.
rewriter.replaceOp(op, callOp);
return success();
}

/// Create outlined functions for each `scf.execute_region` operation within
Expand Down
11 changes: 6 additions & 5 deletions mlir-tensorrt/executor/lib/Runtime/API/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ void AllocTracker::incrementExternalCount(uintptr_t ptr) {
llvm::formatv("Untracked pointer {0}", ptr).str().c_str());
std::unique_ptr<Metadata> const &metadata = map.at(ptr);
int32_t ref = ++metadata->externalReferenceCount;
MTRT_DBG("Incremented external reference for pointer %d to %d", ptr, ref);
MTRT_DBGF("Incremented external reference for pointer 0x%lx to %d", ptr, ref);
}

void AllocTracker::decrementExternalCount(uintptr_t ptr) {
Expand All @@ -422,11 +422,12 @@ void AllocTracker::decrementExternalCount(uintptr_t ptr) {
llvm::formatv("External reference count cannot be negative: {0}", ref)
.str()
.c_str());
MTRT_DBG("Decremented external reference for pointer %d to %d", ptr, ref);
MTRT_DBGF("Decremented external reference for pointer 0x%lx to %d", ptr, ref);
if (ref == 0 && metadata->releasedInternally) {
MTRT_DBG("External reference to an internally released pointer %d is 0, "
"try deallocating pointer memory of size %lu",
ptr, ref, metadata->info.size);
MTRT_DBGF(
"External reference to an internally released pointer 0x%lx is 0, "
"try deallocating pointer memory of size %lu",
ptr, metadata->info.size);
Status s = safeDeallocate(*this, metadata->info.ptr);
if (!s.isOk())
MTRT_DBGF("error while deallocating dangling memory: %s",
Expand Down
32 changes: 19 additions & 13 deletions mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,30 +502,34 @@ static Status validateArgsTypesAgainstFuncArgs(const RuntimeValue *runArg,
return getOkStatus();
}

static constexpr int MEMREF_FIXED_FIELDS = 3; // allocPtr, alignedPtr, offset
[[maybe_unused]] static constexpr int MEMREF_FIXED_FIELDS =
3; // allocPtr, alignedPtr, offset

// MemRefTableReader encapsulates the logic for reading MemRef data from a Lua
// table
class MemRefTableReader {
public:
MemRefTableReader(const sol::protected_function_result &pfr,
MemRefTableReader(const sol::protected_function_result &pfr, int resultIndex,
impl::CallingConvention conv)
: mPfr(pfr), mConv(conv), mIndex(1) {
// Currently, we only support unpacked calling convention
assert(mConv == CallingConvention::unpacked &&
"Only unpacked calling convention is supported");

// Assume result is always a memref.
sol::object obj = mPfr[resultIndex];
assert(obj.is<sol::table>() && "Expected a table for MemRefValue");
mMemRefTable = obj.as<sol::table>();
}

// Retrieves the next value of type T from the MemRef table
// This method advances the internal index automatically
template <typename T>
T getNextValue() {
sol::object obj = mPfr[0];
assert(obj.is<sol::table>() && "Expected a table for MemRefValue");
sol::table memRefTable = obj.as<sol::table>();
return memRefTable.get<T>(mIndex++);
return mMemRefTable.get<T>(mIndex++);
}

// TODO: This may not be required since each pfr index stores a memref.
// Moves to the next MemRef in the table
// This is called after processing all data for the current MemRef
void nextMemRef(int offset) {
Expand All @@ -537,6 +541,7 @@ class MemRefTableReader {
private:
const sol::protected_function_result &mPfr;
impl::CallingConvention mConv;
sol::table mMemRefTable;
int mIndex;
};

Expand Down Expand Up @@ -571,9 +576,10 @@ parseResults(const sol::protected_function_result &pfr,
const FunctionSignatureView &sig,
std::optional<RuntimeClient *> client) {
llvm::SmallVector<std::unique_ptr<RuntimeValue>> results;
MemRefTableReader reader(pfr, sig.getCConv());

for (unsigned i = 0; i < sig.getNumResults(); ++i) {

MemRefTableReader reader(pfr, i, sig.getCConv());

if (sig.getResult(i).isa<ScalarTypeView>()) {
auto scalar = getScalarValue(pfr, i, sig);
if (!scalar.isOk())
Expand Down Expand Up @@ -607,16 +613,16 @@ parseResults(const sol::protected_function_result &pfr,
return getInvalidArgStatus("Runtime client cannot be nullptr");

// Create MemRefValue from extracted data
auto memref = MemRefValue::create(
*client, resultView.getAddressSpace(),
resultView.getElementType().getBitWidth(), allocPtr, offset, shape,
strides, (*client)->getDevices()[0].get(), resultView.getElementType());

auto memref = (*client)->createExternalMemRef(
resultView.getAddressSpace(), resultView.getElementType().getBitWidth(),
allocPtr, offset, shape, strides, (*client)->getDevices()[0].get(),
resultView.getElementType());

if (!memref.isOk())
return memref.getStatus();

results.push_back(std::move(*memref));
reader.nextMemRef(MEMREF_FIXED_FIELDS + rank * 2);
}

return results;
Expand Down
Loading

0 comments on commit cfeefc6

Please sign in to comment.