From d8705e66d032fea08c45d848837bb6ebc9c28854 Mon Sep 17 00:00:00 2001 From: puyuan1996 <2402552459@qq.com> Date: Tue, 24 Dec 2024 21:03:40 +0800 Subject: [PATCH] fix(pu): fix sampled_unizero multitask ddp pipeline --- .../train_unizero_multitask_segment_ddp.py | 3 +++ lzero/model/unizero_world_models/tokenizer.py | 3 ++- .../world_model_multitask.py | 2 +- .../dmc2gym_state_suz_multitask_ddp_config.py | 2 +- ...c2gym_state_suz_multitask_serial_config.py | 20 +++++++++---------- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py index e44722fa7..0423b7ba5 100644 --- a/lzero/entry/train_unizero_multitask_segment_ddp.py +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -330,6 +330,9 @@ def train_unizero_multitask_segment_ddp( print('=' * 20) print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') + # =========TODO========= + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + # 执行安全评估 stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) # 判断评估是否成功 diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index d384be9d0..d0f5e0483 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -93,9 +93,10 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten elif len(shape) == 4: # Case when input is 4D (B, C, H, W) try: - # obs_embeddings = self.encoder[task_id](x) obs_embeddings = self.encoder(x, task_id=task_id) # TODO: for dmc multitask + # obs_embeddings = self.encoder[task_id](x) except Exception as e: + print(e) obs_embeddings = self.encoder(x) # TODO: for memory env obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index 0735f8846..97830f9e3 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -889,8 +889,8 @@ def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor # Copy and store keys_values_wm for a single environment self.update_cache_context(current_obs_embeddings, is_init_infer=True) - # elif n > self.env_num and batch_action is not None and current_obs_embeddings is None: elif batch_action is not None and current_obs_embeddings is None: + # elif n > self.env_num and batch_action is not None and current_obs_embeddings is None: # ================ calculate the target value in Train phase ================ # [192, 16, 64] -> [32, 6, 16, 64] last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py index f015d942b..6d93ccbad 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py @@ -121,7 +121,7 @@ def generate_configs(env_id_list: List[str], num_segments: int, total_batch_size: int): configs = [] - exp_name_prefix = f'data_suz_mt_ddp_20241224/8gpu_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_seed{seed}/' + exp_name_prefix = f'data_suz_mt_20241224/ddp_8gpu_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_seed{seed}/' action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_serial_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_serial_config.py index 18f531f27..a31b9f2d0 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_serial_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_serial_config.py @@ -61,7 +61,7 @@ def create_config(env_id, action_space_size_list, observation_shape_list, collec num_layers=2, num_heads=8, embed_dim=768, - env_num=len(env_id_list), + env_num=max(collector_env_num, evaluator_env_num), task_num=len(env_id_list), use_normal_head=True, use_softmoe_head=False, @@ -111,7 +111,7 @@ def generate_configs(env_id_list, seed, collector_env_num, evaluator_env_num, n_ Generate configurations for all DMC tasks in the environment list. """ configs = [] - exp_name_prefix = f'data_suz_mt_debug/{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}/' + exp_name_prefix = f'data_suz_mt_20241224/{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}/' action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] for task_id, env_id in enumerate(env_id_list): @@ -195,7 +195,7 @@ def create_env_manager(): num_segments = 8 n_episode = 8 num_simulations = 50 - batch_size = [64, 64] # 可以根据需要调整或者设置为列表 + batch_size = [64 for _ in range(len(env_id_list))] num_unroll_steps = 5 infer_context_length = 2 norm_type = 'LN' @@ -206,13 +206,13 @@ def create_env_manager(): update_per_collect = 100 # ========== TODO: debug config ============ - collector_env_num = 2 - evaluator_env_num = 2 - num_segments = 2 - n_episode = 2 - num_simulations = 2 - batch_size = [4,4] # 可以根据需要调整或者设置为列表 - update_per_collect = 1 + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # num_simulations = 2 + # batch_size = [4,4] # 可以根据需要调整或者设置为列表 + # update_per_collect = 1 # 生成配置 configs = generate_configs(