Skip to content

Commit

Permalink
Avoid using deprecated jax.numpy.ndarray.tile() method
Browse files Browse the repository at this point in the history
This has been deprecated since JAX version 0.3.7, and will be removed in jax-ml/jax#11944. The jax.numpy.tile() function is a direct replacement.

PiperOrigin-RevId: 468227966
  • Loading branch information
Jake VanderPlas authored and The jax3d Authors committed Aug 17, 2022
1 parent bad1d9a commit 52d91c5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax3d/math/volume_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def sample_1d(
# Initialize t with leading bin edges
t = jnp.linspace(0.0, 1.0, sample_count, endpoint=False, dtype=dtype)
t = t.reshape(*([1] * len(batch_shape)), sample_count)
t = t.tile(tuple(batch_shape) + (1,))
t = jnp.tile(t, tuple(batch_shape) + (1,))

if strategy == SamplingStrategy.STRATIFIED:
# Randomly perturb points within depth bins
Expand Down Expand Up @@ -106,7 +106,7 @@ def sample_1d_grid(
t = jnp.linspace(0.0, 1.0, sample_count, endpoint=False, dtype=dtype)
t += 0.5 / sample_count
t = t.reshape(*([1] * len(batch_shape)), sample_count)
t = t.tile(tuple(batch_shape) + (1,))
t = jnp.tile(t, tuple(batch_shape) + (1,))
return t


Expand Down

0 comments on commit 52d91c5

Please sign in to comment.