diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index 137a70780..7e037ea92 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1-labs -ARG BASE_IMAGE=nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 +ARG BASE_IMAGE=nvidia/cuda:12.5.0-devel-ubuntu22.04 ARG GIT_USER_NAME="JAX Toolbox" ARG GIT_USER_EMAIL=jax@nvidia.com ARG CLANG_VERSION=17 @@ -146,6 +146,7 @@ RUN install-nsight.sh ############################################################################### ADD install-cudnn.sh /usr/local/bin +RUN install-cudnn.sh ############################################################################### ## Install NCCL diff --git a/.github/container/install-cudnn.sh b/.github/container/install-cudnn.sh index fab18edea..54bf804f4 100755 --- a/.github/container/install-cudnn.sh +++ b/.github/container/install-cudnn.sh @@ -20,8 +20,17 @@ cuda_major_version=$(nvcc --version | sed -n 's/^.*release \([0-9]*\.[0-9]*\).*$ # version of CUDA and cuDNN are compatible. # For example, CUDA 12.3 + cuDNN 8.9.6 (libcudnn8 version: 8.9.6.50-1+cuda12.2) is # considered to be compatible. -libcudnn_version=$(apt-cache show libcudnn${CUDNN_MAJOR_VERSION} | sed -n "s/^Version: \(.*+cuda${cuda_major_version}\.[0-9]*\)$/\1/p" | head -n 1) -libcudnn_dev_version=$(apt-cache show libcudnn${CUDNN_MAJOR_VERSION}-dev | sed -n "s/^Version: \(.*+cuda${cuda_major_version}\.[0-9]\)$/\1/p" | head -n 1) +if [[ ${CUDNN_MAJOR_VERSION} -le 8 ]]; then + libcudnn_name=libcudnn${CUDNN_MAJOR_VERSION} + libcudnn_dev_name=libcudnn${CUDNN_MAJOR_VERSION}-dev + version_pattern="s/^Version: \(.*+cuda${cuda_major_version}\.[0-9]*\)$/\1/p" +elif [[ ${CUDNN_MAJOR_VERSION} -eq 9 ]]; then + libcudnn_name=libcudnn${CUDNN_MAJOR_VERSION}-cuda-${cuda_major_version} + libcudnn_dev_name=libcudnn${CUDNN_MAJOR_VERSION}-dev-cuda-${cuda_major_version} + version_pattern="s/^Version: \(${CUDNN_MAJOR_VERSION}\.[0-9.-]*\)$/\1/p" +fi +libcudnn_version=$(apt-cache show $libcudnn_name | sed -n "$version_pattern" | head -n 1) +libcudnn_dev_version=$(apt-cache show $libcudnn_dev_name | sed -n "$version_pattern" | head -n 1) if [[ -z "${libcudnn_version}" || -z "${libcudnn_dev_version}" ]]; then echo "Could not find compatible cuDNN version for CUDA ${cuda_version}" exit 1 @@ -29,8 +38,8 @@ fi apt-get update apt-get install -y \ - libcudnn${CUDNN_MAJOR_VERSION}=${libcudnn_version} \ - libcudnn${CUDNN_MAJOR_VERSION}-dev=${libcudnn_dev_version} + ${libcudnn_name}=${libcudnn_version} \ + ${libcudnn_dev_name}=${libcudnn_dev_version} apt-get clean rm -rf /var/lib/apt/lists/*