Skip to content

Commit

Permalink
feature(pu): add sampled_unizero multitask pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 24, 2024
1 parent d06ce61 commit 3a88f46
Show file tree
Hide file tree
Showing 21 changed files with 2,068 additions and 93 deletions.
11 changes: 9 additions & 2 deletions lzero/entry/train_unizero_multitask_segment_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,15 @@ def train_unizero_multitask_segment_ddp(
for config in tasks_for_this_rank:
config[1][0].policy.task_num = tasks_per_rank

# 确保指定的policy类型是支持的
assert create_cfg.policy.type in ['unizero_multitask'], "当前仅支持 'unizero_multitask' 类型的policy"
# 确保指定的策略类型受支持
assert create_cfg.policy.type in ['unizero_multitask',
'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'"

if create_cfg.policy.type == 'unizero_multitask':
from lzero.mcts import UniZeroGameBuffer as GameBuffer
if create_cfg.policy.type == 'sampled_unizero_multitask':
from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer


# 根据CUDA可用性设置设备
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
Expand Down
10 changes: 7 additions & 3 deletions lzero/entry/train_unizero_multitask_segment_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.mcts import UniZeroGameBuffer as GameBuffer
from lzero.policy import visit_count_temperature
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
Expand Down Expand Up @@ -59,7 +58,12 @@ def train_unizero_multitask_segment_serial(
task_id, [cfg, create_cfg] = input_cfg_list[0]

# 确保指定的策略类型受支持
assert create_cfg.policy.type in ['unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'"
assert create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'"

if create_cfg.policy.type == 'unizero_multitask':
from lzero.mcts import UniZeroGameBuffer as GameBuffer
if create_cfg.policy.type == 'sampled_unizero_multitask':
from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer

# 根据CUDA可用性设置设备
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
Expand Down Expand Up @@ -181,7 +185,7 @@ def train_unizero_multitask_segment_serial(
print(f'开始收集任务 id: {task_id}...')

# 在每次收集前重置初始数据,对于多任务设置非常重要
collector._policy.reset(reset_init_data=True)
collector._policy.reset(reset_init_data=True, task_id=task_id)
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# 确定每次收集后的更新次数
Expand Down
17 changes: 10 additions & 7 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ def __init__(self, cfg: dict):
if hasattr(self._cfg, 'task_id'):
self.task_id = self._cfg.task_id
print(f"Task ID is set to {self.task_id}.")
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]

else:
self.task_id = None
print("No task_id found in configuration. Task ID is set to None.")
self.action_space_size = self._cfg.model.action_space_size

def reset_runtime_metrics(self):
"""
Expand Down Expand Up @@ -153,7 +156,7 @@ def sample(
self.compute_target_re_time += self._compute_target_timer.value

batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self._cfg.model.action_space_size
policy_non_re_context, self.action_space_size
)

# fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies
Expand Down Expand Up @@ -605,7 +608,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
reward_pool = reward_pool.squeeze().tolist()
policy_logits_pool = policy_logits_pool.tolist()
noises = [
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size
).astype(np.float32).tolist() for _ in range(transition_batch_size)
]
if self._cfg.mcts_ctree:
Expand Down Expand Up @@ -651,7 +654,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:

if policy_mask[policy_index] == 0:
# NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
target_policies.append([0 for _ in range(self._cfg.model.action_space_size)])
target_policies.append([0 for _ in range(self.action_space_size)])
else:
# NOTE: It is very important to use the latest MCTS visit count distribution.
sum_visits = sum(distributions)
Expand All @@ -660,7 +663,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
if distributions is None:
# if at some obs, the legal_action is None, add the fake target_policy
target_policies.append(
list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size)
list(np.ones(self.action_space_size) / self.action_space_size)
)
else:
# Update the data in game segment:
Expand All @@ -677,7 +680,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
target_policies.append(policy)
else:
# for board games that have two players and legal_actions is dy
policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)]
policy_tmp = [0 for _ in range(self.action_space_size)]
# to make sure target_policies have the same dimension
sum_visits = sum(distributions)
policy = [visit_count / sum_visits for visit_count in distributions]
Expand Down Expand Up @@ -706,7 +709,7 @@ def _compute_target_policy_non_reanalyzed(
- game_segment_lens
- action_mask_segment
- to_play_segment
- policy_shape: self._cfg.model.action_space_size
- policy_shape: self.action_space_size
Returns:
- batch_target_policies_non_re
"""
Expand All @@ -729,7 +732,7 @@ def _compute_target_policy_non_reanalyzed(
]
# NOTE: in continuous action space env: we set all legal_actions as -1
legal_actions = [
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
[-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size)
]
else:
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
Expand Down
63 changes: 44 additions & 19 deletions lzero/mcts/buffer/game_buffer_sampled_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,19 @@ def __init__(self, cfg: dict):
self.game_segment_buffer = []
self.game_pos_priorities = []
self.game_segment_game_pos_look_up = []
# self.task_id = self._cfg.task_id
self.sample_type = self._cfg.sample_type # 'transition' or 'episode'

if hasattr(self._cfg, 'task_id'):
self.task_id = self._cfg.task_id
print(f"Task ID is set to {self.task_id}.")
self.action_space_size = self._cfg.model.action_space_size_list[self.task_id]

else:
self.task_id = None
print("No task_id found in configuration. Task ID is set to None.")
self.action_space_size = self._cfg.model.action_space_size


def reanalyze_buffer(
self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
) -> List[Any]:
Expand Down Expand Up @@ -116,18 +126,18 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) ->
# pad random action
if self._cfg.model.continuous_action_space:
actions_tmp += [
np.random.randn(self._cfg.model.action_space_size)
np.random.randn(self.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
root_sampled_actions_tmp += [
np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size)
np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size)
for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp))
]
else:
# generate random `padded actions_tmp`
actions_tmp += generate_random_actions_discrete(
self._cfg.num_unroll_steps - len(actions_tmp),
self._cfg.model.action_space_size,
self.action_space_size,
1 # Number of sampled actions for actions_tmp is 1
)

Expand All @@ -136,7 +146,7 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) ->
reshape = True if self._cfg.mcts_ctree else False
root_sampled_actions_tmp += generate_random_actions_discrete(
self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp),
self._cfg.model.action_space_size,
self.action_space_size,
self._cfg.model.num_of_sampled_actions,
reshape=reshape
)
Expand Down Expand Up @@ -273,18 +283,18 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
# pad random action
if self._cfg.model.continuous_action_space:
actions_tmp += [
np.random.randn(self._cfg.model.action_space_size)
np.random.randn(self.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
root_sampled_actions_tmp += [
np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size)
np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size)
for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp))
]
else:
# generate random `padded actions_tmp`
actions_tmp += generate_random_actions_discrete(
self._cfg.num_unroll_steps - len(actions_tmp),
self._cfg.model.action_space_size,
self.action_space_size,
1 # Number of sampled actions for actions_tmp is 1
)

