Skip to content

Commit

Permalink
Add emergency checkpoint logging of arrays for debugging.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677727333
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Sep 25, 2024
1 parent fb7272d commit ac90ce1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@
_SECONDARY_REPLICA_ID = 1


def _log_array(keypath, arr):
logging.info('Key: %s', utils.tuple_path_from_keypath(keypath))
for shard in arr.addressable_shards:
logging.info(shard.data)


def _write_process_metadata(path: epath.Path, mesh: jax.sharding.Mesh):
"""Write process metadata to the given path."""
logging.info('Saving process index metadata at %s', path)
Expand Down Expand Up @@ -1025,6 +1031,10 @@ def _get_single_slice_sharding(
time.time() - step_stats.checkpointer_start_time
)
in_tree = tuple(jax.tree.flatten(single_slice_pytree)[0])

logging.info('Restored arrays')
jax.tree_util.tree_map_with_path(_log_array, in_tree)

else:
logging.vlog(
1,
Expand Down Expand Up @@ -1057,6 +1067,9 @@ def create_zeros(shape_dtype_tup):
'/orbax/emergency/checkpoint/read/broadcast_duration_secs',
broadcast_elapsed_s,
)
logging.info('Broadcasted arrays')
jax.tree_util.tree_map_with_path(_log_array, shared_states)

step_stats.broadcast_start_time = start_broadcast
step_stats.broadcast_duration_secs = broadcast_elapsed_s
step_stats.checkpoint_manager_duration_secs = (
Expand All @@ -1072,6 +1085,8 @@ def create_zeros(shape_dtype_tup):
finalized_shared_states = self._consistent_restore_mesh_to_global_mesh(
shared_states
)
logging.info('Finalized arrays')
jax.tree_util.tree_map_with_path(_log_array, finalized_shared_states)

return jax.tree.unflatten(tree_defs, finalized_shared_states)

Expand Down
20 changes: 17 additions & 3 deletions checkpoint/orbax/checkpoint/multihost/multislice_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jax
from jax import numpy as jnp
import numpy as np
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint.multihost import utils

PyTree = Any
Expand All @@ -30,6 +31,12 @@
MEMORY_FACTOR = 3


def _log_array(keypath, arr):
logging.info('Key: %s', tree_utils.tuple_path_from_keypath(keypath))
for shard in arr.addressable_shards:
logging.info(shard.data)


def process_slice_id(
process_index: int,
global_mesh: jax.sharding.Mesh,
Expand Down Expand Up @@ -217,7 +224,8 @@ def broadcast_one_replica_to_all(
logging.info('Using available memory of %d bytes.', memory_limit_bytes)

# Set replica_axis to be 0, regardless of its actual value.
def globalize_single_replica_arrays(inp):
def globalize_single_replica_arrays(keypath, inp):
logging.info('Key: %s', tree_utils.tuple_path_from_keypath(keypath))
sharding = inp.sharding
if not isinstance(sharding, jax.sharding.NamedSharding):
raise ValueError(
Expand All @@ -233,9 +241,12 @@ def globalize_single_replica_arrays(inp):
)
global_shape = (num_replicas,) + inp.shape[1:]
global_sharding = jax.sharding.NamedSharding(global_mesh, in_spec)
return jax.make_array_from_single_device_arrays(
result = jax.make_array_from_single_device_arrays(
global_shape, global_sharding, [s.data for s in inp.addressable_shards]
)
for shard in result.addressable_shards:
logging.info(shard.data)
return result

tree_len = len(in_tree)
start = 0
Expand Down Expand Up @@ -270,14 +281,17 @@ def globalize_single_replica_arrays(inp):
),
subtree,
)
in_tree_sharded = jax.tree.map(globalize_single_replica_arrays, subtree)
in_tree_sharded = jax.tree_util.tree_map_with_path(
globalize_single_replica_arrays, subtree
)
# Delete immediately to conserve memory.
jax.tree.map(lambda x: x.delete(), subtree)

out_subtree = jax.jit(
lambda tree: jax.tree.map(functools.partial(jnp.sum, axis=0), tree),
out_shardings=out_sharding,
)(in_tree_sharded)
jax.tree_util.tree_map_with_path(_log_array, out_subtree)
out_tree.extend(out_subtree)
jax.block_until_ready(out_subtree)
start = end
Expand Down

0 comments on commit ac90ce1

Please sign in to comment.