Skip to content

Commit

Permalink
Megatron sharding PaliGemma. Faster bs=1 decode.
Browse files Browse the repository at this point in the history
Co-authored-by: Andreas Steiner <[email protected]>
Co-authored-by: André Susano Pinto <[email protected]>
  • Loading branch information
3 people committed Jul 12, 2024
1 parent 1475d53 commit bd9c689
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 51 deletions.
37 changes: 37 additions & 0 deletions big_vision/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,40 @@ def _update_spec(cur_spec, mesh, name, x):
return nn.logical_to_mesh_axes(cur_spec.names)
return cur_spec
return _update_spec


@Registry.register("shardings.shard_dim")
def shard_dim(axis, dim, ignore_ndim_error=False):
"""Shards the given dimension along the given axis.
Args:
axis: mesh axis name for sharding.
dim: dimension to shard (can be negative).
ignore_ndim_error: if True, a warning error is logged instead of raising an
exception when the given dimension is not compatible with the number of
dimensions of the array.
Returns:
A function that updates the sharding spec.
"""
def _update_spec(cur_spec, mesh, name, x):
del mesh, x
if np.abs(dim) >= len(cur_spec):
msg = f"Cannot shard_dim({axis}, {dim}): name={name} cur_spec={cur_spec}"
if ignore_ndim_error:
logging.warning(msg)
return cur_spec
else:
raise ValueError(msg)
pos_dim = dim
if pos_dim < 0:
pos_dim += len(cur_spec)
if cur_spec[pos_dim] is not None:
raise ValueError(
f"Already sharded: shard_dim({axis}, {dim}):"
f" name={name} cur_spec={cur_spec}"
)
new_spec = cur_spec[:pos_dim] + (axis,) + cur_spec[pos_dim + 1 :]
return new_spec

