Skip to content

Commit

Permalink
Tripy changes for non-DPS
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Nov 4, 2024
1 parent d2a879c commit 4c7b5b8
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 150 deletions.
2 changes: 1 addition & 1 deletion tripy/tests/backend/api/test_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable):
assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
assert param.annotation == tp.Tensor

assert signature.return_annotation == tp.Tensor
assert signature.return_annotation == Sequence[tp.Tensor]

def test_signature_multiple_return_values(self, multiple_return_executable):
signature = inspect.signature(multiple_return_executable)
Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/frontend/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def test_no_explicit_cast(self):
"devices",
[
("cpu", "gpu"),
# TODO(#155)
# ("gpu", "cpu"),
("gpu", "cpu"),
],
)
def test_explicit_copy(self, devices):
Expand Down
15 changes: 8 additions & 7 deletions tripy/tests/integration/test_iota.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ def test_iota_like(self, dtype, shape, dim):

@pytest.mark.parametrize("dtype", DATA_TYPES.values())
def test_negative_no_casting(self, dtype):
from tripy.frontend.trace.ops.iota import Iota
with tp.logger.use_verbosity("ir"):
from tripy.frontend.trace.ops.iota import Iota

if dtype in [tp.float32, tp.int32, tp.int64]:
pytest.skip("tp.iota() supports float32, int32, and int64 without cast")
if dtype in [tp.float32, tp.int32, tp.int64]:
pytest.skip("tp.iota() supports float32, int32, and int64 without cast")

# TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint
a = tp.ones((2, 2))
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)
# TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint
a = tp.ones((2, 2))
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)

exception_str = "error: 'tensorrt.linspace' op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values"
exception_str = "InternalError: failed to run compilation on module with symbol name."
if dtype == tp.bool:
exception_str = "InternalError: failed to run compilation"
with helper.raises(
Expand Down
3 changes: 2 additions & 1 deletion tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,6 @@ def test_non_constant_scale(self):
input = tp.ones((4, 4))
scale = tp.ones((4,))
quantized = tp.quantize(input, scale, tp.int8, dim=0)
quantized_int32 = tp.cast(quantized, tp.int32)

assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
assert bool(tp.all(quantized_int32 == tp.ones((4, 4), dtype=tp.int32)))
1 change: 0 additions & 1 deletion tripy/tripy/backend/api/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,4 @@ def process_arg(name, arg):
return Executable(
executable,
compiled_arg_names,
output_devices=[out.device for out in trace.outputs],
)
50 changes: 29 additions & 21 deletions tripy/tripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import base64
import inspect
from typing import Sequence, Union
from typing import Sequence, Union, Tuple, Callable

import mlir_tensorrt.runtime.api as runtime

Expand All @@ -37,13 +37,11 @@ class Executable:
"""

# The constructor is intentionally undocumented because it is not meant to be called by users.
# TODO(#155): output_devices is not needed after they can be queried from executable
def __init__(self, executable, arg_names, output_devices):
def __init__(self, executable, arg_names):
self._executable = executable
self._executor = Executor(self._executable)
self._arg_names = arg_names
self._num_expected_args = len(arg_names)
self._output_devices = output_devices
self._executable_signature = self._executable.get_signature("main")

# Build a signature so the executable works with `inspect.signature`
Expand Down Expand Up @@ -128,7 +126,7 @@ def add(a, b):
tensor.eval()

try:
executor_outputs = self._executor.execute(self._output_devices, input_tensors)
executor_outputs = self._executor.execute(input_tensors)
except runtime.MTRTException as err:
# TODO: Evaluate whether this should be moved into the executor
if "function expects a memref type with element type" in str(err):
Expand Down Expand Up @@ -170,15 +168,22 @@ def add(a, b):
output_tensors = output_tensors[0]
return output_tensors

def _get_arg_info(self, idx):
arg = self._executable_signature.get_arg(idx)
arg = runtime.MemRefType(arg)
arg_bound = self._executable_signature.get_arg_bound(idx)
shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max()))
if len(shape_bounds) == 0:
# For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape
shape_bounds = tuple((x, x) for x in arg.shape)
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))
def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo:
item = runtime.MemRefType(get_item(idx))
bound = get_bound(idx)
shape_bounds = tuple(zip(bound.min(), bound.max()))

if not shape_bounds:
# For static shape, fallback to item.shape
shape_bounds = tuple((x, x) for x in item.shape)

return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype))

def _get_arg_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound)

def _get_result_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound)

def get_input_info(self) -> Sequence[ArgInfo]:
"""
Expand Down Expand Up @@ -221,11 +226,16 @@ def add(a, b):
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
print(compiled_add.get_output_info())
"""
output_info = []
offset = self._executable_signature.get_num_input_args()
for idx in range(self._executable_signature.get_num_output_args()):
output_info.append(self._get_arg_info(idx + offset))
return output_info
num_input_args = self._executable_signature.get_num_input_args()
num_output_args = self._executable_signature.get_num_output_args()
num_results = self._executable_signature.get_num_results()

assert not (num_output_args and num_results), "Cannot have both output arguments and results"

if num_output_args:
return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)]
else:
return [self._get_result_info(idx) for idx in range(num_results)]

def save(self, path: str) -> None:
"""
Expand Down Expand Up @@ -289,7 +299,6 @@ def add(a, b):
def encode_executable(executable):
return {
"arg_names": executable._arg_names,
"output_devices": executable._output_devices,
"executable": base64.b64encode(executable._executable.serialize()).decode(),
}

Expand All @@ -300,5 +309,4 @@ def decode_executable(executable_dict):
return Executable(
runtime.Executable(executable_bytes),
executable_dict["arg_names"],
executable_dict["output_devices"],
)
1 change: 1 addition & 0 deletions tripy/tripy/backend/mlir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level):
f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}",
f"--tensorrt-builder-opt-level={trt_builder_opt_level}",
"--tensorrt-strongly-typed=True",
"--enable-non-dps-returns",
]
if config.enable_mlir_debug or config.enable_tensorrt_debug:
opts.append("--debug=true")
Expand Down
121 changes: 7 additions & 114 deletions tripy/tripy/backend/mlir/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,89 +31,17 @@

