diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 4342c2be3bd..e0b6a4e3ab4 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -247,7 +247,6 @@
with two or more terms.
[(#4642)](https://github.com/PennyLaneAI/pennylane/pull/4642)
-
* `_qfunc_output` has been removed from `QuantumScript`, as it is no longer necessary. There is
still a `_qfunc_output` property on `QNode` instances.
[(#4651)](https://github.com/PennyLaneAI/pennylane/pull/4651)
@@ -259,6 +258,9 @@
of the computed qubit operator, if imaginary components are smaller than a threshold.
[(#4639)](https://github.com/PennyLaneAI/pennylane/pull/4639)
+* Improved performance of `qml.data.load()` when partially loading a dataset
+ [(#4674)](https://github.com/PennyLaneAI/pennylane/pull/4674)
+
* Plots generated with the `pennylane.drawer.plot` style of `matplotlib.pyplot` now have black
axis labels and are generated at a default DPI of 300.
[(#4690)](https://github.com/PennyLaneAI/pennylane/pull/4690)
@@ -271,6 +273,7 @@
[(#4675)](https://github.com/PennyLaneAI/pennylane/pull/4675)
+
Breaking changes 💔
* The device test suite now converts device kwargs to integers or floats if they can be converted to integers or floats.
@@ -457,6 +460,7 @@
This release contains contributions from (in alphabetical order):
Utkarsh Azad,
+Jack Brown,
Stepan Fomichev,
Joana Fraxanet,
Diego Guala,
diff --git a/pennylane/data/base/hdf5.py b/pennylane/data/base/hdf5.py
index 45f4517e887..ac6401754b9 100644
--- a/pennylane/data/base/hdf5.py
+++ b/pennylane/data/base/hdf5.py
@@ -15,7 +15,7 @@
from collections.abc import MutableMapping
from pathlib import Path
-from typing import Literal, Optional, TypeVar, Union
+from typing import Literal, TypeVar, Union
from uuid import uuid4
from numpy.typing import ArrayLike
@@ -96,15 +96,19 @@ def copy_all(
dest.attrs.update(source.attrs)
-def open_hdf5_s3(s3_url: str, cache_dir: Optional[Path] = None) -> HDF5Group:
+def open_hdf5_s3(s3_url: str, *, block_size: int = 8388608) -> HDF5Group:
"""Uses ``fsspec`` module to open the HDF5 file at ``s3_url``.
This requires both ``fsspec`` and ``aiohttp`` to be installed.
+
+ Args:
+ s3_url: URL of dataset file in S3
+ block_size: Number of bytes to fetch per read operation. Larger values
+ may improve performance for large datasets
"""
- if cache_dir is not None:
- fs = fsspec.open(f"blockcache::{s3_url}", blockcache={"cache_storage": str(cache_dir)})
- else:
- fs = fsspec.open(s3_url)
+ # Tells fsspec to fetch data in 8MB chunks for faster loading
+ memory_cache_args = {"cache_type": "mmap", "block_size": block_size}
+ fs = fsspec.open(s3_url, **memory_cache_args)
return h5py.File(fs.open())
diff --git a/pennylane/data/data_manager/__init__.py b/pennylane/data/data_manager/__init__.py
index 1fdc70be900..877fb37df80 100644
--- a/pennylane/data/data_manager/__init__.py
+++ b/pennylane/data/data_manager/__init__.py
@@ -56,14 +56,14 @@ def _get_data_struct():
def _download_partial(
- s3_url: str, dest: Path, attributes: typing.Iterable[str], overwrite: bool
+ s3_url: str, dest: Path, attributes: typing.Iterable[str], overwrite: bool, block_size: int
) -> None:
"""Download only the requested attributes of the Dataset at ``s3_path`` into ``dest``.
If a dataset already exists at ``dest``, the attributes will be loaded into
the existing dataset.
"""
- remote_dataset = Dataset(open_hdf5_s3(s3_url))
+ remote_dataset = Dataset(open_hdf5_s3(s3_url, block_size=block_size))
remote_dataset.write(dest, "a", attributes, overwrite=overwrite)
del remote_dataset
@@ -73,6 +73,7 @@ def _download_dataset(
data_path: DataPath,
dest: Path,
attributes: Optional[typing.Iterable[str]],
+ block_size: int,
force: bool = False,
) -> None:
"""Downloads the dataset at ``data_path`` to ``dest``, optionally downloading
@@ -84,7 +85,7 @@ def _download_dataset(
s3_path = f"{S3_URL}/{url_safe_datapath}"
if attributes is not None:
- _download_partial(s3_path, dest, attributes, overwrite=force)
+ _download_partial(s3_path, dest, attributes, overwrite=force, block_size=block_size)
return
if dest.exists() and not force:
@@ -120,7 +121,7 @@ def load( # pylint: disable=too-many-arguments
folder_path: Path = Path("./datasets/"),
force: bool = False,
num_threads: int = 50,
- cache_dir: Optional[Path] = Path(".cache"),
+ block_size: int = 8388608,
**params: Union[ParamArg, str, List[str]],
):
r"""Downloads the data if it is not already present in the directory and returns it as a list of
@@ -133,7 +134,9 @@ def load( # pylint: disable=too-many-arguments
folder_path (str) : Path to the directory used for saving datasets. Defaults to './datasets'
force (Bool) : Bool representing whether data has to be downloaded even if it is still present
num_threads (int) : The maximum number of threads to spawn while downloading files (1 thread per file)
- cache_dir (str): Directory used for HTTP caching. Defaults to '{folder_path}/.cache'
+ block_size (int) : The number of bytes to fetch per read operation when fetching datasets from S3.
+ Larger values may improve performance for large datasets, but will slow down small reads. Defaults
+ to 8MB
params (kwargs) : Keyword arguments exactly matching the parameters required for the data type.
Note that these are not optional
@@ -212,8 +215,6 @@ def load( # pylint: disable=too-many-arguments
_validate_attributes(data_struct, data_name, attributes)
folder_path = Path(folder_path)
- if cache_dir and not Path(cache_dir).is_absolute():
- cache_dir = folder_path / cache_dir
data_paths = [data_path for _, data_path in foldermap.find(data_name, **params)]
@@ -224,7 +225,14 @@ def load( # pylint: disable=too-many-arguments
with ThreadPoolExecutor(min(num_threads, len(dest_paths))) as pool:
futures = [
- pool.submit(_download_dataset, data_path, dest_path, attributes, force=force)
+ pool.submit(
+ _download_dataset,
+ data_path,
+ dest_path,
+ attributes,
+ force=force,
+ block_size=block_size,
+ )
for data_path, dest_path in zip(data_paths, dest_paths)
]
results = wait(futures, return_when=FIRST_EXCEPTION)
diff --git a/tests/data/base/test_hdf5.py b/tests/data/base/test_hdf5.py
index f2112c833b0..9d492475824 100644
--- a/tests/data/base/test_hdf5.py
+++ b/tests/data/base/test_hdf5.py
@@ -50,26 +50,15 @@ def patch_h5py(monkeypatch):
monkeypatch.setattr(hdf5, "h5py", MagicMock())
-@pytest.mark.parametrize(
- "kwargs, call_args, call_kwargs",
- [
- ({}, ("/bucket",), {}),
- (
- {"cache_dir": "/cache"},
- ("blockcache::/bucket",),
- {"blockcache": {"cache_storage": "/cache"}},
- ),
- ],
-)
-def test_open_hdf5_s3(
- mock_fsspec, kwargs, call_args, call_kwargs
-): # pylint: disable=redefined-outer-name
+def test_open_hdf5_s3(mock_fsspec): # pylint: disable=redefined-outer-name
"""Test that open_hdf5_s3 calls fsspec.open() with the expected arguments."""
- ret = hdf5.open_hdf5_s3("/bucket", **kwargs)
+ ret = hdf5.open_hdf5_s3("/bucket")
assert isinstance(ret, h5py.File)
- mock_fsspec.open.assert_called_once_with(*call_args, **call_kwargs)
+ mock_fsspec.open.assert_called_once_with(
+ "/bucket", **{"cache_type": "mmap", "block_size": 8 * (2**20)}
+ )
def test_copy_all_conflict_overwrite(tmp_path):
diff --git a/tests/data/data_manager/test_dataset_access.py b/tests/data/data_manager/test_dataset_access.py
index 5c797a63ab8..97a55cfeee9 100644
--- a/tests/data/data_manager/test_dataset_access.py
+++ b/tests/data/data_manager/test_dataset_access.py
@@ -234,7 +234,7 @@ def test_list_attributes(self):
@pytest.fixture
def mock_download_dataset(monkeypatch):
- def mock(data_path, dest, attributes, force):
+ def mock(data_path, dest, attributes, force, block_size):
dset = Dataset.open(Path(dest), "w")
dset.close()
@@ -262,6 +262,7 @@ def test_load(tmp_path, data_name, params, expect_paths):
dsets = pennylane.data.data_manager.load(
data_name=data_name,
folder_path=folder_path,
+ block_size=1,
**params,
)
@@ -289,9 +290,7 @@ def test_download_dataset_full(tmp_path):
using requests if all attributes are requested."""
pennylane.data.data_manager._download_dataset(
- "dataset/path",
- tmp_path / "dataset",
- attributes=None,
+ "dataset/path", tmp_path / "dataset", attributes=None, block_size=1
)
with open(tmp_path / "dataset", "rb") as f:
@@ -311,7 +310,7 @@ def test_download_dataset_full_already_exists(tmp_path, force, expect_data):
f.write(b"This is local data")
pennylane.data.data_manager._download_dataset(
- "dataset/path", tmp_path / "dataset", attributes=None, force=force
+ "dataset/path", tmp_path / "dataset", attributes=None, force=force, block_size=1
)
with open(tmp_path / "dataset", "rb") as f:
@@ -331,7 +330,7 @@ def test_download_dataset_partial(tmp_path, monkeypatch):
)
pennylane.data.data_manager._download_dataset(
- "dataset/path", tmp_path / "dataset", attributes=["x"]
+ "dataset/path", tmp_path / "dataset", attributes=["x"], block_size=1
)
local = Dataset.open(tmp_path / "dataset")
@@ -351,7 +350,9 @@ def test_download_dataset_escapes_url(_, mock_get_args, datapath, escaped):
dest = MagicMock()
dest.exists.return_value = False
- pennylane.data.data_manager._download_dataset(DataPath(datapath), dest=dest, attributes=None)
+ pennylane.data.data_manager._download_dataset(
+ DataPath(datapath), dest=dest, attributes=None, block_size=1
+ )
mock_get_args.assert_called_once()
assert mock_get_args.call_args[0] == (f"{S3_URL}/{escaped}",)
@@ -370,11 +371,11 @@ def test_download_dataset_escapes_url_partial(mock_download_partial, datapath, e
force = False
pennylane.data.data_manager._download_dataset(
- DataPath(datapath), dest=dest, attributes=attributes, force=force
+ DataPath(datapath), dest=dest, attributes=attributes, force=force, block_size=1
)
mock_download_partial.assert_called_once_with(
- f"{S3_URL}/{escaped}", dest, attributes, overwrite=force
+ f"{S3_URL}/{escaped}", dest, attributes, overwrite=force, block_size=1
)