From 839b69b122e357503446049e0f1cdb49254caa5d Mon Sep 17 00:00:00 2001 From: vaidas-sl <55625200+vaidas-sl@users.noreply.github.com> Date: Wed, 1 Nov 2023 22:30:48 +0200 Subject: [PATCH] PCN fix: At execution time always select an action with the highest confidence --- morl_baselines/multi_policy/pcn/pcn.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/morl_baselines/multi_policy/pcn/pcn.py b/morl_baselines/multi_policy/pcn/pcn.py index b85d634e..adf2cf76 100644 --- a/morl_baselines/multi_policy/pcn/pcn.py +++ b/morl_baselines/multi_policy/pcn/pcn.py @@ -254,22 +254,25 @@ def _choose_commands(self, num_episodes: int): desired_return = np.float32(desired_return) return desired_return, desired_horizon - def _act(self, obs: np.ndarray, desired_return, desired_horizon) -> int: + def _act(self, obs: np.ndarray, desired_return, desired_horizon, eval_mode = False) -> int: log_probs = self.model( th.tensor([obs]).float().to(self.device), th.tensor([desired_return]).float().to(self.device), th.tensor([desired_horizon]).unsqueeze(1).float().to(self.device), ) log_probs = log_probs.detach().cpu().numpy()[0] - action = self.np_random.choice(np.arange(len(log_probs)), p=np.exp(log_probs)) + if eval_mode: + action = np.argmax(log_probs) + else: + action = self.np_random.choice(np.arange(len(log_probs)), p=np.exp(log_probs)) return action - def _run_episode(self, env, desired_return, desired_horizon, max_return): + def _run_episode(self, env, desired_return, desired_horizon, max_return, eval_mode = False): transitions = [] obs, _ = env.reset() done = False while not done: - action = self._act(obs, desired_return, desired_horizon) + action = self._act(obs, desired_return, desired_horizon, eval_mode) n_obs, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated @@ -298,7 +301,7 @@ def set_desired_return_and_horizon(self, desired_return: np.ndarray, desired_hor def eval(self, obs, w=None): """Evaluate policy action for a given observation.""" - return self._act(obs, self.desired_return, self.desired_horizon) + return self._act(obs, self.desired_return, self.desired_horizon, eval_mode=True) def evaluate(self, env, max_return, n=10): """Evaluate policy in the given environment.""" @@ -309,7 +312,7 @@ def evaluate(self, env, max_return, n=10): horizons = np.float32(horizons) e_returns = [] for i in range(n): - transitions = self._run_episode(env, returns[i], np.float32(horizons[i] - 2), max_return) + transitions = self._run_episode(env, returns[i], np.float32(horizons[i] - 2), max_return, eval_mode = True) # compute return for i in reversed(range(len(transitions) - 1)): transitions[i].reward += self.gamma * transitions[i + 1].reward