Skip to content

Commit

Permalink
Add initial IR for alloc enqueue
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Oct 9, 2024
1 parent fe1ed9b commit 0a8dc8e
Show file tree
Hide file tree
Showing 20 changed files with 620 additions and 73 deletions.
2 changes: 2 additions & 0 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,8 @@ StableHloToExecutableTask::compileStableHLOToExecutable(
runner = pm.get();
}

runner->printAsTextualPipeline(llvm::dbgs());

// Setup pass manager
if (failed(runner->run(module)))
return getInternalErrorStatus(
Expand Down
30 changes: 24 additions & 6 deletions mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) {
return !client.ptr;
}

/// Returns null client.
static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() {
return MTRT_RuntimeClient{nullptr};
}

/// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the
/// program execution.
/// The `stream` passed to the client is used by all underlying CUDA methods
Expand Down Expand Up @@ -291,6 +296,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) {
return !value.ptr;
}

// Returns whether the RuntimeValue is MemRef.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value);

// Returns whether the RuntimeValue is Scalar.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value);

/// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue.
MLIR_CAPI_EXPORTED MTRT_RuntimeValue
mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref);
Expand Down Expand Up @@ -375,15 +386,22 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) {
}

/// Using `session`, execute the pubic function with the specified name.
/// The `inArgs` and `outArgs` are arrays for input arguments and destination
/// arguments, respectively. Input arguments may be MemRefs or scalars, but
/// destination arguments must be MemRefs.
/// A stream may optionally be specified, otherwise pass the result of
/// `mtrtStreamGetNull()`.
/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments and
/// optional destination arguments and results, respectively. Input arguments
/// may be MemRefs or scalars, but destination arguments and results must be
/// MemRefs. If `outArgs` are present, it is expected that `results` are empty
/// and vice-versa. A stream may optionally be specified, otherwise pass the
/// result of `mtrtStreamGetNull()`.
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream);
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, size_t numResults, MTRT_Stream stream,
MTRT_RuntimeClient client);

MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNbResults(MTRT_RuntimeSession session,
MTRT_StringView name,
int64_t *numResults);

//===----------------------------------------------------------------------===//
// DLPack
Expand Down
11 changes: 7 additions & 4 deletions mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ std::string_view stringifyPointerType(PointerType ptrType);
//===----------------------------------------------------------------------===//
// TypeView
// This section includes classes that form the TypeUnion:
// MemrefTypeView, ScalarTypeView, ExternalOpaqueTypeView
// MemRefTypeView, ScalarTypeView, ExternalOpaqueTypeView
//===----------------------------------------------------------------------===//

/// Base class for all the below classes that provide flatbuffer-view wrappers
Expand Down Expand Up @@ -155,10 +155,10 @@ class ExternalOpaqueTypeView
/// A wrapper around `impl::MemRefTypeT` to provide additional convenience
/// utilities. It does not own any memory; it only
// provides a read-only view into the buffer.
class MemrefTypeView : public FlatbufferTypeObjectView<impl::MemRefType,
class MemRefTypeView : public FlatbufferTypeObjectView<impl::MemRefType,
impl::Type::MemRefType> {
public:
MemrefTypeView(const impl::MemRefType *view)
MemRefTypeView(const impl::MemRefType *view)
: FlatbufferTypeObjectView(view) {}

int64_t getRank() const { return view->shape()->size(); }
Expand Down Expand Up @@ -877,6 +877,8 @@ class RuntimeSession {

ResourceTracker &getResourceTracker() { return *resourceTracker; }

OutputAllocatorTracker &getOutputAllocatorTracker() { return *outputAllocatorTracker; }

/// Returns the options used to construct the session.
const RuntimeSessionOptions &getOptions() { return options; }

Expand All @@ -888,6 +890,7 @@ class RuntimeSession {
std::unique_ptr<PinnedMemoryAllocator> pinnedMemoryAllocator;
std::unique_ptr<AllocTracker> allocTracker;
std::unique_ptr<ResourceTracker> resourceTracker;
std::unique_ptr<OutputAllocatorTracker> outputAllocatorTracker;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -988,7 +991,7 @@ llvm::raw_ostream &print(llvm::raw_ostream &os, const Executable &exe);
/// Print a text summary of the constant to the stream.
llvm::raw_ostream &print(llvm::raw_ostream &os, const ConstantView &constant);
/// Print a text summary of the type to the stream.
llvm::raw_ostream &print(llvm::raw_ostream &os, const MemrefTypeView &type);
llvm::raw_ostream &print(llvm::raw_ostream &os, const MemRefTypeView &type);
/// Print a text summary of the type to the stream.
llvm::raw_ostream &print(llvm::raw_ostream &os, const ScalarTypeView &type);
/// Print a text summary of the type to the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ void registerLuaRuntimeMethods(lua_State *state,
const RuntimeSessionOptions &options,
PinnedMemoryAllocator *pinnedMemoryAllocator,
AllocTracker *allocTracker,
ResourceTracker *resourceTracker);
ResourceTracker *resourceTracker, OutputAllocatorTracker *outputAllocatorTracker);

} // namespace mlirtrt::runtime
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,20 @@ StatusOr<int64_t> runExecutorExecutable(
LuaRuntimeSession::LuaModuleRegistrationFunc registerExtraLuaFuncs = {});

