From b59284e7bd06401c7c28615d086605a6f1f807e4 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Mon, 23 Sep 2024 04:04:03 -0700 Subject: [PATCH] Add emergency checkpoint logging of arrays for debugging. PiperOrigin-RevId: 677727333 --- .../_src/serialization/serialization.py | 21 ++++++++++++++++++- .../emergency/checkpoint_manager.py | 1 + 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py index 805706e3..03f98d20 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py @@ -58,14 +58,25 @@ async def create_async_array_from_callback( inp_sharding: jax.sharding.Sharding, data_callback: Callable[[Index, jax.Device], Awaitable[jax.Array]], ) -> jax.Array: + """Docstring.""" device_to_index_map = inp_sharding.devices_indices_map(global_shape) addressable_da = inp_sharding._addressable_device_assignment # pylint: disable=protected-access future_arrays = [data_callback(device_to_index_map[d], d) for d in addressable_da] dbs = await asyncio.gather(*future_arrays) - return jax.make_array_from_single_device_arrays( + result = jax.make_array_from_single_device_arrays( global_shape, inp_sharding, dbs ) + logging.info( + '[process=%d] create_async_array_from_callback', multihost.process_index() + ) + logging.info('global_shape: %s', global_shape) + logging.info('inp_sharding: %s', inp_sharding) + logging.info('addressable_da: %s', addressable_da) + logging.info( + 'loading indices: %s', [device_to_index_map[d] for d in addressable_da] + ) + return result def _get_metadata(arr): @@ -524,6 +535,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore, async def _read_and_device_put_shard( + index: Index, device: jax.Device, t: ts.TensorStore, new_shard_shape: Sequence[int], @@ -549,6 +561,12 @@ async def _read_and_device_put_shard( # make this work. if out.dtype == jnp.int4: out = jnp.asarray(out) # type: ignore + logging.info( + '_read_and_device_put_shard: (index, shard_shape, arr): %s, %s, %s', + index, + out.shape, + out, + ) return jax.device_put(out, jax.sharding.SingleDeviceSharding(device)) @@ -577,6 +595,7 @@ async def _read_array_index_callback( # Limit the bytes read for every shard. async with reserved_bytes(byte_limiter, requested_bytes): result = await _read_and_device_put_shard( + index, device, t, new_shard_shape, diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py index 2f5e16bd..26c18d0b 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py @@ -1069,6 +1069,7 @@ def _restore_from_local( step, restoring_slice_id, ) + logging.info('jax.local_devices: %s', jax.local_devices()) step_stats = step_statistics.EmergencyRestoreStepStatistics() step_stats.checkpoint_manager_start_time = time.time() step_stats.step = step