From a79fcc29c1e9f026dfc13919de59568566e675e7 Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Fri, 1 Dec 2023 13:11:23 -0500 Subject: [PATCH] Fixed masks with multi discrete space. (#265) Co-authored-by: Denys Makoviichuk --- rl_games/algos_torch/a2c_discrete.py | 2 - rl_games/algos_torch/models.py | 6 ++- rl_games/common/env_configurations.py | 12 +++-- rl_games/common/experience.py | 2 +- rl_games/configs/smac/v1/3m_torch_sa.yaml | 54 ++++++++++++++++++++ rl_games/configs/smac/v1/5m_vs_6m_sa.yaml | 60 +++++++++++++++++++++++ rl_games/envs/smac_env.py | 34 ++++++++++++- 7 files changed, 161 insertions(+), 9 deletions(-) create mode 100644 rl_games/configs/smac/v1/3m_torch_sa.yaml create mode 100644 rl_games/configs/smac/v1/5m_vs_6m_sa.yaml diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index fc1bda89..467d8d86 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -95,8 +95,6 @@ def get_masked_action_values(self, obs, action_masks): value = self.get_central_value(input_dict) res_dict['values'] = value - if self.is_multi_discrete: - action_masks = torch.cat(action_masks, dim=-1) res_dict['action_masks'] = action_masks return res_dict diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index e6772fc0..9c9dde4d 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -144,7 +144,8 @@ def forward(self, input_dict): if is_train: if action_masks is None: categorical = [Categorical(logits=logit) for logit in logits] - else: + else: + action_masks = np.split(action_masks,len(logits), axis=1) categorical = [CategoricalMasked(logits=logit, masks=mask) for logit, mask in zip(logits, action_masks)] prev_actions = torch.split(prev_actions, 1, dim=-1) prev_neglogp = [-c.log_prob(a.squeeze()) for c,a in zip(categorical, prev_actions)] @@ -162,7 +163,8 @@ def forward(self, input_dict): else: if action_masks is None: categorical = [Categorical(logits=logit) for logit in logits] - else: + else: + action_masks = np.split(action_masks, len(logits), axis=1) categorical = [CategoricalMasked(logits=logit, masks=mask) for logit, mask in zip(logits, action_masks)] selected_action = [c.sample().long() for c in categorical] diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index e553deaa..d8b335e3 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -153,11 +153,12 @@ def create_roboschool_env(name): return gym.make(name) def create_smac(name, **kwargs): - from rl_games.envs.smac_env import SMACEnv + from rl_games.envs.smac_env import SMACEnv, MultiDiscreteSmacWrapper frames = kwargs.pop('frames', 1) transpose = kwargs.pop('transpose', False) flatten = kwargs.pop('flatten', True) has_cv = kwargs.get('central_value', False) + as_single_agent = kwargs.pop('as_single_agent', False) env = SMACEnv(name, **kwargs) @@ -166,6 +167,9 @@ def create_smac(name, **kwargs): env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten) else: env = wrappers.BatchedFrameStack(env, frames, transpose=False, flatten=flatten) + + if as_single_agent: + env = MultiDiscreteSmacWrapper(env) return env def create_smac_v2(name, **kwargs): @@ -184,16 +188,18 @@ def create_smac_v2(name, **kwargs): return env def create_smac_cnn(name, **kwargs): - from rl_games.envs.smac_env import SMACEnv + from rl_games.envs.smac_env import SMACEnv, MultiDiscreteSmacWrapper has_cv = kwargs.get('central_value', False) frames = kwargs.pop('frames', 4) transpose = kwargs.pop('transpose', False) + env = SMACEnv(name, **kwargs) if has_cv: env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=transpose) else: env = wrappers.BatchedFrameStack(env, frames, transpose=transpose) - + if as_single_agent: + env = MultiDiscreteSmacWrapper(env) return env def create_test_env(name, **kwargs): diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index 9cc880a6..feea017c 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -341,7 +341,7 @@ def _init_from_env_info(self, env_info): if self.is_discrete or self.is_multi_discrete: self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=int), obs_base_shape) if self.use_action_masks: - self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=bool), obs_base_shape) + self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=(np.sum(self.actions_num),), dtype=bool), obs_base_shape) if self.is_continuous: self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape) self.tensor_dict['mus'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape) diff --git a/rl_games/configs/smac/v1/3m_torch_sa.yaml b/rl_games/configs/smac/v1/3m_torch_sa.yaml new file mode 100644 index 00000000..27d76a49 --- /dev/null +++ b/rl_games/configs/smac/v1/3m_torch_sa.yaml @@ -0,0 +1,54 @@ +params: + algo: + name: a2c_discrete + + model: + name: multi_discrete_a2c + + network: + name: actor_critic + separate: True + #normalization: layer_norm + space: + multi_discrete: + + mlp: + units: [256, 128] + activation: relu + initializer: + name: default + regularizer: + name: None + config: + name: 3m_sa + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 5e-4 + score_to_win: 20 + grad_norm: 0.5 + entropy_coef: 0.001 + truncate_grads: True + env_name: smac + e_clip: 0.2 + clip_value: True + num_actors: 8 + horizon_length: 128 + minibatch_size: 512 + mini_epochs: 4 + critic_coef: 1 + lr_schedule: None + kl_threshold: 0.05 + normalize_input: True + use_action_masks: True + ignore_dead_batches : False + + env_config: + name: 3m + frames: 1 + transpose: False + random_invalid_step: False + as_single_agent: True + central_value: True \ No newline at end of file diff --git a/rl_games/configs/smac/v1/5m_vs_6m_sa.yaml b/rl_games/configs/smac/v1/5m_vs_6m_sa.yaml new file mode 100644 index 00000000..4dc8b7ef --- /dev/null +++ b/rl_games/configs/smac/v1/5m_vs_6m_sa.yaml @@ -0,0 +1,60 @@ +params: + algo: + name: a2c_discrete + + model: + name: multi_discrete_a2c + + network: + name: actor_critic + separate: True + space: + multi_discrete: + + mlp: + units: [512, 256, 128] + activation: relu + initializer: + name: default + + config: + name: 5m_vs_6m_sa + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 3e-4 + score_to_win: 20 + entropy_coef: 0.02 + truncate_grads: True + grad_norm: 1 + env_name: smac + e_clip: 0.2 + clip_value: False + num_actors: 8 + horizon_length: 256 + minibatch_size: 1024 + mini_epochs: 4 + critic_coef: 2 + lr_schedule: None + kl_threshold: 0.05 + normalize_input: True + normalize_value: False + use_action_masks: True + use_diagnostics: True + seq_length: 8 + max_epochs: 10000 + env_config: + name: 5m_vs_6m + central_value: True + reward_only_positive: True + obs_last_action: False + apply_agent_ids: False + as_single_agent: True + + player: + render: False + games_num: 200 + n_game_life: 1 + determenistic: True diff --git a/rl_games/envs/smac_env.py b/rl_games/envs/smac_env.py index 79695743..8561ce1e 100644 --- a/rl_games/envs/smac_env.py +++ b/rl_games/envs/smac_env.py @@ -101,6 +101,38 @@ def get_action_mask(self): def has_action_mask(self): return not self.random_invalid_step - def seed(self, _): + def seed(self, seed): pass + #self.env.seed(seed) + +class MultiDiscreteSmacWrapper(gym.Env): + def __init__(self, env): + gym.Env.__init__(self) + self.env = env + self.observation_space = self.env.state_space + self.action_space = gym.spaces.Tuple([self.env.action_space] * self.env.get_number_of_agents()) + + def step(self, actions): + fixed_rewards = None + obses, reward, done, info = self.env.step(actions) + return obses['state'], reward[0], done[0], info + + def reset(self): + obses = self.env.reset() + return obses['state'] + + def has_action_mask(self): + return self.env.has_action_mask() + + def get_action_mask(self): + action_maks = self.env.get_action_mask() + action_maks = action_maks.flatten() + return np.expand_dims(action_maks, axis=0) + + def get_number_of_agents(self): + return 1 + + def seed(self, seed): + pass + #self.env.seed(seed)