-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Submission of #1319. All credit and thanks to https://github.com/gsps…
…chmid. PiperOrigin-RevId: 696250850
- Loading branch information
1 parent
58d423f
commit c977d4c
Showing
5 changed files
with
421 additions
and
163 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
239 changes: 239 additions & 0 deletions
239
checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Handles replica slices of jax.Arrays and host transfers.""" | ||
|
||
import dataclasses | ||
import math | ||
from typing import Optional, Sequence | ||
|
||
from absl import logging | ||
import jax | ||
import numpy as np | ||
from orbax.checkpoint._src.arrays import fragments | ||
from orbax.checkpoint._src.arrays import numpy_utils | ||
from orbax.checkpoint._src.arrays import types | ||
from orbax.checkpoint._src.multihost import multihost | ||
|
||
|
||
Shape = types.Shape | ||
Index = types.Index | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ReplicaSlice: | ||
"""ReplicaSlice. | ||
ReplicaSlice represents the part of a jax.Shard that a replica is uniquely | ||
responsible for. A replica slice can be either on-device (backed by a slice of | ||
a single-sharding array) or on-host (backed by a numpy ndarray). | ||
With single-replica checkpointing the entirety of each jax.Shard is owned by | ||
exactly one replica. (Currently the only option.) | ||
""" | ||
|
||
replica_id: int | ||
index: Index | ||
data: jax.Array | np.ndarray | ||
|
||
@property | ||
def is_on_host(self): | ||
return isinstance(self.data, np.ndarray) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ReplicaSlices: | ||
"""ReplicaSlices. | ||
ReplicaSlices groups all the sliced data of one jax.Array that a replica is | ||
uniquely responsible for. Slices are either all on-device or all on-host. | ||
""" | ||
|
||
global_shape: Shape | ||
local_shape: Shape | ||
sharding: jax.sharding.Sharding | ||
dtype: np.dtype | ||
is_on_host: bool | ||
replica_slices: list[ReplicaSlice] | ||
|
||
def __post_init__(self): | ||
if not all( | ||
rslice.is_on_host == self.is_on_host for rslice in self.replica_slices | ||
): | ||
raise ValueError(f'Inconsistent is_on_host in {self!r}') | ||
|
||
@property | ||
def nbytes(self) -> int: | ||
slice_nbytes = math.prod(self.local_shape) * self.dtype.itemsize | ||
return slice_nbytes * len(self.replica_slices) | ||
|
||
def to_fragments(self) -> fragments.Fragments: | ||
"""Converts replica slices to fragments.""" | ||
assert self.is_on_host | ||
result = fragments.Fragments( | ||
shape=self.global_shape, | ||
dtype=self.dtype, | ||
fragments=[ | ||
fragments.Fragment( | ||
index=numpy_utils.resolve_slice( | ||
rslice.index, self.global_shape | ||
), | ||
value=rslice.data, | ||
) | ||
for rslice in self.replica_slices | ||
], | ||
) | ||
if result.fragments: | ||
fragments.validate_fragments_can_be_stacked(result) | ||
if not result.is_degenerate(): | ||
assert self.local_shape == result.fragments[0].shape | ||
return result | ||
|
||
|
||
def get_replica_slices( | ||
arr: jax.Array, | ||
replica_id: Optional[int], | ||
) -> ReplicaSlices: | ||
"""Returns the replica slices a given replica is responsible for. | ||
Does not transfer allocate or transfer any data. | ||
Args: | ||
arr: The jax.Array to get replica slices for. | ||
replica_id: Configured replica_id. | ||
Returns: | ||
ReplicaSlices object. | ||
""" | ||
Result = tuple[list[ReplicaSlice], Shape] | ||
shard0 = arr.addressable_shards[0] | ||
|
||
# single-replica: a single replica saves an entire shard. | ||
def pick_single_replica() -> Result: | ||
# Omitting the replica id just picks the first addressable shard's replica | ||
# id so that the process writes each of its addressable shards exactly | ||
# once. (This is the desired behavior for local checkpointing.) | ||
target_replica_id = replica_id or shard0.replica_id | ||
rslices = [ | ||
ReplicaSlice( | ||
replica_id=shard.replica_id, | ||
index=shard.index, | ||
data=shard.data, | ||
) | ||
for shard in arr.addressable_shards | ||
if shard.replica_id == target_replica_id | ||
] | ||
local_shape = shard0.data.shape | ||
return rslices, local_shape | ||
|
||
shards_info = ', '.join([ | ||
f'Shard(index={shard.index}, replica_id={shard.replica_id})' | ||
for shard in arr.addressable_shards | ||
]) | ||
logging.vlog( | ||
1, | ||
'[process=%d] get_replica_slices: replica_id=%d, shards=[%s]', | ||
multihost.process_index(), | ||
replica_id, | ||
shards_info, | ||
) | ||
|
||
# In order for all processes to agree on the right serialization metadata | ||
# we want to compute the correct local shape regardless of whether there | ||
# are any replica slices to save locally. | ||
rslices, local_shape = pick_single_replica() | ||
return ReplicaSlices( | ||
global_shape=arr.shape, | ||
local_shape=local_shape, | ||
sharding=arr.sharding, | ||
dtype=arr.dtype, | ||
is_on_host=False, | ||
replica_slices=rslices, | ||
) | ||
|
||
|
||
def transfer_arrays_to_host( | ||
arrays: Sequence[jax.Array], | ||
replica_id: Optional[int], | ||
*, | ||
enable_pinned_host_transfer: bool = True, | ||
) -> Sequence[ReplicaSlices]: | ||
"""Transfers arrays to host memory. | ||
Transfers jax.Arrays to host memory and returns all the fragments to be | ||
serialized by the given replica, along with local shape. Blocks until | ||
completion. | ||
Args: | ||
arrays: The jax.Arrays to transfer. | ||
replica_id: Configured replica_id. | ||
enable_pinned_host_transfer: Whether to allow transfer to pinned host | ||
memory. | ||
Returns: | ||
ReplicaSlices objects, in host memory. | ||
""" | ||
|
||
def use_pinned_host_transfer(device: jax.Device): | ||
has_pinned_host = any( | ||
m.kind == 'pinned_host' for m in device.addressable_memories() | ||
) | ||
return ( | ||
enable_pinned_host_transfer | ||
and has_pinned_host | ||
and jax._src.config.enable_memories.value # pylint: disable=protected-access | ||
) | ||
|
||
def async_transfer_slice( | ||
rslice: ReplicaSlice, | ||
) -> tuple[ReplicaSlice, jax.Array]: | ||
assert not rslice.is_on_host | ||
data = rslice.data | ||
assert isinstance(data, jax.Array) | ||
device = data.device | ||
# Start the asynchronous device-to-host copy | ||
if use_pinned_host_transfer(device): | ||
# If available, transfer to pinned host memory | ||
data = jax.device_put( | ||
data, | ||
jax.sharding.SingleDeviceSharding(device, memory_kind='pinned_host'), | ||
) | ||
else: | ||
data.copy_to_host_async() | ||
return rslice, data | ||
|
||
# Gather the replica slices to be saved for each array. | ||
rslices_per_array = [get_replica_slices(arr, replica_id) for arr in arrays] | ||
# Kick off transfers for all replica slices to be saved. | ||
transfers_per_array = [ | ||
[async_transfer_slice(rslice) for rslice in rslices.replica_slices] | ||
for rslices in rslices_per_array | ||
] | ||
# Wait for all the transferred data to be ready. | ||
return [ | ||
dataclasses.replace( | ||
rslices, | ||
is_on_host=True, | ||
replica_slices=[ | ||
dataclasses.replace( | ||
rslice_on_device, | ||
# Conversion to numpy arrays forces block_until_ready. | ||
data=np.asarray(data), | ||
) | ||
for rslice_on_device, data in transfers | ||
], | ||
) | ||
for rslices, transfers in zip(rslices_per_array, transfers_per_array) | ||
] |
117 changes: 117 additions & 0 deletions
117
checkpoint/orbax/checkpoint/_src/serialization/replica_slices_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
import jax | ||
import numpy as np | ||
from orbax.checkpoint._src.serialization import replica_slices | ||
|
||
|
||
def is_pow_of_two(n): | ||
while n > 1: | ||
n, rem = divmod(n, 2) | ||
if rem == 1: | ||
return False | ||
return True | ||
|
||
|
||
def make_multi_device_array(shape, partitioned): | ||
"""Creates a partially- or fully-replicated array.""" | ||
num_devices = len(jax.devices()) | ||
assert num_devices >= 4 | ||
assert is_pow_of_two(num_devices) | ||
mesh = jax.sharding.Mesh( | ||
np.asarray(jax.devices()).reshape((2, num_devices // 2)), | ||
('x', 'y'), | ||
) | ||
if partitioned: | ||
# partially-replicated (partitioned dimension 0 along mesh axis x) | ||
spec = jax.sharding.PartitionSpec('x') | ||
num_partitions = 2 | ||
num_replicas = num_devices // 2 | ||
else: | ||
# fully-replicated | ||
spec = jax.sharding.PartitionSpec() | ||
num_partitions = 1 | ||
num_replicas = num_devices | ||
sharding = jax.sharding.NamedSharding(mesh, spec) | ||
|
||
x = jax.random.normal(jax.random.PRNGKey(0), shape) | ||
data = jax.device_put(x, sharding) | ||
|
||
return data, num_partitions, num_replicas | ||
|
||
|
||
@parameterized.product(partitioned=[False, True]) | ||
class ReplicaSlicesTest(parameterized.TestCase): | ||
|
||
def test_get_replica_slices_single_replica(self, partitioned): | ||
arr, num_partitions, num_replicas = make_multi_device_array( | ||
(64, 64), | ||
partitioned=partitioned, | ||
) | ||
|
||
# Using an addressable replica_id yields that replica. | ||
for replica_id in range(num_replicas): | ||
rslices = replica_slices.get_replica_slices( | ||
arr, replica_id=replica_id | ||
).replica_slices | ||
self.assertLen(rslices, num_partitions) | ||
for rslice in rslices: | ||
self.assertEqual(rslice.replica_id, replica_id) | ||
|
||
# Omitting replica_id yields _some_ replica. | ||
rslices = replica_slices.get_replica_slices( | ||
arr, replica_id=None | ||
).replica_slices | ||
self.assertLen(rslices, num_partitions) | ||
for rslice in rslices: | ||
self.assertEqual(rslice.replica_id, rslices[0].replica_id) | ||
|
||
# Using an unaddressable replica_id yields nothing. | ||
rslices = replica_slices.get_replica_slices( | ||
arr, | ||
replica_id=-1, | ||
).replica_slices | ||
self.assertEmpty(rslices) | ||
|
||
def test_transfer(self, partitioned): | ||
arr, num_partitions, _ = make_multi_device_array( | ||
(64, 64), | ||
partitioned=partitioned, | ||
) | ||
replica0_shards = [ | ||
shard for shard in arr.addressable_shards if shard.replica_id == 0 | ||
] | ||
|
||
rslices = replica_slices.transfer_arrays_to_host([arr], replica_id=0)[ | ||
0 | ||
].replica_slices | ||
self.assertLen(rslices, num_partitions) | ||
self.assertEqual(len(rslices), len(replica0_shards)) | ||
|
||
index_start = lambda x: x.index[0].start or 0 | ||
rslices = sorted(rslices, key=index_start) | ||
replica0_shards = sorted(replica0_shards, key=index_start) | ||
|
||
for rslice, replica0_shard in zip(rslices, replica0_shards): | ||
self.assertTrue(rslice.is_on_host) | ||
self.assertIsInstance(rslice.data, np.ndarray) | ||
self.assertEqual(rslice.index, replica0_shard.index) | ||
np.testing.assert_array_equal(rslice.data, replica0_shard.data) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
Oops, something went wrong.