diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 4342c2be3bd..e0b6a4e3ab4 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -247,7 +247,6 @@ with two or more terms. [(#4642)](https://github.com/PennyLaneAI/pennylane/pull/4642) - * `_qfunc_output` has been removed from `QuantumScript`, as it is no longer necessary. There is still a `_qfunc_output` property on `QNode` instances. [(#4651)](https://github.com/PennyLaneAI/pennylane/pull/4651) @@ -259,6 +258,9 @@ of the computed qubit operator, if imaginary components are smaller than a threshold. [(#4639)](https://github.com/PennyLaneAI/pennylane/pull/4639) +* Improved performance of `qml.data.load()` when partially loading a dataset + [(#4674)](https://github.com/PennyLaneAI/pennylane/pull/4674) + * Plots generated with the `pennylane.drawer.plot` style of `matplotlib.pyplot` now have black axis labels and are generated at a default DPI of 300. [(#4690)](https://github.com/PennyLaneAI/pennylane/pull/4690) @@ -271,6 +273,7 @@ [(#4675)](https://github.com/PennyLaneAI/pennylane/pull/4675) +

Breaking changes 💔

* The device test suite now converts device kwargs to integers or floats if they can be converted to integers or floats. @@ -457,6 +460,7 @@ This release contains contributions from (in alphabetical order): Utkarsh Azad, +Jack Brown, Stepan Fomichev, Joana Fraxanet, Diego Guala, diff --git a/pennylane/data/base/hdf5.py b/pennylane/data/base/hdf5.py index 45f4517e887..ac6401754b9 100644 --- a/pennylane/data/base/hdf5.py +++ b/pennylane/data/base/hdf5.py @@ -15,7 +15,7 @@ from collections.abc import MutableMapping from pathlib import Path -from typing import Literal, Optional, TypeVar, Union +from typing import Literal, TypeVar, Union from uuid import uuid4 from numpy.typing import ArrayLike @@ -96,15 +96,19 @@ def copy_all( dest.attrs.update(source.attrs) -def open_hdf5_s3(s3_url: str, cache_dir: Optional[Path] = None) -> HDF5Group: +def open_hdf5_s3(s3_url: str, *, block_size: int = 8388608) -> HDF5Group: """Uses ``fsspec`` module to open the HDF5 file at ``s3_url``. This requires both ``fsspec`` and ``aiohttp`` to be installed. + + Args: + s3_url: URL of dataset file in S3 + block_size: Number of bytes to fetch per read operation. Larger values + may improve performance for large datasets """ - if cache_dir is not None: - fs = fsspec.open(f"blockcache::{s3_url}", blockcache={"cache_storage": str(cache_dir)}) - else: - fs = fsspec.open(s3_url) + # Tells fsspec to fetch data in 8MB chunks for faster loading + memory_cache_args = {"cache_type": "mmap", "block_size": block_size} + fs = fsspec.open(s3_url, **memory_cache_args) return h5py.File(fs.open()) diff --git a/pennylane/data/data_manager/__init__.py b/pennylane/data/data_manager/__init__.py index 1fdc70be900..877fb37df80 100644 --- a/pennylane/data/data_manager/__init__.py +++ b/pennylane/data/data_manager/__init__.py @@ -56,14 +56,14 @@ def _get_data_struct(): def _download_partial( - s3_url: str, dest: Path, attributes: typing.Iterable[str], overwrite: bool + s3_url: str, dest: Path, attributes: typing.Iterable[str], overwrite: bool, block_size: int ) -> None: """Download only the requested attributes of the Dataset at ``s3_path`` into ``dest``. If a dataset already exists at ``dest``, the attributes will be loaded into the existing dataset. """ - remote_dataset = Dataset(open_hdf5_s3(s3_url)) + remote_dataset = Dataset(open_hdf5_s3(s3_url, block_size=block_size)) remote_dataset.write(dest, "a", attributes, overwrite=overwrite) del remote_dataset @@ -73,6 +73,7 @@ def _download_dataset( data_path: DataPath, dest: Path, attributes: Optional[typing.Iterable[str]], + block_size: int, force: bool = False, ) -> None: """Downloads the dataset at ``data_path`` to ``dest``, optionally downloading @@ -84,7 +85,7 @@ def _download_dataset( s3_path = f"{S3_URL}/{url_safe_datapath}" if attributes is not None: - _download_partial(s3_path, dest, attributes, overwrite=force) + _download_partial(s3_path, dest, attributes, overwrite=force, block_size=block_size) return if dest.exists() and not force: @@ -120,7 +121,7 @@ def load( # pylint: disable=too-many-arguments folder_path: Path = Path("./datasets/"), force: bool = False, num_threads: int = 50, - cache_dir: Optional[Path] = Path(".cache"), + block_size: int = 8388608, **params: Union[ParamArg, str, List[str]], ): r"""Downloads the data if it is not already present in the directory and returns it as a list of @@ -133,7 +134,9 @@ def load( # pylint: disable=too-many-arguments folder_path (str) : Path to the directory used for saving datasets. Defaults to './datasets' force (Bool) : Bool representing whether data has to be downloaded even if it is still present num_threads (int) : The maximum number of threads to spawn while downloading files (1 thread per file) - cache_dir (str): Directory used for HTTP caching. Defaults to '{folder_path}/.cache' + block_size (int) : The number of bytes to fetch per read operation when fetching datasets from S3. + Larger values may improve performance for large datasets, but will slow down small reads. Defaults + to 8MB params (kwargs) : Keyword arguments exactly matching the parameters required for the data type. Note that these are not optional @@ -212,8 +215,6 @@ def load( # pylint: disable=too-many-arguments _validate_attributes(data_struct, data_name, attributes) folder_path = Path(folder_path) - if cache_dir and not Path(cache_dir).is_absolute(): - cache_dir = folder_path / cache_dir data_paths = [data_path for _, data_path in foldermap.find(data_name, **params)] @@ -224,7 +225,14 @@ def load( # pylint: disable=too-many-arguments with ThreadPoolExecutor(min(num_threads, len(dest_paths))) as pool: futures = [ - pool.submit(_download_dataset, data_path, dest_path, attributes, force=force) + pool.submit( + _download_dataset, + data_path, + dest_path, + attributes, + force=force, + block_size=block_size, + ) for data_path, dest_path in zip(data_paths, dest_paths) ] results = wait(futures, return_when=FIRST_EXCEPTION) diff --git a/tests/data/base/test_hdf5.py b/tests/data/base/test_hdf5.py index f2112c833b0..9d492475824 100644 --- a/tests/data/base/test_hdf5.py +++ b/tests/data/base/test_hdf5.py @@ -50,26 +50,15 @@ def patch_h5py(monkeypatch): monkeypatch.setattr(hdf5, "h5py", MagicMock()) -@pytest.mark.parametrize( - "kwargs, call_args, call_kwargs", - [ - ({}, ("/bucket",), {}), - ( - {"cache_dir": "/cache"}, - ("blockcache::/bucket",), - {"blockcache": {"cache_storage": "/cache"}}, - ), - ], -) -def test_open_hdf5_s3( - mock_fsspec, kwargs, call_args, call_kwargs -): # pylint: disable=redefined-outer-name +def test_open_hdf5_s3(mock_fsspec): # pylint: disable=redefined-outer-name """Test that open_hdf5_s3 calls fsspec.open() with the expected arguments.""" - ret = hdf5.open_hdf5_s3("/bucket", **kwargs) + ret = hdf5.open_hdf5_s3("/bucket") assert isinstance(ret, h5py.File) - mock_fsspec.open.assert_called_once_with(*call_args, **call_kwargs) + mock_fsspec.open.assert_called_once_with( + "/bucket", **{"cache_type": "mmap", "block_size": 8 * (2**20)} + ) def test_copy_all_conflict_overwrite(tmp_path): diff --git a/tests/data/data_manager/test_dataset_access.py b/tests/data/data_manager/test_dataset_access.py index 5c797a63ab8..97a55cfeee9 100644 --- a/tests/data/data_manager/test_dataset_access.py +++ b/tests/data/data_manager/test_dataset_access.py @@ -234,7 +234,7 @@ def test_list_attributes(self): @pytest.fixture def mock_download_dataset(monkeypatch): - def mock(data_path, dest, attributes, force): + def mock(data_path, dest, attributes, force, block_size): dset = Dataset.open(Path(dest), "w") dset.close() @@ -262,6 +262,7 @@ def test_load(tmp_path, data_name, params, expect_paths): dsets = pennylane.data.data_manager.load( data_name=data_name, folder_path=folder_path, + block_size=1, **params, ) @@ -289,9 +290,7 @@ def test_download_dataset_full(tmp_path): using requests if all attributes are requested.""" pennylane.data.data_manager._download_dataset( - "dataset/path", - tmp_path / "dataset", - attributes=None, + "dataset/path", tmp_path / "dataset", attributes=None, block_size=1 ) with open(tmp_path / "dataset", "rb") as f: @@ -311,7 +310,7 @@ def test_download_dataset_full_already_exists(tmp_path, force, expect_data): f.write(b"This is local data") pennylane.data.data_manager._download_dataset( - "dataset/path", tmp_path / "dataset", attributes=None, force=force + "dataset/path", tmp_path / "dataset", attributes=None, force=force, block_size=1 ) with open(tmp_path / "dataset", "rb") as f: @@ -331,7 +330,7 @@ def test_download_dataset_partial(tmp_path, monkeypatch): ) pennylane.data.data_manager._download_dataset( - "dataset/path", tmp_path / "dataset", attributes=["x"] + "dataset/path", tmp_path / "dataset", attributes=["x"], block_size=1 ) local = Dataset.open(tmp_path / "dataset") @@ -351,7 +350,9 @@ def test_download_dataset_escapes_url(_, mock_get_args, datapath, escaped): dest = MagicMock() dest.exists.return_value = False - pennylane.data.data_manager._download_dataset(DataPath(datapath), dest=dest, attributes=None) + pennylane.data.data_manager._download_dataset( + DataPath(datapath), dest=dest, attributes=None, block_size=1 + ) mock_get_args.assert_called_once() assert mock_get_args.call_args[0] == (f"{S3_URL}/{escaped}",) @@ -370,11 +371,11 @@ def test_download_dataset_escapes_url_partial(mock_download_partial, datapath, e force = False pennylane.data.data_manager._download_dataset( - DataPath(datapath), dest=dest, attributes=attributes, force=force + DataPath(datapath), dest=dest, attributes=attributes, force=force, block_size=1 ) mock_download_partial.assert_called_once_with( - f"{S3_URL}/{escaped}", dest, attributes, overwrite=force + f"{S3_URL}/{escaped}", dest, attributes, overwrite=force, block_size=1 )