Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of qml.data.load() when partially loading a dataset #4674

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@
with two or more terms.
[(#4642)](https://github.com/PennyLaneAI/pennylane/pull/4642)


* `_qfunc_output` has been removed from `QuantumScript`, as it is no longer necessary. There is
still a `_qfunc_output` property on `QNode` instances.
[(#4651)](https://github.com/PennyLaneAI/pennylane/pull/4651)
Expand All @@ -259,6 +258,9 @@
of the computed qubit operator, if imaginary components are smaller than a threshold.
[(#4639)](https://github.com/PennyLaneAI/pennylane/pull/4639)

* Improved performance of `qml.data.load()` when partially loading a dataset
[(#4674)](https://github.com/PennyLaneAI/pennylane/pull/4674)

* Plots generated with the `pennylane.drawer.plot` style of `matplotlib.pyplot` now have black
axis labels and are generated at a default DPI of 300.
[(#4690)](https://github.com/PennyLaneAI/pennylane/pull/4690)
Expand All @@ -271,6 +273,7 @@
[(#4675)](https://github.com/PennyLaneAI/pennylane/pull/4675)



<h3>Breaking changes 💔</h3>

* The device test suite now converts device kwargs to integers or floats if they can be converted to integers or floats.
Expand Down Expand Up @@ -457,6 +460,7 @@
This release contains contributions from (in alphabetical order):

Utkarsh Azad,
Jack Brown,
Stepan Fomichev,
Joana Fraxanet,
Diego Guala,
Expand Down
16 changes: 10 additions & 6 deletions pennylane/data/base/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from collections.abc import MutableMapping
from pathlib import Path
from typing import Literal, Optional, TypeVar, Union
from typing import Literal, TypeVar, Union
from uuid import uuid4

from numpy.typing import ArrayLike
Expand Down Expand Up @@ -96,15 +96,19 @@ def copy_all(
dest.attrs.update(source.attrs)


def open_hdf5_s3(s3_url: str, cache_dir: Optional[Path] = None) -> HDF5Group:
def open_hdf5_s3(s3_url: str, *, block_size: int = 8388608) -> HDF5Group:
"""Uses ``fsspec`` module to open the HDF5 file at ``s3_url``.

This requires both ``fsspec`` and ``aiohttp`` to be installed.

Args:
s3_url: URL of dataset file in S3
block_size: Number of bytes to fetch per read operation. Larger values
may improve performance for large datasets
"""

if cache_dir is not None:
fs = fsspec.open(f"blockcache::{s3_url}", blockcache={"cache_storage": str(cache_dir)})
else:
fs = fsspec.open(s3_url)
# Tells fsspec to fetch data in 8MB chunks for faster loading
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
memory_cache_args = {"cache_type": "mmap", "block_size": block_size}
fs = fsspec.open(s3_url, **memory_cache_args)

return h5py.File(fs.open())
24 changes: 16 additions & 8 deletions pennylane/data/data_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def _get_data_struct():


def _download_partial(
s3_url: str, dest: Path, attributes: typing.Iterable[str], overwrite: bool
s3_url: str, dest: Path, attributes: typing.Iterable[str], overwrite: bool, block_size: int
) -> None:
"""Download only the requested attributes of the Dataset at ``s3_path`` into ``dest``.
If a dataset already exists at ``dest``, the attributes will be loaded into
the existing dataset.
"""

remote_dataset = Dataset(open_hdf5_s3(s3_url))
remote_dataset = Dataset(open_hdf5_s3(s3_url, block_size=block_size))
remote_dataset.write(dest, "a", attributes, overwrite=overwrite)

del remote_dataset
Expand All @@ -73,6 +73,7 @@ def _download_dataset(
data_path: DataPath,
dest: Path,
attributes: Optional[typing.Iterable[str]],
block_size: int,
force: bool = False,
) -> None:
"""Downloads the dataset at ``data_path`` to ``dest``, optionally downloading
Expand All @@ -84,7 +85,7 @@ def _download_dataset(
s3_path = f"{S3_URL}/{url_safe_datapath}"

if attributes is not None:
_download_partial(s3_path, dest, attributes, overwrite=force)
_download_partial(s3_path, dest, attributes, overwrite=force, block_size=block_size)
return

if dest.exists() and not force:
Expand Down Expand Up @@ -120,7 +121,7 @@ def load( # pylint: disable=too-many-arguments
folder_path: Path = Path("./datasets/"),
force: bool = False,
num_threads: int = 50,
cache_dir: Optional[Path] = Path(".cache"),
block_size: int = 8388608,
**params: Union[ParamArg, str, List[str]],
):
r"""Downloads the data if it is not already present in the directory and returns it as a list of
Expand All @@ -133,7 +134,9 @@ def load( # pylint: disable=too-many-arguments
folder_path (str) : Path to the directory used for saving datasets. Defaults to './datasets'
force (Bool) : Bool representing whether data has to be downloaded even if it is still present
num_threads (int) : The maximum number of threads to spawn while downloading files (1 thread per file)
cache_dir (str): Directory used for HTTP caching. Defaults to '{folder_path}/.cache'
block_size (int) : The number of bytes to fetch per read operation when fetching datasets from S3.
Larger values may improve performance for large datasets, but will slow down small reads. Defaults
to 8MB
params (kwargs) : Keyword arguments exactly matching the parameters required for the data type.
Note that these are not optional

Expand Down Expand Up @@ -212,8 +215,6 @@ def load( # pylint: disable=too-many-arguments
_validate_attributes(data_struct, data_name, attributes)

folder_path = Path(folder_path)
if cache_dir and not Path(cache_dir).is_absolute():
cache_dir = folder_path / cache_dir

data_paths = [data_path for _, data_path in foldermap.find(data_name, **params)]

Expand All @@ -224,7 +225,14 @@ def load( # pylint: disable=too-many-arguments

with ThreadPoolExecutor(min(num_threads, len(dest_paths))) as pool:
futures = [
pool.submit(_download_dataset, data_path, dest_path, attributes, force=force)
pool.submit(
_download_dataset,
data_path,
dest_path,
attributes,
force=force,
block_size=block_size,
)
for data_path, dest_path in zip(data_paths, dest_paths)
]
results = wait(futures, return_when=FIRST_EXCEPTION)
Expand Down
21 changes: 5 additions & 16 deletions tests/data/base/test_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,15 @@ def patch_h5py(monkeypatch):
monkeypatch.setattr(hdf5, "h5py", MagicMock())


@pytest.mark.parametrize(
"kwargs, call_args, call_kwargs",
[
({}, ("/bucket",), {}),
(
{"cache_dir": "/cache"},
("blockcache::/bucket",),
{"blockcache": {"cache_storage": "/cache"}},
),
],
)
def test_open_hdf5_s3(
mock_fsspec, kwargs, call_args, call_kwargs
): # pylint: disable=redefined-outer-name
def test_open_hdf5_s3(mock_fsspec): # pylint: disable=redefined-outer-name
"""Test that open_hdf5_s3 calls fsspec.open() with the expected arguments."""

ret = hdf5.open_hdf5_s3("/bucket", **kwargs)
ret = hdf5.open_hdf5_s3("/bucket")

assert isinstance(ret, h5py.File)
mock_fsspec.open.assert_called_once_with(*call_args, **call_kwargs)
mock_fsspec.open.assert_called_once_with(
"/bucket", **{"cache_type": "mmap", "block_size": 8 * (2**20)}
)


def test_copy_all_conflict_overwrite(tmp_path):
Expand Down
19 changes: 10 additions & 9 deletions tests/data/data_manager/test_dataset_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_list_attributes(self):

@pytest.fixture
def mock_download_dataset(monkeypatch):
def mock(data_path, dest, attributes, force):
def mock(data_path, dest, attributes, force, block_size):
dset = Dataset.open(Path(dest), "w")
dset.close()

Expand Down Expand Up @@ -262,6 +262,7 @@ def test_load(tmp_path, data_name, params, expect_paths):
dsets = pennylane.data.data_manager.load(
data_name=data_name,
folder_path=folder_path,
block_size=1,
**params,
)

Expand Down Expand Up @@ -289,9 +290,7 @@ def test_download_dataset_full(tmp_path):
using requests if all attributes are requested."""

pennylane.data.data_manager._download_dataset(
"dataset/path",
tmp_path / "dataset",
attributes=None,
"dataset/path", tmp_path / "dataset", attributes=None, block_size=1
)

with open(tmp_path / "dataset", "rb") as f:
Expand All @@ -311,7 +310,7 @@ def test_download_dataset_full_already_exists(tmp_path, force, expect_data):
f.write(b"This is local data")

pennylane.data.data_manager._download_dataset(
"dataset/path", tmp_path / "dataset", attributes=None, force=force
"dataset/path", tmp_path / "dataset", attributes=None, force=force, block_size=1
)

with open(tmp_path / "dataset", "rb") as f:
Expand All @@ -331,7 +330,7 @@ def test_download_dataset_partial(tmp_path, monkeypatch):
)

pennylane.data.data_manager._download_dataset(
"dataset/path", tmp_path / "dataset", attributes=["x"]
"dataset/path", tmp_path / "dataset", attributes=["x"], block_size=1
)

local = Dataset.open(tmp_path / "dataset")
Expand All @@ -351,7 +350,9 @@ def test_download_dataset_escapes_url(_, mock_get_args, datapath, escaped):
dest = MagicMock()
dest.exists.return_value = False

pennylane.data.data_manager._download_dataset(DataPath(datapath), dest=dest, attributes=None)
pennylane.data.data_manager._download_dataset(
DataPath(datapath), dest=dest, attributes=None, block_size=1
)

mock_get_args.assert_called_once()
assert mock_get_args.call_args[0] == (f"{S3_URL}/{escaped}",)
Expand All @@ -370,11 +371,11 @@ def test_download_dataset_escapes_url_partial(mock_download_partial, datapath, e
force = False

pennylane.data.data_manager._download_dataset(
DataPath(datapath), dest=dest, attributes=attributes, force=force
DataPath(datapath), dest=dest, attributes=attributes, force=force, block_size=1
)

mock_download_partial.assert_called_once_with(
f"{S3_URL}/{escaped}", dest, attributes, overwrite=force
f"{S3_URL}/{escaped}", dest, attributes, overwrite=force, block_size=1
)


Expand Down
Loading