From 8ddca74435e0d85dfe05956dcca1ffe6242bb9e4 Mon Sep 17 00:00:00 2001 From: shunichi Date: Tue, 27 Feb 2024 09:27:35 +0000 Subject: [PATCH] Fix ppo tuple support --- nnabla_rl/algorithms/ppo.py | 50 +++++++++++++++++++++------------- tests/algorithms/test_ppo.py | 52 +++++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 19 deletions(-) diff --git a/nnabla_rl/algorithms/ppo.py b/nnabla_rl/algorithms/ppo.py index 96a28043..4dfc49f8 100644 --- a/nnabla_rl/algorithms/ppo.py +++ b/nnabla_rl/algorithms/ppo.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ import threading as th from collections import namedtuple from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Union +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import gym import numpy as np @@ -645,23 +645,15 @@ def _compute_action(self, s, *, begin_of_episode=False): return action, info def _fill_result(self, experiences, v_targets, advantages): - def array_and_dtype(mp_arrays_item): - return mp_arrays_item[0], mp_arrays_item[2] - - def _np_to_mp_array(np_array, mp_array): - if isinstance(np_array, tuple) and isinstance(mp_array, tuple): - (np_to_mp_array(np_ary, *array_and_dtype(mp_ary)) for np_ary, mp_ary in zip(np_array, mp_array)) - else: - np_to_mp_array(np_array, *array_and_dtype(mp_array)) (s, a, r, non_terminal, s_next, log_prob) = marshal_experiences(experiences) - _np_to_mp_array(s, self._mp_arrays.state) - _np_to_mp_array(a, self._mp_arrays.action) - _np_to_mp_array(r, self._mp_arrays.reward) - _np_to_mp_array(non_terminal, self._mp_arrays.non_terminal) - _np_to_mp_array(s_next, self._mp_arrays.next_state) - _np_to_mp_array(log_prob, self._mp_arrays.log_prob) - _np_to_mp_array(v_targets, self._mp_arrays.v_target) - _np_to_mp_array(advantages, self._mp_arrays.advantage) + _copy_np_array_to_mp_array(s, self._mp_arrays.state) + _copy_np_array_to_mp_array(a, self._mp_arrays.action) + _copy_np_array_to_mp_array(r, self._mp_arrays.reward) + _copy_np_array_to_mp_array(non_terminal, self._mp_arrays.non_terminal) + _copy_np_array_to_mp_array(s_next, self._mp_arrays.next_state) + _copy_np_array_to_mp_array(log_prob, self._mp_arrays.log_prob) + _copy_np_array_to_mp_array(v_targets, self._mp_arrays.v_target) + _copy_np_array_to_mp_array(advantages, self._mp_arrays.advantage) def _update_params(self, src, dest): copy_params_to_mp_arrays(src, dest) @@ -707,3 +699,25 @@ def _prepare_action_mp_array(self, action_space, env_info): action_mp_array = mp_array_from_np_array( np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) return (action_mp_array, action_mp_array_shape, action_space.dtype) + + +def _copy_np_array_to_mp_array( + np_array: Union[np.ndarray, Tuple[np.ndarray]], + mp_array_shape_type: Union[Tuple[np.ndarray, Tuple[int], np.dtype], Tuple[Tuple[np.ndarray, Tuple[int], np.dtype]]], +): + """Copy numpy array to multiprocessing array. + + Args: + np_array (Union[np.ndarray, Tuple[np.ndarray]]): copy source of numpy array. + mp_array_shape_type + (Union[Tuple[np.ndarray, Tuple[int], np.dtype], Tuple[Tuple[np.ndarray, Tuple[int], np.dtype]]]): + copy target of multiprocessing array, shape and type. + """ + if isinstance(np_array, tuple) and isinstance(mp_array_shape_type[0], tuple): + assert len(np_array) == len(mp_array_shape_type) + for np_ary, mp_ary_shape_type in zip(np_array, mp_array_shape_type): + np_to_mp_array(np_ary, mp_ary_shape_type[0], mp_ary_shape_type[2]) + elif isinstance(np_array, np.ndarray) and isinstance(mp_array_shape_type[0], np.ndarray): + np_to_mp_array(np_array, mp_array_shape_type[0], mp_array_shape_type[2]) + else: + raise ValueError("Invalid pair of np_array and mp_array!") diff --git a/tests/algorithms/test_ppo.py b/tests/algorithms/test_ppo.py index dfc5513e..1ca82548 100644 --- a/tests/algorithms/test_ppo.py +++ b/tests/algorithms/test_ppo.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,9 +20,11 @@ import nnabla.functions as NF import nnabla_rl.algorithms as A import nnabla_rl.environments as E +from nnabla_rl.algorithms.ppo import _copy_np_array_to_mp_array from nnabla_rl.builders import ModelBuilder from nnabla_rl.distributions import Gaussian from nnabla_rl.models import StochasticPolicy, VFunction +from nnabla_rl.utils.multiprocess import mp_array_from_np_array, mp_to_np_array class TupleStateActor(StochasticPolicy): @@ -168,6 +170,54 @@ def test_latest_iteration_state(self): assert latest_iteration_state['scalar']['pi_loss'] == 0. assert latest_iteration_state['scalar']['v_loss'] == 1. + def test_copy_np_array_to_mp_array(self): + shape = (10, 9, 8, 7) + mp_array_shape_type = (mp_array_from_np_array(np.random.uniform(size=shape)), shape, np.float64) + + test_array = np.random.uniform(size=shape) + before_copying = mp_to_np_array(mp_array_shape_type[0], shape, dtype=mp_array_shape_type[2]) + assert not np.allclose(before_copying, test_array) + + _copy_np_array_to_mp_array(test_array, mp_array_shape_type) + + after_copying = mp_to_np_array(mp_array_shape_type[0], shape, dtype=mp_array_shape_type[2]) + assert np.allclose(after_copying, test_array) + + def test_copy_tuple_np_array_to_tuple_mp_array_shape_type(self): + shape = ((10, 9, 8, 7), (6, 5, 4, 3)) + tuple_mp_array_shape_type = tuple( + [(mp_array_from_np_array(np.random.uniform(size=s)), shape, np.float64) for s in shape] + ) + tuple_test_array = tuple([np.random.uniform(size=s) for s in shape]) + + for mp_ary_shape_type, s, test_ary in zip(tuple_mp_array_shape_type, shape, tuple_test_array): + before_copying = mp_to_np_array(mp_ary_shape_type[0], s, dtype=mp_ary_shape_type[2]) + assert not np.allclose(before_copying, test_ary) + + _copy_np_array_to_mp_array(tuple_test_array, tuple_mp_array_shape_type) + + for mp_ary_shape_type, s, test_ary in zip(tuple_mp_array_shape_type, shape, tuple_test_array): + after_copying = mp_to_np_array(mp_ary_shape_type[0], s, dtype=mp_ary_shape_type[2]) + assert np.allclose(after_copying, test_ary) + + def test_copy_np_array_to_tuple_mp_array_shape_type(self): + shape = ((10, 9, 8, 7), (6, 5, 4, 3)) + tuple_mp_array_shape_type = tuple( + [(mp_array_from_np_array(np.random.uniform(size=s)), shape, np.float64) for s in shape] + ) + test_array = np.random.uniform(size=shape[0]) + + with pytest.raises(ValueError): + _copy_np_array_to_mp_array(test_array, tuple_mp_array_shape_type) + + def test_copy_tuple_np_array_to_mp_array_shape_type(self): + shape = ((10, 9, 8, 7), (6, 5, 4, 3)) + mp_array_shape_type = (mp_array_from_np_array(np.random.uniform(size=shape[0])), shape, np.float64) + tuple_test_array = tuple([np.random.uniform(size=s) for s in shape]) + + with pytest.raises(ValueError): + _copy_np_array_to_mp_array(tuple_test_array, mp_array_shape_type) + if __name__ == "__main__": pytest.main()