From ba428271d898a348c91f4180481ab4d74290b49b Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Mon, 23 Sep 2024 04:04:03 -0700 Subject: [PATCH] Add emergency checkpoint logging of arrays for debugging. PiperOrigin-RevId: 677727333 --- .../emergency/checkpoint_manager.py | 15 ++++++++++++++ .../checkpoint/multihost/multislice_utils.py | 20 ++++++++++++++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py index 63fb40e6d..5a6f80c77 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py @@ -71,6 +71,12 @@ _SECONDARY_REPLICA_ID = 1 +def _log_array(keypath, arr): + logging.info('Key: %s', utils.tuple_path_from_keypath(keypath)) + for shard in arr.addressable_shards: + logging.info(shard.data) + + def _write_process_metadata(path: epath.Path, mesh: jax.sharding.Mesh): """Write process metadata to the given path.""" logging.info('Saving process index metadata at %s', path) @@ -1025,6 +1031,10 @@ def _get_single_slice_sharding( time.time() - step_stats.checkpointer_start_time ) in_tree = tuple(jax.tree.flatten(single_slice_pytree)[0]) + + logging.info('Restored arrays') + jax.tree_util.tree_map_with_path(_log_array, in_tree) + else: logging.vlog( 1, @@ -1057,6 +1067,9 @@ def create_zeros(shape_dtype_tup): '/orbax/emergency/checkpoint/read/broadcast_duration_secs', broadcast_elapsed_s, ) + logging.info('Broadcasted arrays') + jax.tree_util.tree_map_with_path(_log_array, shared_states) + step_stats.broadcast_start_time = start_broadcast step_stats.broadcast_duration_secs = broadcast_elapsed_s step_stats.checkpoint_manager_duration_secs = ( @@ -1072,6 +1085,8 @@ def create_zeros(shape_dtype_tup): finalized_shared_states = self._consistent_restore_mesh_to_global_mesh( shared_states ) + logging.info('Finalized arrays') + jax.tree_util.tree_map_with_path(_log_array, finalized_shared_states) return jax.tree.unflatten(tree_defs, finalized_shared_states) diff --git a/checkpoint/orbax/checkpoint/multihost/multislice_utils.py b/checkpoint/orbax/checkpoint/multihost/multislice_utils.py index 1a6ec7938..3910a777a 100644 --- a/checkpoint/orbax/checkpoint/multihost/multislice_utils.py +++ b/checkpoint/orbax/checkpoint/multihost/multislice_utils.py @@ -21,6 +21,7 @@ import jax from jax import numpy as jnp import numpy as np +from orbax.checkpoint import tree as tree_utils from orbax.checkpoint.multihost import utils PyTree = Any @@ -30,6 +31,12 @@ MEMORY_FACTOR = 3 +def _log_array(keypath, arr): + logging.info('Key: %s', utils.tuple_path_from_keypath(keypath)) + for shard in arr.addressable_shards: + logging.info(shard.data) + + def process_slice_id( process_index: int, global_mesh: jax.sharding.Mesh, @@ -217,7 +224,8 @@ def broadcast_one_replica_to_all( logging.info('Using available memory of %d bytes.', memory_limit_bytes) # Set replica_axis to be 0, regardless of its actual value. - def globalize_single_replica_arrays(inp): + def globalize_single_replica_arrays(keypath, inp): + logging.info('Key: %s', tree_utils.tuple_path_from_keypath(keypath)) sharding = inp.sharding if not isinstance(sharding, jax.sharding.NamedSharding): raise ValueError( @@ -233,9 +241,12 @@ def globalize_single_replica_arrays(inp): ) global_shape = (num_replicas,) + inp.shape[1:] global_sharding = jax.sharding.NamedSharding(global_mesh, in_spec) - return jax.make_array_from_single_device_arrays( + result = jax.make_array_from_single_device_arrays( global_shape, global_sharding, [s.data for s in inp.addressable_shards] ) + for shard in result.addressable_shards: + logging.info(shard.data) + return result tree_len = len(in_tree) start = 0 @@ -270,7 +281,9 @@ def globalize_single_replica_arrays(inp): ), subtree, ) - in_tree_sharded = jax.tree.map(globalize_single_replica_arrays, subtree) + in_tree_sharded = jax.tree_util.tree_map_with_path( + globalize_single_replica_arrays, subtree + ) # Delete immediately to conserve memory. jax.tree.map(lambda x: x.delete(), subtree) @@ -278,6 +291,7 @@ def globalize_single_replica_arrays(inp): lambda tree: jax.tree.map(functools.partial(jnp.sum, axis=0), tree), out_shardings=out_sharding, )(in_tree_sharded) + jax.tree_util.tree_map_with_path(_log_array, out_subtree) out_tree.extend(out_subtree) jax.block_until_ready(out_subtree) start = end