Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error in 'qml.data.load()` when using 'full' parameter value #4663

Merged
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@
still a `_qfunc_output` property on `QNode` instances.
[(#4651)](https://github.com/PennyLaneAI/pennylane/pull/4651)

* `qml.data.load` properly handles parameters that come after `'full'`
[(#4663)](https://github.com/PennyLaneAI/pennylane/pull/4663)

<h3>Breaking changes 💔</h3>

* The device test suite now converts device kwargs to integers or floats if they can be converted to integers or floats.
Expand Down
28 changes: 20 additions & 8 deletions pennylane/data/data_manager/foldermap.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def find(

while curr:
curr_description, curr_level = curr.pop()

if param_arg == ParamArg.FULL:
next_params = curr_level
elif param_arg == ParamArg.DEFAULT:
Expand All @@ -131,18 +132,29 @@ def find(
else:
next_params = param_arg

try:
todo.extend(
for next_param in next_params:
try:
fmap_next = curr_level[next_param]
except KeyError:
continue

todo.append(
(
Description((*curr_description.items(), (param_name, next_param))),
curr_level[next_param],
fmap_next,
)
for next_param in next_params
)
except KeyError as exc:
raise ValueError(
f"{param_name} '{exc.args[0]}' is not available. Available values are: {list(curr_level)}"
) from exc

if len(todo) == 0:
# None of the parameters matched
param_arg_repr = (
repr([param_arg])
if isinstance(param_arg, (str, ParamArg))
else repr(list(param_arg))
)
raise ValueError(
f"{param_name} value(s) {param_arg_repr} are not available. Available values are: {list(curr_level)}"
)

curr, todo = todo, curr

Expand Down
62 changes: 59 additions & 3 deletions tests/data/data_manager/test_foldermap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Tests for the :class:`pennylane.data.data_manger.FolderMapView` class.
"""

import re

import pytest

from pennylane.data.data_manager import DEFAULT, FULL, DataPath
Expand Down Expand Up @@ -125,6 +127,38 @@ class TestFolderMapView:
),
],
),
(
{"missing_default": DEFAULT, "molname": "O2", "basis": FULL, "bondlength": ["0.6"]},
[
(
{"molname": "O2", "basis": "STO-3G", "bondlength": "0.6"},
"qchem/O2/STO-3G/0.6.h5",
),
],
),
(
{
"missing_default": DEFAULT,
"molname": "O2",
"basis": FULL,
"bondlength": ["0.6", "200"],
},
brownj85 marked this conversation as resolved.
Show resolved Hide resolved
[
(
{"molname": "O2", "basis": "STO-3G", "bondlength": "0.6"},
"qchem/O2/STO-3G/0.6.h5",
),
],
),
(
{"missing_default": FULL, "molname": "O2", "bondlength": ["0.6"]},
[
(
{"molname": "O2", "basis": "STO-3G", "bondlength": "0.6"},
"qchem/O2/STO-3G/0.6.h5",
),
],
),
],
)
def test_find(self, foldermap, kwds, expect): # pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -155,14 +189,36 @@ def test_find_missing_arg_no_default(self):
with pytest.raises(ValueError, match="No default available for parameter 'molname'"):
FolderMapView(FOLDERMAP).find("qchem")

def test_find_invalid_parameter(self):
@pytest.mark.parametrize(
"arg, error_fmt", [("Z3", repr(["Z3"])), (("Z3", "Z4"), repr(["Z3", "Z4"]))]
)
def test_find_invalid_parameter(self, arg, error_fmt):
"""Test that a ValueError is raised when a parameter provided
does not exist."""

with pytest.raises(
ValueError, match=r"molname 'Z3' is not available. Available values are: \['O2', 'H2'\]"
ValueError,
match=re.escape(
f"molname value(s) {error_fmt} are not available. Available values are: ['O2', 'H2']"
),
):
FolderMapView(FOLDERMAP).find("qchem", molname="Z3")
FolderMapView(FOLDERMAP).find("qchem", molname=arg)

@pytest.mark.parametrize("basis", [FULL, DEFAULT])
def test_find_invalid_parameters_after_full_default(self, basis):
"""Test that a ValueError is raised when a parameter provided
does not exist, after a 'full' or 'default' parameter has been provided for a
higher-priority parameter."""

with pytest.raises(
ValueError,
match=(
r"bondlength value\(s\) \['0.20', '200'\] are not available. Available values are: \['0.5', '0.6'\]"
),
):
FolderMapView(FOLDERMAP).find(
"qchem", molname="O2", basis=basis, bondlength=["0.20", "200"]
)

@pytest.mark.parametrize(
"init, key, expect",
Expand Down
Loading