Skip to content

Commit

Permalink
Emergency checkpoint: compile broadcast function once at init.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694245124
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Nov 7, 2024
1 parent 4b2f712 commit d8a2a63
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 44 deletions.
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Emergency checkpoint: use JAX for global_max and combine multiple broadcasts
into one for `saved` bool broadcast. This should alleviate concerns about
broadcasting using the distributed system at large scale.
- Emergency checkpoint: compile broadcast function once at init.

## [0.8.0] - 2024-10-29

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,48 +453,6 @@ def _process_local_to_global(
return per_process_values


def _global_max(values: list[int]) -> list[int]:
"""Returns the global max of local values across all processes as a scalar.
Uses JAX-based broadcasting to ensure greater performance at scale.
Args:
values: A list of integers.
Args:
values: Max of the values across all processes.
"""
num_hosts = multihost.process_count()
num_devices_per_host = jax.local_device_count()
slice_mesh = jax.sharding.Mesh(
np.asarray(jax.devices()).reshape(num_hosts, num_devices_per_host),
['host', 'dev'],
)

sdas = []
for d in jax.local_devices():
sdas.append(
jax.device_put(np.asarray(values).reshape((1, 1, len(values))), d)
)
sharding = jax.sharding.NamedSharding(slice_mesh, P('host', 'dev'))
# TODO(cpgaffney): Use jax.make_array_from_process_local_data.
g_arr = jax.make_array_from_single_device_arrays(
(num_hosts, num_devices_per_host, len(values)), sharding, sdas
)
result_arr = jax.jit(
lambda x: x,
out_shardings=jax.sharding.NamedSharding(slice_mesh, P()),
)(g_arr)
result_arr = np.asarray(result_arr.addressable_data(0))

# Checks that for every host, values are equal across local devices.
assert (
np.sum(result_arr, axis=1) / num_devices_per_host == result_arr[:, 0, :]
).all()
# Select values from first device and compute max for each value across hosts.
return list(np.max(result_arr[:, 0, :], axis=0).astype(int))


def _all_devices_excepting_slice(
devices: np.ndarray,
*,
Expand Down Expand Up @@ -830,6 +788,18 @@ def __init__(
# Initialize step cache.
self.all_steps(read=True)

# Compile function for global broadcast.
slice_mesh = jax.sharding.Mesh(
np.asarray(jax.devices()).reshape(
multihost.process_count(), jax.local_device_count()
),
['host', 'dev'],
)
self._global_broadcast_fn = jax.jit(
lambda x: x,
out_shardings=jax.sharding.NamedSharding(slice_mesh, P()),
)

logging.info(
'Created emergency.CheckpointManager with slice_id=%d,'
' process_index=%d, jax.process_index=%d',
Expand Down Expand Up @@ -954,8 +924,43 @@ def reached_preemption(self, step: int) -> bool:
return utils.reached_preemption(step)

def _global_max(self, values: list[int]) -> list[int]:
"""Returns the global max of local values across all devices."""
return _global_max(values)
"""Returns the global max of local values across all processes as a scalar.
Uses JAX-based broadcasting to ensure greater performance at scale.
Args:
values: A list of integers.
Args:
values: Max of the values across all processes.
"""
num_hosts = multihost.process_count()
num_devices_per_host = jax.local_device_count()
slice_mesh = jax.sharding.Mesh(
np.asarray(jax.devices()).reshape(num_hosts, num_devices_per_host),
['host', 'dev'],
)

sdas = []
for d in jax.local_devices():
sdas.append(
jax.device_put(np.asarray(values).reshape((1, 1, len(values))), d)
)
sharding = jax.sharding.NamedSharding(slice_mesh, P('host', 'dev'))
# TODO(cpgaffney): Use jax.make_array_from_process_local_data.
g_arr = jax.make_array_from_single_device_arrays(
(num_hosts, num_devices_per_host, len(values)), sharding, sdas
)
result_arr = self._global_broadcast_fn(g_arr)
result_arr = np.asarray(result_arr.addressable_data(0))

# Checks that for every host, values are equal across local devices.
assert (
np.sum(result_arr, axis=1) / num_devices_per_host == result_arr[:, 0, :]
).all()
# Select values from first device and compute max for each value across
# hosts.
return list(np.max(result_arr[:, 0, :], axis=0).astype(int))

def should_save(self, step: int) -> bool:
"""Returns True if a checkpoint should be saved for the current step.
Expand Down

0 comments on commit d8a2a63

Please sign in to comment.