class Executor:
def __init__(self, executable: runtime.Executable) -> None:

runtime.GlobalDebug.flag = True
debug_types = ["allocator", "runtime"]
runtime.GlobalDebug.set_types(debug_types)
self.runtime_client = MLIRRuntimeClient()
session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0)
self.session = runtime.RuntimeSession(session_options, executable)
self.device = self.runtime_client.get_devices()[0] # Assume a single device is available.
self.signature = executable.get_signature("main")
self.stream = default_stream()
self.num_input_args = self.signature.get_num_input_args()
self.num_output_args = self.signature.get_num_output_args()
self.output_args = [
self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args)
]
self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args]

def _create_shape_memref(self, shape):
shape = make_tuple(shape)
if len(shape) == 0:
return create_memref(
shape=(0,),
dtype=datatype.int64,
device=device("cpu"),
)
return create_memref(
array=convert_list_to_array(shape, datatype.int64),
shape=(len(shape),),
dtype=datatype.int64,
device=device("cpu"),
)

def _get_outputs_shape(self):
outputs_shape = []
all_outputs_known = True
for memref in self.output_memrefs:
outputs_shape.append(memref.shape)
all_outputs_known &= all(dim >= 0 for dim in memref.shape)
return outputs_shape, all_outputs_known

def _get_inputs_runtime_shape(self, inputs):
inputs_shape = []
for input in inputs:
inputs_shape.append(input.trace_tensor.producer.data.shape)
return inputs_shape

def _execute_shape_inference(self, inputs_shape, outputs_shape):
inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape]
outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape]
self.session.execute_function(
name=self.signature.get_shape_func_name(), in_args=inputs_shape_memref, out_args=outputs_shape_memref
)

outputs_runtime_shape = [memoryview(s).tolist() for s in outputs_shape_memref]
return outputs_runtime_shape

