Skip to content

Commit

Permalink
Avoid raising exception when info has inconsistent values
Browse files Browse the repository at this point in the history
  • Loading branch information
ishihara-y committed Nov 22, 2023
1 parent 889bd58 commit 4ae9fc7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
13 changes: 9 additions & 4 deletions nnabla_rl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

import nnabla as nn
from nnabla_rl.logger import logger
from nnabla_rl.typing import TupledData

T = TypeVar('T')
Expand Down Expand Up @@ -77,10 +78,14 @@ def marshal_dict_experiences(dict_experiences: Sequence[Dict[str, Any]]) -> Dict
dict_of_list = list_of_dict_to_dict_of_list(dict_experiences)
marshaled_experiences = {}
for key, data in dict_of_list.items():
if isinstance(data[0], Dict):
marshaled_experiences.update({key: marshal_dict_experiences(data)})
else:
marshaled_experiences.update({key: add_axis_if_single_dim(np.asarray(data))})
try:
if isinstance(data[0], Dict):
marshaled_experiences.update({key: marshal_dict_experiences(data)})
else:
marshaled_experiences.update({key: add_axis_if_single_dim(np.asarray(data))})
except ValueError as e:
# do nothing
logger.warn(f'key: {key} contains inconsistent elements!. Details: {e}')
return marshaled_experiences


Expand Down
22 changes: 22 additions & 0 deletions tests/utils/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import pytest
from packaging.version import parse

import nnabla as nn
import nnabla_rl.environments as E
Expand Down Expand Up @@ -133,6 +134,27 @@ def test_marshal_triple_nested_dict_experiences(self):
np.testing.assert_allclose(np.asarray(key1_experiences), 1)
np.testing.assert_allclose(np.asarray(key2_experiences), 2)

def test_marashal_dict_experiences_with_inhomogeneous_part(self):
installed_numpy_version = parse(np.__version__)
numpy_version1_24 = parse('1.24.0')

if installed_numpy_version < numpy_version1_24:
# no need to test
return

experiences = {'key1': 1, 'key2': 2}
inhomgeneous_experiences = {'key1': np.empty(shape=(6, )), 'key2': 2}
dict_experiences = [{'key_parent': experiences}, {'key_parent': inhomgeneous_experiences}]

marshaled_experience = marshal_dict_experiences(dict_experiences)

assert 'key1' not in marshaled_experience['key_parent']

key2_experiences = marshaled_experience['key_parent']['key2']
assert key2_experiences.shape == (2, 1)

np.testing.assert_allclose(np.asarray(key2_experiences), 2)

def test_list_of_dict_to_dict_of_list(self):
list_of_dict = [{'key1': 1, 'key2': 2}, {'key1': 1, 'key2': 2}]
dict_of_list = list_of_dict_to_dict_of_list(list_of_dict)
Expand Down

0 comments on commit 4ae9fc7

Please sign in to comment.