diff --git a/tripy/tests/backend/api/test_executable.py b/tripy/tests/backend/api/test_executable.py index 3b588b46..40ac8d01 100644 --- a/tripy/tests/backend/api/test_executable.py +++ b/tripy/tests/backend/api/test_executable.py @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable): assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD assert param.annotation == tp.Tensor - assert signature.return_annotation == tp.Tensor + assert signature.return_annotation == Sequence[tp.Tensor] def test_signature_multiple_return_values(self, multiple_return_executable): signature = inspect.signature(multiple_return_executable) diff --git a/tripy/tests/frontend/test_tensor.py b/tripy/tests/frontend/test_tensor.py index f0a21ded..837ec8af 100644 --- a/tripy/tests/frontend/test_tensor.py +++ b/tripy/tests/frontend/test_tensor.py @@ -226,8 +226,7 @@ def test_no_explicit_cast(self): "devices", [ ("cpu", "gpu"), - # TODO(#155) - # ("gpu", "cpu"), + ("gpu", "cpu"), ], ) def test_explicit_copy(self, devices): diff --git a/tripy/tests/integration/test_iota.py b/tripy/tests/integration/test_iota.py index 39df3578..48094c88 100644 --- a/tripy/tests/integration/test_iota.py +++ b/tripy/tests/integration/test_iota.py @@ -82,16 +82,17 @@ def test_iota_like(self, dtype, shape, dim): @pytest.mark.parametrize("dtype", DATA_TYPES.values()) def test_negative_no_casting(self, dtype): - from tripy.frontend.trace.ops.iota import Iota + with tp.logger.use_verbosity("ir"): + from tripy.frontend.trace.ops.iota import Iota - if dtype in [tp.float32, tp.int32, tp.int64]: - pytest.skip("tp.iota() supports float32, int32, and int64 without cast") + if dtype in [tp.float32, tp.int32, tp.int64]: + pytest.skip("tp.iota() supports float32, int32, and int64 without cast") - # TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint - a = tp.ones((2, 2)) - out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype) + # TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint + a = tp.ones((2, 2)) + out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype) - exception_str = "error: 'tensorrt.linspace' op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values" + exception_str = "InternalError: failed to run compilation on module with symbol name." if dtype == tp.bool: exception_str = "InternalError: failed to run compilation" with helper.raises( diff --git a/tripy/tests/integration/test_quantize.py b/tripy/tests/integration/test_quantize.py index b5029386..9bf96be9 100644 --- a/tripy/tests/integration/test_quantize.py +++ b/tripy/tests/integration/test_quantize.py @@ -117,5 +117,6 @@ def test_non_constant_scale(self): input = tp.ones((4, 4)) scale = tp.ones((4,)) quantized = tp.quantize(input, scale, tp.int8, dim=0) + quantized_int32 = tp.cast(quantized, tp.int32) - assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8))) + assert bool(tp.all(quantized_int32 == tp.ones((4, 4), dtype=tp.int32))) diff --git a/tripy/tripy/backend/api/compile.py b/tripy/tripy/backend/api/compile.py index 1f631580..0491f861 100644 --- a/tripy/tripy/backend/api/compile.py +++ b/tripy/tripy/backend/api/compile.py @@ -196,5 +196,4 @@ def process_arg(name, arg): return Executable( executable, compiled_arg_names, - output_devices=[out.device for out in trace.outputs], ) diff --git a/tripy/tripy/backend/api/executable.py b/tripy/tripy/backend/api/executable.py index 33347b31..8c977153 100644 --- a/tripy/tripy/backend/api/executable.py +++ b/tripy/tripy/backend/api/executable.py @@ -14,7 +14,7 @@ # limitations under the License. import base64 import inspect -from typing import Sequence, Union +from typing import Sequence, Union, Tuple, Callable import mlir_tensorrt.runtime.api as runtime @@ -37,13 +37,11 @@ class Executable: """ # The constructor is intentionally undocumented because it is not meant to be called by users. - # TODO(#155): output_devices is not needed after they can be queried from executable - def __init__(self, executable, arg_names, output_devices): + def __init__(self, executable, arg_names): self._executable = executable self._executor = Executor(self._executable) self._arg_names = arg_names self._num_expected_args = len(arg_names) - self._output_devices = output_devices self._executable_signature = self._executable.get_signature("main") # Build a signature so the executable works with `inspect.signature` @@ -128,7 +126,7 @@ def add(a, b): tensor.eval() try: - executor_outputs = self._executor.execute(self._output_devices, input_tensors) + executor_outputs = self._executor.execute(input_tensors) except runtime.MTRTException as err: # TODO: Evaluate whether this should be moved into the executor if "function expects a memref type with element type" in str(err): @@ -170,15 +168,22 @@ def add(a, b): output_tensors = output_tensors[0] return output_tensors - def _get_arg_info(self, idx): - arg = self._executable_signature.get_arg(idx) - arg = runtime.MemRefType(arg) - arg_bound = self._executable_signature.get_arg_bound(idx) - shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max())) - if len(shape_bounds) == 0: - # For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape - shape_bounds = tuple((x, x) for x in arg.shape) - return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype)) + def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo: + item = runtime.MemRefType(get_item(idx)) + bound = get_bound(idx) + shape_bounds = tuple(zip(bound.min(), bound.max())) + + if not shape_bounds: + # For static shape, fallback to item.shape + shape_bounds = tuple((x, x) for x in item.shape) + + return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype)) + + def _get_arg_info(self, idx: int) -> ArgInfo: + return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound) + + def _get_result_info(self, idx: int) -> ArgInfo: + return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound) def get_input_info(self) -> Sequence[ArgInfo]: """ @@ -221,11 +226,16 @@ def add(a, b): compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)]) print(compiled_add.get_output_info()) """ - output_info = [] - offset = self._executable_signature.get_num_input_args() - for idx in range(self._executable_signature.get_num_output_args()): - output_info.append(self._get_arg_info(idx + offset)) - return output_info + num_input_args = self._executable_signature.get_num_input_args() + num_output_args = self._executable_signature.get_num_output_args() + num_results = self._executable_signature.get_num_results() + + assert not (num_output_args and num_results), "Cannot have both output arguments and results" + + if num_output_args: + return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)] + else: + return [self._get_result_info(idx) for idx in range(num_results)] def save(self, path: str) -> None: """ @@ -289,7 +299,6 @@ def add(a, b): def encode_executable(executable): return { "arg_names": executable._arg_names, - "output_devices": executable._output_devices, "executable": base64.b64encode(executable._executable.serialize()).decode(), } @@ -300,5 +309,4 @@ def decode_executable(executable_dict): return Executable( runtime.Executable(executable_bytes), executable_dict["arg_names"], - executable_dict["output_devices"], ) diff --git a/tripy/tripy/backend/mlir/compiler.py b/tripy/tripy/backend/mlir/compiler.py index 1874e893..517b978d 100644 --- a/tripy/tripy/backend/mlir/compiler.py +++ b/tripy/tripy/backend/mlir/compiler.py @@ -58,6 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level): f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}", f"--tensorrt-builder-opt-level={trt_builder_opt_level}", "--tensorrt-strongly-typed=True", + "--enable-non-dps-returns", ] if config.enable_mlir_debug or config.enable_tensorrt_debug: opts.append("--debug=true") diff --git a/tripy/tripy/backend/mlir/executor.py b/tripy/tripy/backend/mlir/executor.py index b03c507f..447ad2cf 100644 --- a/tripy/tripy/backend/mlir/executor.py +++ b/tripy/tripy/backend/mlir/executor.py @@ -31,89 +31,17 @@ class Executor: def __init__(self, executable: runtime.Executable) -> None: - + runtime.GlobalDebug.flag = True + debug_types = ["allocator", "runtime"] + runtime.GlobalDebug.set_types(debug_types) self.runtime_client = MLIRRuntimeClient() session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) self.session = runtime.RuntimeSession(session_options, executable) self.device = self.runtime_client.get_devices()[0] # Assume a single device is available. self.signature = executable.get_signature("main") self.stream = default_stream() - self.num_input_args = self.signature.get_num_input_args() - self.num_output_args = self.signature.get_num_output_args() - self.output_args = [ - self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args) - ] - self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args] - - def _create_shape_memref(self, shape): - shape = make_tuple(shape) - if len(shape) == 0: - return create_memref( - shape=(0,), - dtype=datatype.int64, - device=device("cpu"), - ) - return create_memref( - array=convert_list_to_array(shape, datatype.int64), - shape=(len(shape),), - dtype=datatype.int64, - device=device("cpu"), - ) - - def _get_outputs_shape(self): - outputs_shape = [] - all_outputs_known = True - for memref in self.output_memrefs: - outputs_shape.append(memref.shape) - all_outputs_known &= all(dim >= 0 for dim in memref.shape) - return outputs_shape, all_outputs_known - - def _get_inputs_runtime_shape(self, inputs): - inputs_shape = [] - for input in inputs: - inputs_shape.append(input.trace_tensor.producer.data.shape) - return inputs_shape - - def _execute_shape_inference(self, inputs_shape, outputs_shape): - inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape] - outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape] - self.session.execute_function( - name=self.signature.get_shape_func_name(), in_args=inputs_shape_memref, out_args=outputs_shape_memref - ) - - outputs_runtime_shape = [memoryview(s).tolist() for s in outputs_shape_memref] - return outputs_runtime_shape - - def _get_output_tensor_info(self, outputs_runtime_shape, output_devices): - outputs_tensor_info = [] - for index in range(self.num_output_args): - memref = self.output_memrefs[index] - dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype) - - output_device = output_devices[index] - if not output_device: - output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0)) - - runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])] - outputs_tensor_info.append( - TensorInfo( - len(runtime_shape), - tuple(runtime_shape), - dtype, - output_device, - ) - ) - return outputs_tensor_info - - def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]): - outputs_shape, all_outputs_known = self._get_outputs_shape() - if not all_outputs_known: - inputs_shape = self._get_inputs_runtime_shape(inputs) - outputs_shape = self._execute_shape_inference(inputs_shape, outputs_shape) - output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices) - return output_tensor_info - def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]: + def execute(self, inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]: in_args = [] for inp in inputs: memref = inp.trace_tensor.producer.data @@ -131,45 +59,10 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> ) in_args.append(memref) - # HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR. - out_tensor_info = self.get_output_tensor_runtime_info(inputs, output_devices) - - # Allocate output memory and store buffer pointers. - outputs = [ - create_memref( - shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream - ) - for info in out_tensor_info - ] - - out_args = [] - for out in outputs: - memref = out - # HACK (#155): MLIR-TensorRT requires inputs to be on device. - # Remove explicit copy to device once #155 is addressed. - if memref.address_space != runtime.PointerType.device: - memref = self.runtime_client.copy_to_device( - host_memref=memref, - device=self.runtime_client.get_devices()[0], - stream=self.stream._active_cuda_stream, - ) - if not memref: - raise_error("Could not allocate output memref", details=memref.error_details) - out_args.append(memref) - # Execute and populate device pointers. - self.session.execute_function( - "main", in_args=in_args, out_args=out_args, stream=self.stream._active_cuda_stream + outputs = self.session.execute_function( + "main", in_args=in_args, stream=self.stream._active_cuda_stream, client=self.runtime_client ) - # For outputs that were on the host, do the copy back - # TODO(#155): MLIR-TensorRT should allow output tensor placements on host. - for idx, out_info in enumerate(out_tensor_info): - if out_info.device.kind != "gpu": - self.runtime_client.copy_to_host( - device_memref=out_args[idx], - existing_host_memref=outputs[idx], - stream=self.stream._active_cuda_stream, - ) - + # For now return results on GPU. return outputs diff --git a/tripy/tripy/flat_ir/ops/copy.py b/tripy/tripy/flat_ir/ops/copy.py index 48598b2c..092ff937 100644 --- a/tripy/tripy/flat_ir/ops/copy.py +++ b/tripy/tripy/flat_ir/ops/copy.py @@ -29,6 +29,12 @@ class CopyOp(BaseFlatIROp): target: tripy.common.device + def set_memory_space_attr(self, tensor, mem_space_attr): + current_type = tensor.type + # Set the encoding attribute on the operation's result + new_type = ir.RankedTensorType.get(current_type.shape, current_type.element_type, encoding=mem_space_attr) + tensor.set_type(new_type) + def to_mlir(self, operands): from mlir_tensorrt.compiler.dialects import bufferization, tensor, arith @@ -46,7 +52,10 @@ def to_mlir(self, operands): sliced_dims.append(dim) alloc_tensor = bufferization.alloc_tensor(inp_type, sliced_dims, memory_space=mem_space_attr) + self.set_memory_space_attr(alloc_tensor, mem_space_attr) result_tensor = bufferization.materialize_in_destination(inp_type, operands[0], alloc_tensor) + self.set_memory_space_attr(result_tensor, mem_space_attr) cast_tensor = tensor.cast(self.outputs[0].to_mlir(), result_tensor) + self.set_memory_space_attr(cast_tensor, mem_space_attr) return [cast_tensor] diff --git a/tripy/tripy/frontend/tensor.py b/tripy/tripy/frontend/tensor.py index 344567ba..5c5e2983 100644 --- a/tripy/tripy/frontend/tensor.py +++ b/tripy/tripy/frontend/tensor.py @@ -185,11 +185,11 @@ def eval(self) -> runtime.MemRefValue: compiler = Compiler(trt_builder_opt_level=0) executable = compiler.compile(mlir, flat_ir=flat_ir) - executor = Executor(executable) + self.executor = Executor(executable) # Upon computing the value of this tensor, we switch it to have a `Storage` # parameter so that it does not need to be computed again. - data = executor.execute([out.device for out in flat_ir.outputs]) - executor.stream.synchronize() + data = self.executor.execute() + self.executor.stream.synchronize() assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor" data = data[0]