def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
outputs_tensor_info = []
for index in range(self.num_output_args):
memref = self.output_memrefs[index]
dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype)

output_device = output_devices[index]
if not output_device:
output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0))

runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])]
outputs_tensor_info.append(
TensorInfo(
len(runtime_shape),
tuple(runtime_shape),
dtype,
output_device,
)
)
return outputs_tensor_info

def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]):
outputs_shape, all_outputs_known = self._get_outputs_shape()
if not all_outputs_known:
inputs_shape = self._get_inputs_runtime_shape(inputs)
outputs_shape = self._execute_shape_inference(inputs_shape, outputs_shape)
output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices)
return output_tensor_info

def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
def execute(self, inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
in_args = []
for inp in inputs:
memref = inp.trace_tensor.producer.data
Expand All @@ -131,45 +59,10 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) ->
)
in_args.append(memref)

# HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
out_tensor_info = self.get_output_tensor_runtime_info(inputs, output_devices)

# Allocate output memory and store buffer pointers.
outputs = [
create_memref(
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
)
for info in out_tensor_info
]

out_args = []
for out in outputs:
memref = out
# HACK (#155): MLIR-TensorRT requires inputs to be on device.
# Remove explicit copy to device once #155 is addressed.
if memref.address_space != runtime.PointerType.device:
memref = self.runtime_client.copy_to_device(
host_memref=memref,
device=self.runtime_client.get_devices()[0],
stream=self.stream._active_cuda_stream,
)
if not memref:
raise_error("Could not allocate output memref", details=memref.error_details)
out_args.append(memref)

# Execute and populate device pointers.
self.session.execute_function(
"main", in_args=in_args, out_args=out_args, stream=self.stream._active_cuda_stream
outputs = self.session.execute_function(
"main", in_args=in_args, stream=self.stream._active_cuda_stream, client=self.runtime_client
)

# For outputs that were on the host, do the copy back
# TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
for idx, out_info in enumerate(out_tensor_info):
if out_info.device.kind != "gpu":
self.runtime_client.copy_to_host(
device_memref=out_args[idx],
existing_host_memref=outputs[idx],
stream=self.stream._active_cuda_stream,
)

# For now return results on GPU.
return outputs
9 changes: 9 additions & 0 deletions tripy/tripy/flat_ir/ops/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class CopyOp(BaseFlatIROp):

target: tripy.common.device

def set_memory_space_attr(self, tensor, mem_space_attr):
current_type = tensor.type
# Set the encoding attribute on the operation's result
new_type = ir.RankedTensorType.get(current_type.shape, current_type.element_type, encoding=mem_space_attr)
tensor.set_type(new_type)

def to_mlir(self, operands):
from mlir_tensorrt.compiler.dialects import bufferization, tensor, arith

Expand All @@ -46,7 +52,10 @@ def to_mlir(self, operands):
sliced_dims.append(dim)

alloc_tensor = bufferization.alloc_tensor(inp_type, sliced_dims, memory_space=mem_space_attr)
self.set_memory_space_attr(alloc_tensor, mem_space_attr)
result_tensor = bufferization.materialize_in_destination(inp_type, operands[0], alloc_tensor)
self.set_memory_space_attr(result_tensor, mem_space_attr)
cast_tensor = tensor.cast(self.outputs[0].to_mlir(), result_tensor)
self.set_memory_space_attr(cast_tensor, mem_space_attr)

return [cast_tensor]
6 changes: 3 additions & 3 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def eval(self) -> runtime.MemRefValue:

compiler = Compiler(trt_builder_opt_level=0)
executable = compiler.compile(mlir, flat_ir=flat_ir)
executor = Executor(executable)
self.executor = Executor(executable)
# Upon computing the value of this tensor, we switch it to have a `Storage`
# parameter so that it does not need to be computed again.
data = executor.execute([out.device for out in flat_ir.outputs])
executor.stream.synchronize()
data = self.executor.execute()
self.executor.stream.synchronize()
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
data = data[0]

Expand Down

0 comments on commit 4c7b5b8

Please sign in to comment.