Expand All @@ -293,7 +303,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
reshape = True if self._cfg.mcts_ctree else False
root_sampled_actions_tmp += generate_random_actions_discrete(
self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp),
self._cfg.model.action_space_size,
self.action_space_size,
self._cfg.model.num_of_sampled_actions,
reshape=reshape
)
Expand All @@ -317,7 +327,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
if self._cfg.model.continuous_action_space:
# pad random action
bootstrap_action_tmp += [
np.random.randn(self._cfg.model.action_space_size)
np.random.randn(self.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(bootstrap_action_tmp))
]
bootstrap_action_list.append(bootstrap_action_tmp)
Expand Down Expand Up @@ -430,7 +440,7 @@ def _prepare_policy_reanalyzed_context(
]
return policy_re_context

def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, action_batch) -> np.ndarray:
def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, batch_action) -> np.ndarray:
"""
Overview:
prepare policy targets from the reanalyzed context of policies
Expand Down Expand Up @@ -475,9 +485,15 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:

# =============== NOTE: The key difference with MuZero =================
# calculate the target value
# action_batch.shape (32, 10)
# batch_action.shape (32, 10)
# batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352
m_output = model.initial_inference(batch_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num

if self.task_id is not None:
m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num
else:
m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num

m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num
# =======================================================================

if not model.training:
Expand All @@ -503,18 +519,24 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
# cpp mcts_tree
# roots = MCTSCtree.roots(transition_batch_size, legal_actions)
roots = MCTSCtree.roots(
transition_batch_size, legal_actions, self._cfg.model.action_space_size,
transition_batch_size, legal_actions, self.action_space_size,
self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
)
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)
if self.task_id is not None:
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id)
else:
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)
else:
# python mcts_tree
roots = MCTSPtree.roots(transition_batch_size, legal_actions)
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play)
if self.task_id is not None:
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id)
else:
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play)

roots_legal_actions_list = legal_actions
roots_distributions = roots.get_distributions()
Expand Down Expand Up @@ -594,7 +616,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
return batch_target_policies_re, root_sampled_actions


def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action) -> Tuple[
Any, Any]:
"""
Overview:
Expand All @@ -618,7 +640,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
if self._cfg.model.continuous_action_space is True:
# when the action space of the environment is continuous, action_mask[:] is None.
action_mask = [
list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
]
# NOTE: in continuous action space env: we set all legal_actions as -1
legal_actions = [
Expand All @@ -636,7 +658,10 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
# =============== NOTE: The key difference with MuZero =================
# calculate the target value
# batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352
m_output = model.initial_inference(batch_obs, action_batch)
if self.task_id is not None:
m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id)
else:
m_output = model.initial_inference(batch_obs, batch_action)
# ======================================================================

if not model.training:
Expand Down
Loading

0 comments on commit 3a88f46

Please sign in to comment.