Skip to content

Commit

Permalink
Merge pull request #116 from sony/feature/20240227-fix-ppo-tuple-support
Browse files Browse the repository at this point in the history
Fix ppo tuple support
  • Loading branch information
ishihara-y authored Mar 27, 2024
2 parents 159a8f8 + 8ddca74 commit ade5c23
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 19 deletions.
50 changes: 32 additions & 18 deletions nnabla_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@
import threading as th
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union

import gym
import numpy as np
Expand Down Expand Up @@ -645,23 +645,15 @@ def _compute_action(self, s, *, begin_of_episode=False):
return action, info

def _fill_result(self, experiences, v_targets, advantages):
def array_and_dtype(mp_arrays_item):
return mp_arrays_item[0], mp_arrays_item[2]

def _np_to_mp_array(np_array, mp_array):
if isinstance(np_array, tuple) and isinstance(mp_array, tuple):
(np_to_mp_array(np_ary, *array_and_dtype(mp_ary)) for np_ary, mp_ary in zip(np_array, mp_array))
else:
np_to_mp_array(np_array, *array_and_dtype(mp_array))
(s, a, r, non_terminal, s_next, log_prob) = marshal_experiences(experiences)
_np_to_mp_array(s, self._mp_arrays.state)
_np_to_mp_array(a, self._mp_arrays.action)
_np_to_mp_array(r, self._mp_arrays.reward)
_np_to_mp_array(non_terminal, self._mp_arrays.non_terminal)
_np_to_mp_array(s_next, self._mp_arrays.next_state)
_np_to_mp_array(log_prob, self._mp_arrays.log_prob)
_np_to_mp_array(v_targets, self._mp_arrays.v_target)
_np_to_mp_array(advantages, self._mp_arrays.advantage)
_copy_np_array_to_mp_array(s, self._mp_arrays.state)
_copy_np_array_to_mp_array(a, self._mp_arrays.action)
_copy_np_array_to_mp_array(r, self._mp_arrays.reward)
_copy_np_array_to_mp_array(non_terminal, self._mp_arrays.non_terminal)
_copy_np_array_to_mp_array(s_next, self._mp_arrays.next_state)
_copy_np_array_to_mp_array(log_prob, self._mp_arrays.log_prob)
_copy_np_array_to_mp_array(v_targets, self._mp_arrays.v_target)
_copy_np_array_to_mp_array(advantages, self._mp_arrays.advantage)

def _update_params(self, src, dest):
copy_params_to_mp_arrays(src, dest)
Expand Down Expand Up @@ -707,3 +699,25 @@ def _prepare_action_mp_array(self, action_space, env_info):
action_mp_array = mp_array_from_np_array(
np.empty(shape=action_mp_array_shape, dtype=action_space.dtype))
return (action_mp_array, action_mp_array_shape, action_space.dtype)


def _copy_np_array_to_mp_array(
np_array: Union[np.ndarray, Tuple[np.ndarray]],
mp_array_shape_type: Union[Tuple[np.ndarray, Tuple[int], np.dtype], Tuple[Tuple[np.ndarray, Tuple[int], np.dtype]]],
):
"""Copy numpy array to multiprocessing array.
Args:
np_array (Union[np.ndarray, Tuple[np.ndarray]]): copy source of numpy array.
mp_array_shape_type
(Union[Tuple[np.ndarray, Tuple[int], np.dtype], Tuple[Tuple[np.ndarray, Tuple[int], np.dtype]]]):
copy target of multiprocessing array, shape and type.
"""
if isinstance(np_array, tuple) and isinstance(mp_array_shape_type[0], tuple):
assert len(np_array) == len(mp_array_shape_type)
for np_ary, mp_ary_shape_type in zip(np_array, mp_array_shape_type):
np_to_mp_array(np_ary, mp_ary_shape_type[0], mp_ary_shape_type[2])
elif isinstance(np_array, np.ndarray) and isinstance(mp_array_shape_type[0], np.ndarray):
np_to_mp_array(np_array, mp_array_shape_type[0], mp_array_shape_type[2])
else:
raise ValueError("Invalid pair of np_array and mp_array!")
52 changes: 51 additions & 1 deletion tests/algorithms/test_ppo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,9 +20,11 @@
import nnabla.functions as NF
import nnabla_rl.algorithms as A
import nnabla_rl.environments as E
from nnabla_rl.algorithms.ppo import _copy_np_array_to_mp_array
from nnabla_rl.builders import ModelBuilder
from nnabla_rl.distributions import Gaussian
from nnabla_rl.models import StochasticPolicy, VFunction
from nnabla_rl.utils.multiprocess import mp_array_from_np_array, mp_to_np_array


class TupleStateActor(StochasticPolicy):
Expand Down Expand Up @@ -168,6 +170,54 @@ def test_latest_iteration_state(self):
assert latest_iteration_state['scalar']['pi_loss'] == 0.
assert latest_iteration_state['scalar']['v_loss'] == 1.

def test_copy_np_array_to_mp_array(self):
shape = (10, 9, 8, 7)
mp_array_shape_type = (mp_array_from_np_array(np.random.uniform(size=shape)), shape, np.float64)

test_array = np.random.uniform(size=shape)
before_copying = mp_to_np_array(mp_array_shape_type[0], shape, dtype=mp_array_shape_type[2])
assert not np.allclose(before_copying, test_array)

_copy_np_array_to_mp_array(test_array, mp_array_shape_type)

after_copying = mp_to_np_array(mp_array_shape_type[0], shape, dtype=mp_array_shape_type[2])
assert np.allclose(after_copying, test_array)

def test_copy_tuple_np_array_to_tuple_mp_array_shape_type(self):
shape = ((10, 9, 8, 7), (6, 5, 4, 3))
tuple_mp_array_shape_type = tuple(
[(mp_array_from_np_array(np.random.uniform(size=s)), shape, np.float64) for s in shape]
)
tuple_test_array = tuple([np.random.uniform(size=s) for s in shape])

for mp_ary_shape_type, s, test_ary in zip(tuple_mp_array_shape_type, shape, tuple_test_array):
before_copying = mp_to_np_array(mp_ary_shape_type[0], s, dtype=mp_ary_shape_type[2])
assert not np.allclose(before_copying, test_ary)

_copy_np_array_to_mp_array(tuple_test_array, tuple_mp_array_shape_type)

for mp_ary_shape_type, s, test_ary in zip(tuple_mp_array_shape_type, shape, tuple_test_array):
after_copying = mp_to_np_array(mp_ary_shape_type[0], s, dtype=mp_ary_shape_type[2])
assert np.allclose(after_copying, test_ary)

def test_copy_np_array_to_tuple_mp_array_shape_type(self):
shape = ((10, 9, 8, 7), (6, 5, 4, 3))
tuple_mp_array_shape_type = tuple(
[(mp_array_from_np_array(np.random.uniform(size=s)), shape, np.float64) for s in shape]
)
test_array = np.random.uniform(size=shape[0])

with pytest.raises(ValueError):
_copy_np_array_to_mp_array(test_array, tuple_mp_array_shape_type)

def test_copy_tuple_np_array_to_mp_array_shape_type(self):
shape = ((10, 9, 8, 7), (6, 5, 4, 3))
mp_array_shape_type = (mp_array_from_np_array(np.random.uniform(size=shape[0])), shape, np.float64)
tuple_test_array = tuple([np.random.uniform(size=s) for s in shape])

with pytest.raises(ValueError):
_copy_np_array_to_mp_array(tuple_test_array, mp_array_shape_type)


if __name__ == "__main__":
pytest.main()

0 comments on commit ade5c23

Please sign in to comment.