diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 3869244b5..dcedec48b 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -5,6 +5,7 @@ import json import logging +import mmap import os import sys import warnings @@ -804,12 +805,34 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: Args: obj (Dict[str, Any]): The state. """ + # Set shared memory block to be 1024 characters long. This enables calling + # `load_state_dict` multiple times without needing to resize the shared memory block. + # Resizing the shared memory block is not possible, and closing the shared memory block + # and replacing it with a new one is causing great difficulties. + name = _get_path(self._shm_prefix_int, RESUME) - data = json.dumps(obj, sort_keys=True).encode('utf-8') + data = json.dumps(obj, sort_keys=True) + + len_needed = len(data) + # Note: mmap.PAGESIZE has a minimum size of 4096 bytes across systems. For reference, + # see the link below: + # https://en.wikipedia.org/wiki/Page_(computer_memory)#Multiple_page_sizes + if len_needed > mmap.PAGESIZE: + raise ValueError( + f'The StreamingDataset state dict for resumption is currently ', + f'allocated {mmap.PAGESIZE} bytes, insufficient to store the ', + f'state dict that was attempted to load in, which uses {len_needed} ', + f'bytes. Please increase the bytes allocated to the state dict by ', + f'changing the SharedMemory size parameter, set in this function.', + f'The state dict may also be corrupted. The state dict is: {data}.') # Some platforms choose to allocate chunks of memory based upon that platform's memory page # size, hence the exact size of the shared memory block that was returned may be larger # than what was requested. - self._resume_shm = SharedMemory(name=name, size=len(data)) + self._resume_shm = SharedMemory(name=name, size=mmap.PAGESIZE) + # Write a null byte at the end of the shared memory block so that we read in the state + # dict correctly in `_resume`. + data += '\0' + data = data.encode('utf-8') self._resume_shm.buf[:len(data)] = data def resample_streams( diff --git a/tests/test_shared.py b/tests/test_shared.py index 44d376d93..c3536a083 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -5,8 +5,10 @@ import pytest +from streaming.base import StreamingDataset from streaming.base.shared import get_shm_prefix from streaming.base.world import World +from tests.common.utils import convert_to_mds @pytest.mark.usefixtures('local_remote_dir') @@ -41,3 +43,107 @@ def test_same_local_remote_none(local_remote_dir: Tuple[str, str]): local, _ = local_remote_dir _, _ = get_shm_prefix(streams_local=[local], streams_remote=[None], world=World()) _, _ = get_shm_prefix(streams_local=[local], streams_remote=[None], world=World()) + + +@pytest.mark.parametrize('from_beginning', [True, False]) +@pytest.mark.usefixtures('local_remote_dir') +def test_load_get_state_dict_once(local_remote_dir: Tuple[str, str], from_beginning: bool): + local, remote = local_remote_dir + convert_to_mds(out_root=remote, + dataset_name='sequencedataset', + num_samples=117, + size_limit=1 << 8) + dataset = StreamingDataset(local=local, remote=remote) + + # Get the current dataset state dict + old_state_dict = dataset.state_dict(0, from_beginning) + assert old_state_dict is not None + + state_keys = list(old_state_dict.keys()) + + # Change the state dict and load it back to the dataset. + new_state_dict = old_state_dict.copy() + for key in state_keys: + new_state_dict[key] += 1 + dataset.load_state_dict(new_state_dict) + + new_loaded_state_dict = dataset.state_dict(0, from_beginning) + assert new_loaded_state_dict is not None + if from_beginning: + for key in state_keys: + if key == 'sample_in_epoch': + # If `from_beginning` is True, we expect sample_in_epoch to be 0. + assert new_loaded_state_dict[key] == 0 + else: + # All other fields in retrieved and loaded state dicts should match. + assert new_loaded_state_dict[key] == new_state_dict[key] + else: + # If `from_beginning` is False, retrieved and loaded state dicts should match completely. + assert new_loaded_state_dict == new_state_dict + + for key in state_keys: + if key == 'sample_in_epoch' and from_beginning: + # If `from_beginning` is True, we expect sample_in_epoch to be the same, 0. + assert new_loaded_state_dict[key] == old_state_dict[key] + else: + assert new_loaded_state_dict[key] == old_state_dict[key] + 1 + + +@pytest.mark.parametrize('iterations', [10]) +@pytest.mark.usefixtures('local_remote_dir') +def test_load_get_state_dict_multiple(local_remote_dir: Tuple[str, str], iterations: int): + local, remote = local_remote_dir + convert_to_mds(out_root=remote, + dataset_name='sequencedataset', + num_samples=117, + size_limit=1 << 8) + dataset = StreamingDataset(local=local, remote=remote) + + # Get the current dataset state dict + old_state_dict = dataset.state_dict(0, False) + assert old_state_dict is not None + + state_keys = list(old_state_dict.keys()) + + for _ in range(iterations): + # Change the state dict and load it back to the dataset. + new_state_dict = old_state_dict.copy() + for key in state_keys: + # If the epoch from the loaded state dict is -1, make sure that the new epoch + # is greater than -1. Otherwise, we will assume a stale resumption state, ignoring it. + if key == 'epoch' and new_state_dict[key] < 0: + new_state_dict[key] *= -5 + else: + new_state_dict[key] *= 5 + + dataset.load_state_dict(new_state_dict) + new_loaded_state_dict = dataset.state_dict(0, False) + + assert new_loaded_state_dict is not None + assert new_loaded_state_dict == new_state_dict + for key in state_keys: + # Ensure we check that epoch has been correctly updated, in case it was negative. + if key == 'epoch' and old_state_dict[key] < 0: + assert new_loaded_state_dict[key] == old_state_dict[key] * -5 + else: + assert new_loaded_state_dict[key] == old_state_dict[key] * 5 + + old_state_dict = new_loaded_state_dict + + +@pytest.mark.usefixtures('local_remote_dir') +def test_state_dict_too_large(local_remote_dir: Tuple[str, str]): + local, remote = local_remote_dir + convert_to_mds(out_root=remote, + dataset_name='sequencedataset', + num_samples=117, + size_limit=1 << 8) + dataset = StreamingDataset(local=local, remote=remote) + + # Make a state dict that is too large to fit in the allocated shared memory. + import mmap + key = 'a' * mmap.PAGESIZE + big_state_dict = {key: 1} + + with pytest.raises(ValueError, match='The StreamingDataset state dict*'): + dataset.load_state_dict(big_state_dict)