From c6b86b7560ac7fc6b92a6f540178a64c7ae0ac04 Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Fri, 8 Nov 2024 13:26:34 -0800 Subject: [PATCH] Add Layout support PiperOrigin-RevId: 694619629 --- checkpoint/CHANGELOG.md | 1 + .../_src/serialization/serialization.py | 25 +++++++++++-- .../_src/serialization/serialization_test.py | 37 +++++++++++++++++++ .../_src/serialization/type_handlers.py | 19 ++++++---- 4 files changed, 70 insertions(+), 12 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 834a2837..7a910a6f 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -15,6 +15,7 @@ that contain utilities to perform de/serialization for `RootMetadata` and `StepMetadata`. - `ReplicaSlice`/`ReplicaSlices` construct to facilitate saving replicated arrays. +- Added restoring with custom jax.experimental.layout.Layout support ### Changed - Refactor metadata/tree_test.py and move common test types to diff --git a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py index 1beaeeff..8ba26117 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py @@ -28,6 +28,7 @@ from absl import logging import humanize import jax +from jax.experimental import layout import jax.numpy as jnp import numpy as np from orbax.checkpoint._src.arrays import fragments @@ -48,8 +49,9 @@ ] -Shape = types.Shape Index = types.Index +Layout = layout.Layout +Shape = types.Shape async def create_async_array_from_callback( @@ -426,6 +428,7 @@ async def _read_and_device_put_shard( dtype: jnp.dtype, requested_domain: ts.IndexDomain, restricted_domain: ts.IndexDomain, + dll: Optional[layout.DeviceLocalLayout], ) -> jax.Array: """Reads a single shard from TensorStore and places it on device.""" # This maybe needed because the shape the array was saved with is smaller @@ -445,7 +448,9 @@ async def _read_and_device_put_shard( # make this work. if out.dtype == jnp.int4: out = jnp.asarray(out) # type: ignore - return jax.device_put(out, jax.sharding.SingleDeviceSharding(device)) + return jax.device_put( + out, Layout(dll, jax.sharding.SingleDeviceSharding(device)) + ) async def _read_array_index_callback( @@ -457,6 +462,7 @@ async def _read_array_index_callback( dtype: jnp.dtype, byte_limiter: ByteLimiter, strict: bool, + ddl: Optional[layout.DeviceLocalLayout], ) -> jax.Array: """Callback that reads an array index and places on device.""" if strict and t.shape != shape: @@ -479,12 +485,13 @@ async def _read_array_index_callback( dtype, requested_domain, restricted_domain, + ddl, ) return result async def async_deserialize( - user_in_sharding: jax.sharding.Sharding, + user_in_sharding: jax.sharding.Sharding | Layout, tensorstore_spec: Union[ts.Spec, Dict[str, Any]], global_shape: Optional[Sequence[int]] = None, dtype: Optional[jnp.dtype] = None, @@ -496,11 +503,20 @@ async def async_deserialize( """Reads an array using TensorStore.""" byte_limiter = byte_limiter or get_byte_limiter() context = context or ts_utils.get_ts_context(use_ocdbt=False) - in_sharding = user_in_sharding + in_sharding = ( + user_in_sharding.sharding + if isinstance(user_in_sharding, Layout) + else user_in_sharding + ) if not isinstance(in_sharding, jax.sharding.Sharding): raise ValueError( 'sharding passed to deserialization should be specified, concrete and' f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') + dll = ( + user_in_sharding.device_local_layout + if isinstance(user_in_sharding, Layout) + else None + ) t = await ts.open( tensorstore_spec, open=True, @@ -520,5 +536,6 @@ async def async_deserialize( dtype=dtype, byte_limiter=byte_limiter, strict=strict, + ddl=dll, ), ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py b/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py index a769549a..db9467a5 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py @@ -25,6 +25,7 @@ from absl.testing import parameterized import jax from jax import dtypes as _dtypes +from jax.experimental import layout import jax.numpy as jnp import numpy as np from orbax.checkpoint import future @@ -38,6 +39,8 @@ GSPMDSharding = jax.sharding.GSPMDSharding NamedSharding = jax.sharding.NamedSharding P = jax.sharding.PartitionSpec +DLL = layout.DeviceLocalLayout +Layout = layout.Layout jax.config.update('jax_enable_x64', True) @@ -598,6 +601,40 @@ def test_odd_resharding(self): for i, shard in enumerate(restored.addressable_shards): self.assertArraysEqual(np.asarray(shard.data), np.arange(4) + (i * 4)) + def test_load_with_layout(self): + mesh = create_global_mesh((4, 2), ('x', 'y')) + np_inp = np.arange(32).reshape(8, 4) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + out_layout = ( + jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)) + .lower(arr) + .compile() + .output_layouts + ) + self.assertEqual( + arr.layout.device_local_layout.major_to_minor, + out_layout.device_local_layout.major_to_minor[::-1], + ) + + ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) + ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, [ckpt_path]) + + serialize( + [arr], + tspecs, + ) + + (out,) = deserialize([out_layout], tspecs) + + self.assertEqual(out.layout, out_layout) + self.assertIsInstance(out, jax.Array) + self.assertArraysEqual(out, np_inp) + for s in out.addressable_shards: + self.assertArraysEqual(s.data, np_inp[s.index]) + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index bdc8a739..09b49372 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -31,6 +31,7 @@ from absl import logging from etils import epath import jax +from jax.experimental import layout import jax.numpy as jnp import numpy as np from orbax.checkpoint import future @@ -49,6 +50,7 @@ import tensorstore as ts +Layout = layout.Layout Shape = types.Shape Scalar = Union[int, float, np.number] Metadata = value_metadata.Metadata @@ -1041,13 +1043,12 @@ class ArrayRestoreArgs(RestoreArgs): mesh_axes: The mesh_axes that the array should be restored as. Cannot be None. sharding: - `jax.sharding.Sharding` object which takes precedence over mesh and - mesh_axes if provided. Otherwise, mesh and mesh_axes will be used to - construct a NamedSharding object OR `ShardingMetadata` which is an orbax - representation of `jax.sharding.Sharding` that stores the same properties - but does not require accessing real devices. - global_shape: - The global shape that the array should be restored into. If not + `jax.sharding.Sharding`, `ShardingMetadata`, or `Layout` object which takes + precedence over mesh and mesh_axes if provided. Otherwise, mesh and mesh_axes + will be used to construct a NamedSharding object OR `ShardingMetadata` which + is an orbax representation of `jax.sharding.Sharding` that stores the same + properties but does not require accessing real devices. + global_shape: The global shape that the array should be restored into. If not provided, the shape will be restored as written. Presently, arbitrary shape transformations are not supported (for example, reshaping to different dimensions). Padding and truncating are supported. When the global_shape is @@ -1064,7 +1065,9 @@ class ArrayRestoreArgs(RestoreArgs): restore_type: Optional[Any] = jax.Array mesh: Optional[jax.sharding.Mesh] = None mesh_axes: Optional[jax.sharding.PartitionSpec] = None - sharding: Optional[Union[jax.sharding.Sharding, ShardingMetadata]] = None + sharding: Optional[Union[jax.sharding.Sharding, ShardingMetadata, Layout]] = ( + None + ) global_shape: Optional[Tuple[int, ...]] = None strict: bool = True