nca_jax NCA in jax example see https://github.com/kevinzakka/torchnca for the original, this is a port.