/// Execute a named function in the session with the specified input args and
/// output (destination args). Returns any results.
/// output (destination args). Returns optional results.
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name,
llvm::ArrayRef<RuntimeValue *> inputArgs,
llvm::ArrayRef<RuntimeValue *> outputArgs,
std::optional<CudaStream> stream = {});
std::optional<CudaStream> stream = {},
std::optional<RuntimeClient* > client = {});

/// Execute a named function in the session with the specified input args and return results.
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
executeFunctionWithResultWithLuaBackend(
LuaRuntimeSession &session, RuntimeClient &client, std::string_view name,
llvm::ArrayRef<RuntimeValue *> inputArgs,
std::optional<CudaStream> stream = {});

} // namespace mlirtrt::runtime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ResourceTracker;
/// Lua state.
void registerExecutorTensorRTModuleLuaRuntimeMethods(
lua_State *luaState, PinnedMemoryAllocator *pinnedMemoryAllocator,
AllocTracker *allocTracker, ResourceTracker *resourceTracker);
AllocTracker *allocTracker, ResourceTracker *resourceTracker, OutputAllocatorTracker *outputAllocatorTracker);

} // namespace mlirtrt::runtime

Expand Down
88 changes: 88 additions & 0 deletions mlir-tensorrt/executor/include/mlir-executor/Support/Allocators.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,94 @@ class PinnedMemoryAllocator {
std::unique_ptr<BlockEventQueue> pendingBlockEvents;
};

//===----------------------------------------------------------------------===//
// OutputAllocator and CustomTensorRTOuputAllocator
//===----------------------------------------------------------------------===//

//!
//! Class to allocate memory for outputs with data-dependent shapes. The sizes
//! of those are unknown so pre-allocation is not possible.
//!
class OutputAllocator {
public:
virtual ~OutputAllocator() = default;
virtual void setTensorName(const char *tensorName) = 0;
virtual void setCurrentMemory(void *currentMemory) = 0;
virtual void setOutputSize(const int64_t outputSize) = 0;
virtual void *reallocateOutputAsync(char const *tensorName,
void *currentMemory, uint64_t size,
uint64_t alignment,
CudaStream /*stream*/) = 0;
virtual void notifyShape(char const *tensorName, const int64_t *dims,
int64_t nbDims) = 0;
};


class CustomTensorRTOuputAllocator : public OutputAllocator {
public:
CustomTensorRTOuputAllocator() = default;
~CustomTensorRTOuputAllocator();

//! Methods are called just after construction. TODO: can they be called
//! during construction?
void setTensorName(const char *tensorName) override {
mTensorName = tensorName;
}

void setCurrentMemory(void *currentMemory) override {
mCurrentMemory = currentMemory;
}

void setOutputSize(int64_t outputSize) override { mOutputSize = outputSize; }

void *reallocateOutputAsync(char const *tensorName, void *currentMemory,
uint64_t size, uint64_t alignment,
CudaStream /*stream*/) override;

void notifyShape(char const *tensorName, const int64_t *dims,
int64_t nbDims) override;

//! nullptr if memory could not be allocated
void *mOutputPtr{nullptr};

//! Size of allocation pointed to by output.
uint64_t mOutputSize{0};

bool mReallocateOutputCalled{false};

bool mNotifyShapeCalled{false};

//! Dimensions of tensor.
std::vector<int64_t> mOutputDims;

private:
const char *mTensorName;
void *mCurrentMemory;
};

class OutputAllocatorTracker {
public:
OutputAllocatorTracker() = default;
~OutputAllocatorTracker() = default;

OutputAllocatorTracker(const OutputAllocatorTracker &) = delete;
OutputAllocatorTracker &operator=(const OutputAllocatorTracker &) = delete;
OutputAllocatorTracker(OutputAllocatorTracker &&) = default;
OutputAllocatorTracker &operator=(OutputAllocatorTracker &&) = default;

// Track a new OutputAllocator
void track(std::unique_ptr<OutputAllocator> allocator) {
mOutputAllocatorRegistry.emplace_back(std::move(allocator));
}

OutputAllocator *get(int64_t index) {
return mOutputAllocatorRegistry[index].get();
}

private:
std::vector<std::unique_ptr<OutputAllocator>> mOutputAllocatorRegistry;
};

} // namespace mlirtrt

