From e04168df9be58a7b33a391ed714d2a606e458ff9 Mon Sep 17 00:00:00 2001 From: Markus Date: Mon, 11 Dec 2023 17:02:51 +0100 Subject: [PATCH] Fixed subset function in SampledObs for multiple local devices. --- jVMC/stats.py | 3 ++- jVMC/version.py | 2 +- tests/stats_test.py | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/jVMC/stats.py b/jVMC/stats.py index 422e40d..9db9785 100644 --- a/jVMC/stats.py +++ b/jVMC/stats.py @@ -212,7 +212,8 @@ def subset(self, start=None, end=None, step=None): newObs._data = _get_subset_helper(self._data, (start, end, step)) newObs._weights = newObs._weights / normalization newObs._data = newObs._data / jnp.sqrt(normalization) - newObs._mean = mpi.global_sum( _subset_mean_helper(newObs._data, newObs._weights, self._mean)[None,...] ) + + newObs._mean = mpi.global_sum( _subset_mean_helper(newObs._data, newObs._weights, 0.0)[:,None,...] ) + self._mean newObs._data = _subset_data_prep(newObs._data, newObs._weights, self._mean, newObs._mean) return newObs diff --git a/jVMC/version.py b/jVMC/version.py index f023598..83f68a0 100644 --- a/jVMC/version.py +++ b/jVMC/version.py @@ -1,2 +1,2 @@ """Current jVMC version at head on Github.""" -__version__ = "1.2.4" +__version__ = "1.2.5" diff --git a/tests/stats_test.py b/tests/stats_test.py index d667b2f..b1e056c 100644 --- a/tests/stats_test.py +++ b/tests/stats_test.py @@ -34,3 +34,21 @@ def test_sampled_obs(self): O = obs2._data.reshape((-1,2)) self.assertTrue(jnp.allclose(obs2.tangent_kernel(), jnp.matmul(O, jnp.conj(jnp.transpose(O))))) + def test_subset_function(self): + + N = 10 + Obs1 = jnp.reshape(jnp.arange(jax.device_count()*N), (jax.device_count(), N, 1)) + p = jax.random.uniform(jax.random.PRNGKey(123), (jax.device_count(),N)) + p = p / jnp.sum(p) + + obs1 = SampledObs(Obs1, p) + obs2 = obs1.subset(0,N//2) + + self.assertTrue( jnp.allclose(obs1.mean(), jnp.sum(jnp.reshape(Obs1, (jax.device_count(), N)) * p)) ) + + self.assertTrue(obs2.mean() - jnp.sum(jnp.reshape(Obs1, (jax.device_count(), N))[:,0:N//2] * p[:,0:N//2]) / jnp.sum(p[:,0:N//2])) + + + obs3 = SampledObs(Obs1[:,0:N//2,:], p[:,0:N//2] / jnp.sum(p[:,0:N//2])) + + self.assertTrue( jnp.allclose(obs3.covar(), obs2.covar()))