Skip to content

Commit

Permalink
Avoid using deprecated jax.numpy.ndarray.tile() method
Browse files Browse the repository at this point in the history
This has been deprecated since JAX version 0.3.7, and will be removed in jax-ml/jax#11944. The jax.numpy.tile() function is a direct replacement, but in some places arrays can be constructed more efficiently using `lax.broadcasted_iota`.

PiperOrigin-RevId: 468227966
  • Loading branch information
Jake VanderPlas authored and The jax3d Authors committed Aug 17, 2022
1 parent bad1d9a commit 06a9583
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax3d/math/volume_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def sample_1d(
# Initialize t with leading bin edges
t = jnp.linspace(0.0, 1.0, sample_count, endpoint=False, dtype=dtype)
t = t.reshape(*([1] * len(batch_shape)), sample_count)
t = t.tile(tuple(batch_shape) + (1,))
t = jnp.tile(t, tuple(batch_shape) + (1,))

if strategy == SamplingStrategy.STRATIFIED:
# Randomly perturb points within depth bins
Expand Down Expand Up @@ -106,7 +106,7 @@ def sample_1d_grid(
t = jnp.linspace(0.0, 1.0, sample_count, endpoint=False, dtype=dtype)
t += 0.5 / sample_count
t = t.reshape(*([1] * len(batch_shape)), sample_count)
t = t.tile(tuple(batch_shape) + (1,))
t = jnp.tile(t, tuple(batch_shape) + (1,))
return t


Expand Down

0 comments on commit 06a9583

Please sign in to comment.