return _update_spec
146 changes: 95 additions & 51 deletions big_vision/trainers/proj/paligemma/predict_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Prediction functions for clippo/generative.py."""
"""Prediction functions for PaliGemma."""

import collections
import functools

from big_vision.pp import registry
Expand All @@ -23,6 +24,9 @@
import jax.numpy as jnp
import numpy as np


P = jax.sharding.PartitionSpec

# pylint: disable=missing-function-docstring


Expand Down Expand Up @@ -64,56 +68,82 @@ def _image_avg_repr(train_state, batch, *, model, key="img/pre_logits"):

def _decode_with_logp(
train_state, batch, *, model, devices, max_decode_len, eos_token,
best_of_n=1, sampler="greedy"):
best_of_n=1, sampler="greedy", replicate_out=False, eos_look_behind=0):
"""Sample token continuations to the input sequences."""
mesh = jax.sharding.Mesh(devices, ("fsdp",))
replicate_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec())
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("fsdp"))

# Wrap `jit` to avoid repeating how the functions should be jitted:
# - Output sharding defaults to data_sharding.
# - All keyword arguments are static.
def jit(fn, *, out_shardings=data_sharding, **jit_kw):
return lambda *a, **kw: jax.jit(
fn, **jit_kw, out_shardings=out_shardings, static_argnames=kw.keys()
)(*a, **kw)
mesh = jax.sharding.Mesh(devices, ("devices",))
replicate_sharding = jax.sharding.NamedSharding(mesh, P())
out_sharding = jax.sharding.NamedSharding(
mesh, P() if replicate_out else P("devices")
)

# Prefill the model cache and generate logits for first token.
logits, cache = jit(_prefill_cache)(
train_state["params"], {
logits, cache = jax.jit(
_prefill_cache,
out_shardings=out_sharding,
static_argnames=("model", "max_decode_len"),
)(
train_state["params"],
{
"image": batch["image"],
"text": batch["text"],
"mask_input": batch["mask_input"],
"mask_ar": batch["mask_ar"],
},
model=model, max_decode_len=max_decode_len)
model=model,
max_decode_len=max_decode_len,
)

# Mask indicating real examples. False if example is used to pad the batch.
mask = batch["_mask"]

# Repeat example in case we are picking the best of n.
logits, cache, mask = jit(_bon_repeat)((logits, cache, mask), n=best_of_n)
logits, cache, mask = jax.jit(
_bon_repeat,
static_argnames=("n",)
)((logits, cache, mask), n=best_of_n)

decode_sample_output = jax.jit(
_decode_sample_output,
static_argnames=("max_decode_len", "sampler"),
)
decode_early_stop = jax.jit(
_decode_early_stop,
out_shardings=replicate_sharding,
static_argnames=("eos_token",),
)
extend_cache = jax.jit(
_extend_cache,
donate_argnums=1,
static_argnames=("model",),
)

# Keep sampling tokens from last logits until EOS or max_decode_len.
state = None
# Setting `eos_look_behind>0` removes blocking transfer with small batches.
stops = collections.deque(maxlen=1 + eos_look_behind)
for idx in range(max_decode_len):
tokens, state = jit(_decode_sample_output)(
state, logits,
max_decode_len=max_decode_len, sampler=sampler)
tokens, state = decode_sample_output(
state, logits, max_decode_len=max_decode_len, sampler=sampler
)

early_stop = jit(_decode_early_stop, out_shardings=replicate_sharding)(
state, mask, eos_token=eos_token)
if jax.device_get(early_stop) or (idx + 1 >= max_decode_len):
if idx + 1 >= max_decode_len:
break

stops.append(decode_early_stop(state, mask, eos_token=eos_token))
if len(stops) == stops.maxlen and jax.device_get(stops[0]):
break

# Compute logits for next token
logits, cache = jit(_extend_cache, donate_argnums=1)(
train_state["params"], cache, tokens, model=model)
logits, cache = extend_cache(
train_state["params"], cache, tokens, model=model
)

# Select the best of n sample for each example.
_, tokens, logp = jit(_bon_select)(state, n=best_of_n, eos_token=eos_token)
_, tokens, logp = jax.jit(
_bon_select,
out_shardings=out_sharding,
static_argnames=("n", "eos_token"),
)(state, n=best_of_n, eos_token=eos_token)

return tokens, logp

Expand Down Expand Up @@ -252,7 +282,8 @@ def _temperature_sampling(t, *, logits, rng):


@registry.Registry.register("paligemma_sampler.nucleus")
def _nucleus_sampling(p: float, *, logits, rng):
def _nucleus_sampling(p: float, t: float = 1.0, *, logits, rng):
logits = logits / t
neg_inf = np.array(-1.0e7) # Effective negative infinity.
logits_sorted = jnp.sort(logits, axis=-1, descending=True)
sorted_cum_probs = jnp.cumsum(
Expand All @@ -266,52 +297,65 @@ def _nucleus_sampling(p: float, *, logits, rng):

def _beam_decode(train_state, batch, *,
model, devices, max_decode_len,
eos_token, beam_size):
eos_token, beam_size, replicate_out=False):
"""Beam search (greedy/top-k exploration)."""
mesh = jax.sharding.Mesh(devices, ("fsdp",))
replicate_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec())
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("fsdp"))

# Wrap `jit` to avoid repeating how the functions should be jitted:
# - Output sharding defaults to data_sharding.
# - All keyword arguments are static.
def jit(fn, *, out_shardings=data_sharding, **jit_kw):
return lambda *a, **kw: jax.jit(
fn, **jit_kw, out_shardings=out_shardings, static_argnames=kw.keys()
)(*a, **kw)
mesh = jax.sharding.Mesh(devices, ("devices",))
replicate_sharding = jax.sharding.NamedSharding(mesh, P())
out_sharding = jax.sharding.NamedSharding(
mesh, P() if replicate_out else P("devices")
)

# Prefill the model cache and generate logits for first token.
logits, cache = jit(_prefill_cache)(
train_state["params"], {
logits, cache = jax.jit(
_prefill_cache,
out_shardings=out_sharding,
static_argnames=("model", "max_decode_len"),
)(
train_state["params"],
{
"image": batch["image"],
"text": batch["text"],
"mask_input": batch["mask_input"],
"mask_ar": batch["mask_ar"],
},
model=model, max_decode_len=max_decode_len)
model=model,
max_decode_len=max_decode_len,
)

# Mask indicating real examples. False if example is used to pad the batch.
mask = batch["_mask"]

beam_sample_output = jax.jit(
_beam_sample_output,
static_argnames=("max_decode_len", "beam_size", "eos_token"),
)
beam_early_stop = jax.jit(
_beam_early_stop,
out_shardings=replicate_sharding,
static_argnames=("eos_token",),
)
extend_cache = jax.jit(
_extend_cache,
donate_argnums=1,
static_argnames=("model",),
)

# Keep sampling tokens from last logits until EOS or max_decode_len.
state = None
for idx in range(max_decode_len):
tokens, state, cache = jit(_beam_sample_output)(
tokens, state, cache = beam_sample_output(
state, logits, cache,
max_decode_len=max_decode_len, beam_size=beam_size, eos_token=eos_token)

early_stop = jit(_beam_early_stop, out_shardings=replicate_sharding)(
state, mask, eos_token=eos_token)
early_stop = beam_early_stop(state, mask, eos_token=eos_token)
if jax.device_get(early_stop) or (idx + 1 >= max_decode_len):
break

# Compute logits for next token
logits, cache = jit(_extend_cache, donate_argnums=1)(
logits, cache = extend_cache(
train_state["params"], cache, tokens, model=model)

return jit(_beam_make_output)(state)
return jax.jit(_beam_make_output, out_shardings=out_sharding)(state)


def _beam_early_stop(state, mask, eos_token):
Expand Down

0 comments on commit bd9c689

Please sign in to comment.