Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add emergency checkpoint logging of arrays for debugging. #1198

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion checkpoint/orbax/checkpoint/_src/serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,25 @@ async def create_async_array_from_callback(
inp_sharding: jax.sharding.Sharding,
data_callback: Callable[[Index, jax.Device], Awaitable[jax.Array]],
) -> jax.Array:
"""Docstring."""
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
addressable_da = inp_sharding._addressable_device_assignment # pylint: disable=protected-access
future_arrays = [data_callback(device_to_index_map[d], d)
for d in addressable_da]
dbs = await asyncio.gather(*future_arrays)
return jax.make_array_from_single_device_arrays(
result = jax.make_array_from_single_device_arrays(
global_shape, inp_sharding, dbs
)
logging.info(
'[process=%d] create_async_array_from_callback', multihost.process_index()
)
logging.info('global_shape: %s', global_shape)
logging.info('inp_sharding: %s', inp_sharding)
logging.info('addressable_da: %s', addressable_da)
logging.info(
'loading indices: %s', [device_to_index_map[d] for d in addressable_da]
)
return result


def _get_metadata(arr):
Expand Down Expand Up @@ -524,6 +535,7 @@ def estimate_read_memory_footprint(t: ts.TensorStore,


async def _read_and_device_put_shard(
index: Index,
device: jax.Device,
t: ts.TensorStore,
new_shard_shape: Sequence[int],
Expand All @@ -549,6 +561,12 @@ async def _read_and_device_put_shard(
# make this work.
if out.dtype == jnp.int4:
out = jnp.asarray(out) # type: ignore
logging.info(
'_read_and_device_put_shard: (index, shard_shape, arr): %s, %s, %s',
index,
out.shape,
out,
)
return jax.device_put(out, jax.sharding.SingleDeviceSharding(device))


Expand Down Expand Up @@ -577,6 +595,7 @@ async def _read_array_index_callback(
# Limit the bytes read for every shard.
async with reserved_bytes(byte_limiter, requested_bytes):
result = await _read_and_device_put_shard(
index,
device,
t,
new_shard_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,7 @@ def _restore_from_local(
step,
restoring_slice_id,
)
logging.info('jax.local_devices: %s', jax.local_devices())
step_stats = step_statistics.EmergencyRestoreStepStatistics()
step_stats.checkpoint_manager_start_time = time.time()
step_stats.step = step
Expand Down
Loading