Skip to content

Commit

Permalink
Fixed subset function in SampledObs for multiple local devices.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus committed Dec 11, 2023
1 parent bffcb52 commit e04168d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
3 changes: 2 additions & 1 deletion jVMC/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def subset(self, start=None, end=None, step=None):
newObs._data = _get_subset_helper(self._data, (start, end, step))
newObs._weights = newObs._weights / normalization
newObs._data = newObs._data / jnp.sqrt(normalization)
newObs._mean = mpi.global_sum( _subset_mean_helper(newObs._data, newObs._weights, self._mean)[None,...] )

newObs._mean = mpi.global_sum( _subset_mean_helper(newObs._data, newObs._weights, 0.0)[:,None,...] ) + self._mean
newObs._data = _subset_data_prep(newObs._data, newObs._weights, self._mean, newObs._mean)

return newObs
Expand Down
2 changes: 1 addition & 1 deletion jVMC/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Current jVMC version at head on Github."""
__version__ = "1.2.4"
__version__ = "1.2.5"
18 changes: 18 additions & 0 deletions tests/stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,21 @@ def test_sampled_obs(self):
O = obs2._data.reshape((-1,2))
self.assertTrue(jnp.allclose(obs2.tangent_kernel(), jnp.matmul(O, jnp.conj(jnp.transpose(O)))))

def test_subset_function(self):

N = 10
Obs1 = jnp.reshape(jnp.arange(jax.device_count()*N), (jax.device_count(), N, 1))
p = jax.random.uniform(jax.random.PRNGKey(123), (jax.device_count(),N))
p = p / jnp.sum(p)

obs1 = SampledObs(Obs1, p)
obs2 = obs1.subset(0,N//2)

self.assertTrue( jnp.allclose(obs1.mean(), jnp.sum(jnp.reshape(Obs1, (jax.device_count(), N)) * p)) )

self.assertTrue(obs2.mean() - jnp.sum(jnp.reshape(Obs1, (jax.device_count(), N))[:,0:N//2] * p[:,0:N//2]) / jnp.sum(p[:,0:N//2]))


obs3 = SampledObs(Obs1[:,0:N//2,:], p[:,0:N//2] / jnp.sum(p[:,0:N//2]))

self.assertTrue( jnp.allclose(obs3.covar(), obs2.covar()))

0 comments on commit e04168d

Please sign in to comment.