Skip to content

Commit

Permalink
PCN fix:
Browse files Browse the repository at this point in the history
At execution time always select an action with the highest confidence
  • Loading branch information
vaidas-sl committed Nov 1, 2023
1 parent d5bc675 commit 839b69b
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions morl_baselines/multi_policy/pcn/pcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,22 +254,25 @@ def _choose_commands(self, num_episodes: int):
desired_return = np.float32(desired_return)
return desired_return, desired_horizon

def _act(self, obs: np.ndarray, desired_return, desired_horizon) -> int:
def _act(self, obs: np.ndarray, desired_return, desired_horizon, eval_mode = False) -> int:
log_probs = self.model(
th.tensor([obs]).float().to(self.device),
th.tensor([desired_return]).float().to(self.device),
th.tensor([desired_horizon]).unsqueeze(1).float().to(self.device),
)
log_probs = log_probs.detach().cpu().numpy()[0]
action = self.np_random.choice(np.arange(len(log_probs)), p=np.exp(log_probs))
if eval_mode:
action = np.argmax(log_probs)
else:
action = self.np_random.choice(np.arange(len(log_probs)), p=np.exp(log_probs))
return action

def _run_episode(self, env, desired_return, desired_horizon, max_return):
def _run_episode(self, env, desired_return, desired_horizon, max_return, eval_mode = False):
transitions = []
obs, _ = env.reset()
done = False
while not done:
action = self._act(obs, desired_return, desired_horizon)
action = self._act(obs, desired_return, desired_horizon, eval_mode)
n_obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated

Expand Down Expand Up @@ -298,7 +301,7 @@ def set_desired_return_and_horizon(self, desired_return: np.ndarray, desired_hor

def eval(self, obs, w=None):
"""Evaluate policy action for a given observation."""
return self._act(obs, self.desired_return, self.desired_horizon)
return self._act(obs, self.desired_return, self.desired_horizon, eval_mode=True)

def evaluate(self, env, max_return, n=10):
"""Evaluate policy in the given environment."""
Expand All @@ -309,7 +312,7 @@ def evaluate(self, env, max_return, n=10):
horizons = np.float32(horizons)
e_returns = []
for i in range(n):
transitions = self._run_episode(env, returns[i], np.float32(horizons[i] - 2), max_return)
transitions = self._run_episode(env, returns[i], np.float32(horizons[i] - 2), max_return, eval_mode = True)
# compute return
for i in reversed(range(len(transitions) - 1)):
transitions[i].reward += self.gamma * transitions[i + 1].reward
Expand Down

0 comments on commit 839b69b

Please sign in to comment.