diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index e6772fc0..ac234d53 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -276,6 +276,9 @@ def forward(self, input_dict): return result else: selected_action = distr.sample() + choice = torch.rand_like(selected_action) > 0.1 + choice = choice.float() + selected_action = selected_action * choice + mu * (1 - choice) neglogp = self.neglogp(selected_action, mu, sigma, logstd) result = { 'neglogpacs' : torch.squeeze(neglogp),