diff --git a/CMakeLists.txt b/CMakeLists.txt index 7db2663..bf5c284 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -333,13 +333,6 @@ endif() # TRITON_ENABLE_ONNXRUNTIME_OPENVINO # if(TRITON_ONNXRUNTIME_DOCKER_BUILD) set(_GEN_FLAGS "") - if(${RHEL_BUILD} AND "${TRITON_BUILD_TENSORRT_HOME}" STREQUAL "") - if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "aarch64") - set(TRITON_BUILD_TENSORRT_HOME "/usr/local/cuda/targets/sbsa-linux/") - else() - set(TRITON_BUILD_TENSORRT_HOME "/usr/local/cuda/targets/x86_64-linux/") - endif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "aarch64") - endif(${RHEL_BUILD} AND "${TRITON_BUILD_TENSORRT_HOME}" STREQUAL "") if(NOT ${TRITON_BUILD_TARGET_PLATFORM} STREQUAL "") set(_GEN_FLAGS ${_GEN_FLAGS} "--target-platform=${TRITON_BUILD_TARGET_PLATFORM}") endif() # TRITON_BUILD_TARGET_PLATFORM diff --git a/tools/gen_ort_dockerfile.py b/tools/gen_ort_dockerfile.py index abd52c4..4b002d9 100755 --- a/tools/gen_ort_dockerfile.py +++ b/tools/gen_ort_dockerfile.py @@ -572,7 +572,13 @@ def preprocess_gpu_flags(): print("error: linux build requires --cudnn-home and --cuda-home") if FLAGS.tensorrt_home is None: - FLAGS.tensorrt_home = "/usr/src/tensorrt" + if target_platform() == "rhel": + if platform.machine() == "aarch64": + FLAGS.tensorrt_home = "/usr/local/cuda/targets/sbsa-linux/" + else: + FLAGS.tensorrt_home = "/usr/local/cuda/targets/x86_64-linux/" + else: + FLAGS.tensorrt_home = "/usr/src/tensorrt" if __name__ == "__main__":