Skip to content

Commit

Permalink
python3Packages.jax-cuda12-plugin: patch like jax-cuda12-pjrt
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-huan committed Jan 20, 2025
1 parent 659babe commit 9bd08f0
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions pkgs/development/python-modules/jax-cuda12-plugin/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
jax-cuda12-pjrt,
}:
let
inherit (cudaPackages) cudaVersion;
inherit (jaxlib) version;
inherit (cudaPackages) cudaVersion;
inherit (jax-cuda12-pjrt) cudaLibPath;

getSrcFromPypi =
{
Expand Down Expand Up @@ -94,12 +95,34 @@ buildPythonPackage {
wheelUnpackHook
];

# jax-cuda12-plugin looks for ptxas at runtime, e.g. with a xla custom call.
# Linking into $out is the least bad solution. See
# * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
# * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
# * https://github.com/NixOS/nixpkgs/pull/375186
# for more info.
postInstall = ''
mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
'';

# jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen
# and these implicit dependencies are not recognized by ldd or
# autoPatchelfHook. That means we need to sneak them into rpath. This step
# must be done after autoPatchelfHook and the automatic stripping of
# artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
# patchPhase.
preInstallCheck = ''
patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
'';

dependencies = [ jax-cuda12-pjrt ];

pythonImportsCheck = [ "jax_cuda12_plugin" ];

# no tests
doCheck = false;
# FIXME: there are no tests, but we need to run preInstallCheck above
doCheck = true;

meta = {
description = "JAX Plugin for CUDA12";
Expand Down

0 comments on commit 9bd08f0

Please sign in to comment.