#endif // MLIR_TENSORRT_SUPPORT_ALLOCATORS_H
4 changes: 2 additions & 2 deletions mlir-tensorrt/executor/lib/CAPI/Common/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ MTRT_Status getTypeHelper(TypeUnionView typeUnionView, MTRT_Type *type) {
// concrete object, and release it to be owned by the CAPI object.
auto typeUnion = std::make_unique<impl::TypeUnion>();
// Extract the correct type.
if (typeUnionView.isa<MemrefTypeView>()) {
auto memrefView = typeUnionView.get<MemrefTypeView>();
if (typeUnionView.isa<MemRefTypeView>()) {
auto memrefView = typeUnionView.get<MemRefTypeView>();
impl::MemRefTypeT memref;
memref.shape = memrefView.getShape();
memref.strides = memrefView.getStrides();
Expand Down
41 changes: 35 additions & 6 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
return wrap(static_cast<ScalarValue *>(x));
}

bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::MemRef;
}

bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::Scalar;
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeSessionOptions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -691,7 +701,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) {
MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) {
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, size_t numResults, MTRT_Stream stream, MTRT_RuntimeClient client) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));

Expand All @@ -701,19 +712,37 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction(
llvm::SmallVector<RuntimeValue *> outArgValues =
llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs),
[](MTRT_RuntimeValue arg) { return unwrap(arg); });

StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> result =
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> resultValues =
executeFunctionWithLuaBackend(
*cppSession, std::string_view(name.data, name.length), inArgValues,
outArgValues,
!mtrtStreamIsNull(stream)
? std::optional(unwrap(stream)->getRawStream())
: std::nullopt);
if (!result.isOk())
return wrap(result.getStatus());
: std::nullopt,
!mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client))
: std::nullopt);
if (!resultValues.isOk())
return wrap(resultValues.getStatus());

assert(resultValues->size() == numResults);
for (size_t i = 0; i < numResults; ++i)
results[i] = wrap((*resultValues)[i].release());

return mtrtStatusGetOk();
}

MTRT_Status mtrtRuntimeSessionGetNbResults(MTRT_RuntimeSession session,
MTRT_StringView name,
int64_t *numResults) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));
*numResults = cppSession->getExecutable()
.getFunction(std::string_view(name.data, name.length))
.getSignature()
.getNumResults();
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeClient
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 6 additions & 5 deletions mlir-tensorrt/executor/lib/Runtime/API/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,8 @@ RuntimeSession::RuntimeSession(RuntimeSessionOptions options,
: options(std::move(options)), executable(exe),
pinnedMemoryAllocator(std::make_unique<PinnedMemoryAllocator>()),
allocTracker(std::make_unique<AllocTracker>()),
resourceTracker(std::make_unique<ResourceTracker>()) {}
resourceTracker(std::make_unique<ResourceTracker>()),
outputAllocatorTracker(std::make_unique<OutputAllocatorTracker>()) {}

//===----------------------------------------------------------------------===//
// AllocTracker
Expand Down Expand Up @@ -750,7 +751,7 @@ StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
MemRefValue::MemRefValue(RuntimeClient *client,
mlirtrt::runtime::PointerType addressSpace,
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
llvm::ArrayRef<int64_t> shape,
llvm::ArrayRef<int64_t> shape,
llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device,
std::optional<ScalarType> scalarType)
Expand Down Expand Up @@ -1123,8 +1124,8 @@ static llvm::raw_ostream &squareBraces(llvm::raw_ostream &os, Callable c) {
}

llvm::raw_ostream &rt::print(llvm::raw_ostream &os, const TypeUnionView &arg) {
if (arg.isa<MemrefTypeView>())
return print(os, arg.get<MemrefTypeView>());
if (arg.isa<MemRefTypeView>())
return print(os, arg.get<MemRefTypeView>());
if (arg.isa<ScalarTypeView>())
return print(os, arg.get<ScalarTypeView>());
if (arg.isa<ExternalOpaqueTypeView>())
Expand Down Expand Up @@ -1180,7 +1181,7 @@ llvm::raw_ostream &rt::print(llvm::raw_ostream &os,
<< ">";
return os;
}
llvm::raw_ostream &rt::print(llvm::raw_ostream &os, const MemrefTypeView &exe) {
llvm::raw_ostream &rt::print(llvm::raw_ostream &os, const MemRefTypeView &exe) {

auto handleDimOrStride = [](llvm::raw_ostream &os, int64_t x) {
if (x != std::numeric_limits<int64_t>::min())
Expand Down
Loading

0 comments on commit 0a8dc8e

Please sign in to comment.