diff --git a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu index aee2961f56c5b5..93ff4272769476 100644 --- a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu @@ -172,7 +172,7 @@ template < typename GeneralDispatcher> static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) { if (iter.dtype() == kHalf) { - return OpFunctor{}(iter); + return OpFunctor{}(iter); } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { // type promotion that does cast and reduction in a single kernel return OpFunctor{}(iter); diff --git a/c10/util/Half-inl.h b/c10/util/Half-inl.h index cad9762d4469f9..3ae2d5284f9310 100644 --- a/c10/util/Half-inl.h +++ b/c10/util/Half-inl.h @@ -105,7 +105,12 @@ inline __device__ Half __ldg(const Half* ptr) { /// Arithmetic inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) { +#if (defined(__CUDACC__) || defined(__HIPCC__)) && \ + defined(PYTORCH_ENABLE_NATIVE_HALF) + return __half{a} + __half{b}; +#else return static_cast(a) + static_cast(b); +#endif } inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) { diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 504d6ee2243db1..8fe222abc1fa38 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1000,6 +1000,10 @@ if(USE_CUDNN) target_include_directories(torch::cudnn INTERFACE ${CUDNN_FRONTEND_INCLUDE_DIR}) endif() +# Note: This variable also affects CUDA as well. +set(PYTORCH_ENABLE_HALF $ENV{PYTORCH_ENABLE_HALF} CACHE BOOL "Native support for half data type from EnVar." FORCE) + + # ---[ HIP if(USE_ROCM) # This prevents linking in the libtinfo from /opt/conda/lib which conflicts with ROCm libtinfo. @@ -1042,7 +1046,11 @@ if(USE_ROCM) list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1) list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1) list(APPEND HIP_CXX_FLAGS -DUSE_ROCM) - list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1) + if(NOT PYTORCH_ENABLE_HALF) + list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1) + else() + add_definitions(-DPYTORCH_ENABLE_NATIVE_HALF) + endif() list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1) list(APPEND HIP_CXX_FLAGS -DTORCH_HIP_VERSION=${TORCH_HIP_VERSION}) list(APPEND HIP_CXX_FLAGS -Wno-shift-count-negative) @@ -1369,11 +1377,16 @@ if(NOT INTERN_BUILD_MOBILE) message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor") string(APPEND CMAKE_CUDA_FLAGS " -DCUDA_HAS_FP16=1" - " -D__CUDA_NO_HALF_OPERATORS__" " -D__CUDA_NO_HALF_CONVERSIONS__" " -D__CUDA_NO_HALF2_OPERATORS__" " -D__CUDA_NO_BFLOAT16_CONVERSIONS__") + if(NOT PYTORCH_ENABLE_HALF) + string(APPEND CMAKE_CUDA_FLAGS " -D__CUDA_NO_HALF_OPERATORS__") + else() + add_definitions(-DPYTORCH_ENABLE_NATIVE_HALF) + endif() + string(APPEND CMAKE_C_FLAGS_RELEASE " -DNDEBUG") string(APPEND CMAKE_CXX_FLAGS_RELEASE " -DNDEBUG") if(NOT GENERATOR_IS_MULTI_CONFIG) diff --git a/setup.py b/setup.py index ea087c356c1520..0ae7e53cbca1a9 100644 --- a/setup.py +++ b/setup.py @@ -197,6 +197,10 @@ # # BUILD_PYTHON_ONLY # Builds pytorch as a wheel using libtorch.so from a seperate wheel +# +# PYTORCH_ENABLE_HALF +# If set to '1' will enable native support for FP16 datatypes in certain functors. +# Note: Currently, this is considered experimental and will only affect reductions. import os import pkgutil @@ -676,6 +680,11 @@ def run(self): else: report("-- Not using ITT") + if cmake_cache_vars["PYTORCH_ENABLE_HALF"]: + report("-- Using native FP16 support") + else: + report("-- Not using native FP16 support") + # Do not use clang to compile extensions if `-fstack-clash-protection` is defined # in system CFLAGS c_flags = str(os.getenv("CFLAGS", ""))