From 021b7de57daf90680ca0d3761c8238a1812cc4f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 2 Dec 2024 00:46:38 +0800 Subject: [PATCH 1/2] feature(pu): add mcts_tictactoe_zh.py --- lzero/agent/mcts_tictactoe_zh.py | 192 +++++++++++++++++++++++++++++++ zoo/board_games/mcts_bot.py | 8 +- 2 files changed, 196 insertions(+), 4 deletions(-) create mode 100644 lzero/agent/mcts_tictactoe_zh.py diff --git a/lzero/agent/mcts_tictactoe_zh.py b/lzero/agent/mcts_tictactoe_zh.py new file mode 100644 index 000000000..58057aaff --- /dev/null +++ b/lzero/agent/mcts_tictactoe_zh.py @@ -0,0 +1,192 @@ +import math +import random + +# 游戏类,表示井字棋的状态 +class Game: + def __init__(self): + # 初始化棋盘,使用列表表示9个格子,初始为空格 + self.board = [' ' for _ in range(9)] + # 当前玩家,1表示玩家1(X),-1表示玩家2(O) + self.current_player = 1 + + def get_current_player(self): + # 返回当前玩家 + return self.current_player + + def get_legal_moves(self): + # 返回所有合法的走法,即棋盘中为空的位置的索引 + return [i for i in range(9) if self.board[i] == ' '] + + def make_move(self, move): + # 执行走法,如果目标位置不为空则抛出异常 + if self.board[move] != ' ': + raise ValueError("无效的走法") + # 根据当前玩家标记棋子 + self.board[move] = 'X' if self.current_player == 1 else 'O' + # 切换玩家 + self.current_player *= -1 + + def is_game_over(self): + # 定义所有可能的获胜线路 + lines = [ + [0, 1, 2], [3, 4, 5], [6, 7, 8], # 行 + [0, 3, 6], [1, 4, 7], [2, 5, 8], # 列 + [0, 4, 8], [2, 4, 6] # 对角线 + ] + # 检查是否有玩家获胜 + for line in lines: + a, b, c = line + if self.board[a] == self.board[b] == self.board[c] and self.board[a] != ' ': + return True, self.board[a] # 返回游戏结束和胜利者 + # 检查是否平局 + if ' ' not in self.board: + return True, 0 # 平局 + # 游戏未结束 + return False, None + + def clone(self): + # 克隆当前游戏状态,用于模拟 + cloned_game = Game() + cloned_game.board = self.board.copy() + cloned_game.current_player = self.current_player + return cloned_game + + def print_board(self): + # 打印当前棋盘状态 + print("当前棋盘状态:") + print(f"{self.board[0]} | {self.board[1]} | {self.board[2]}") + print("---------") + print(f"{self.board[3]} | {self.board[4]} | {self.board[5]}") + print("---------") + print(f"{self.board[6]} | {self.board[7]} | {self.board[8]}") + print() + +# 节点类,用于MCTS的树结构 +class Node: + def __init__(self, game, parent=None): + self.game = game # 当前游戏状态 + self.parent = parent # 父节点 + self.children = {} # 子节点,键为走法,值为节点 + self.visits = 0 # 访问次数 + self.value = 0.0 # 累计奖励值 + +# 选择子节点的策略(使用UCB1公式) +def select_child(self): + best_score = -float('inf') + best_move = None + best_child = None + for move, child in self.children.items(): + if child.visits == 0: + score = float('inf') # 未被访问过的节点优先选择 + else: + exploitation = child.value / child.visits # 利用 + exploration = math.sqrt(2 * math.log(self.visits) / child.visits) # 探索 + score = exploitation + exploration + if score > best_score: + best_score = score + best_move = move + best_child = child + return best_move, best_child + +# 为节点扩展所有可能的子节点 +def expand(self, game): + legal_moves = game.get_legal_moves() + for move in legal_moves: + new_game = game.clone() + new_game.make_move(move) + child_node = Node(new_game, parent=self) + self.children[move] = child_node + +# 模拟游戏直到结束,返回游戏结果 +def simulate(self): + game = self.game.clone() + while True: + is_over, result = game.is_game_over() + if is_over: + break + legal_moves = game.get_legal_moves() + move = random.choice(legal_moves) # 随机选择走法 + game.make_move(move) + return result # 返回 'X', 'O' 或 0 + +# 将上述函数绑定到Node类 +Node.select_child = select_child +Node.expand = expand +Node.simulate = simulate + +# MCTS算法实现 +def mcts(root_node, simulations=1000): + for _ in range(simulations): + node = root_node + game = node.game.clone() + # 选择阶段 + while node.children and not game.is_game_over()[0]: + move, node = node.select_child() + game.make_move(move) + # 扩展阶段 + if not node.children and not game.is_game_over()[0]: + node.expand(game) + # 模拟阶段 + if not game.is_game_over()[0]: + result = node.simulate() + else: + _, result = game.is_game_over() + # 回溯阶段 + while node: + node.visits += 1 + if result == 'X': + node.value += 1.0 if node.game.current_player == -1 else -1.0 + elif result == 'O': + node.value += -1.0 if node.game.current_player == -1 else 1.0 + else: + node.value += 0.0 # 平局 + node = node.parent + # 选择访问次数最多的走法作为最佳走法 + best_move = max(root_node.children.keys(), key=lambda move: root_node.children[move].visits) + return best_move + +# 人类玩家的走法输入 +def human_move(game): + while True: + try: + move_input = input("请输入你的走法(1-9):") + move = int(move_input) - 1 # 转换为索引 + if move not in game.get_legal_moves(): + print("无效的走法,请重新输入。") + else: + game.make_move(move) + break + except ValueError: + print("无效的输入,请输入一个数字。") + +# 机器人玩家的走法(使用MCTS) +def bot_move(game): + root_node = Node(game.clone()) + best_move = mcts(root_node, simulations=50) # 可以根据性能调整模拟次数 + game.make_move(best_move) + print(f"Bot选择了走法:{best_move + 1}") + +# 主函数,游戏循环 +def main(): + game = Game() + game.print_board() + + while not game.is_game_over()[0]: + if game.get_current_player() == 1: + human_move(game) # 玩家1(X)走法 + else: + bot_move(game) # 玩家2(O)走法 + game.print_board() + is_over, result = game.is_game_over() + if is_over: + if result == 'X': + print("玩家1(X)获胜!") + elif result == 'O': + print("玩家2(O)获胜!") + else: + print("平局!") + break + +# 运行主函数 +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/zoo/board_games/mcts_bot.py b/zoo/board_games/mcts_bot.py index 706c0f468..b7392d6ab 100644 --- a/zoo/board_games/mcts_bot.py +++ b/zoo/board_games/mcts_bot.py @@ -9,15 +9,15 @@ For more details, you can refer to: https://github.com/int8/monte-carlo-tree-search. """ -import time import copy +import os +import time from abc import ABC, abstractmethod from collections import defaultdict -from graphviz import Digraph -import os import numpy as np -import copy +from graphviz import Digraph + class MCTSNode(ABC): """ From e74f05a6987f6a963ab314255f0999f2da2b637c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Wed, 8 Jan 2025 16:06:09 +0800 Subject: [PATCH 2/2] feature(pu): add mcts_tictactoe.py --- lzero/agent/config/muzero/gym_cartpole_v0.py | 28 +++ lzero/agent/mcts_tictactoe.py | 195 +++++++++++++++++++ lzero/agent/mcts_tictactoe_zh.py | 87 +++++---- 3 files changed, 268 insertions(+), 42 deletions(-) create mode 100644 lzero/agent/mcts_tictactoe.py diff --git a/lzero/agent/config/muzero/gym_cartpole_v0.py b/lzero/agent/config/muzero/gym_cartpole_v0.py index 7304ac0e7..d49c8df0b 100644 --- a/lzero/agent/config/muzero/gym_cartpole_v0.py +++ b/lzero/agent/config/muzero/gym_cartpole_v0.py @@ -74,3 +74,31 @@ ) cfg = EasyDict(cfg) + + +if __name__ == "__main__": + # Note: Install the `huggingface_ding` package using the following shell commands + # git clone https://github.com/opendilab/huggingface_ding.git + # cd huggingface_ding + # pip3 install -e . + + # Import the required modules for downloading a pretrained model from Hugging Face Model Zoo + from lzero.agent import MuZeroAgent + from huggingface_ding import pull_model_from_hub + + # Pull the pretrained model and its configuration from the Hugging Face Hub + policy_state_dict, cfg = pull_model_from_hub(repo_id="OpenDILabCommunity/CartPole-v0-MuZero") + + # Instantiate the agent (MuZeroAgent) with the environment, configuration, and policy state + agent = MuZeroAgent( + env_id="CartPole-v0", # Environment ID + exp_name="CartPole-v0-MuZero", # Experiment name + cfg=cfg.exp_config, # Configuration for the experiment + policy_state_dict=policy_state_dict # Pretrained policy states + ) + + # Train the agent for 5000 steps + agent.train(step=5000) + + # Render the performance of the trained agent and save the replay + agent.deploy(enable_save_replay=True) \ No newline at end of file diff --git a/lzero/agent/mcts_tictactoe.py b/lzero/agent/mcts_tictactoe.py new file mode 100644 index 000000000..de0d6a4b2 --- /dev/null +++ b/lzero/agent/mcts_tictactoe.py @@ -0,0 +1,195 @@ +import math +import random + +# Game class representing the state of Tic-Tac-Toe +class Game: + def __init__(self): + # Initialize the board using a list of 9 cells, initially empty + self.board = [' ' for _ in range(9)] + # Current player: 1 represents Player 1 (X), -1 represents Player 2 (O) + self.current_player = 1 + + def get_current_player(self): + # Return the current player + return self.current_player + + def get_legal_moves(self): + # Return all legal moves, i.e., the indices of empty cells on the board + return [i for i in range(9) if self.board[i] == ' '] + + def make_move(self, move): + # Make a move; raise an exception if the target cell is not empty + if self.board[move] != ' ': + raise ValueError("Invalid move") + # Mark the cell based on the current player + self.board[move] = 'X' if self.current_player == 1 else 'O' + # Switch the player + self.current_player *= -1 + + def is_game_over(self): + # Define all possible winning lines + lines = [ + [0, 1, 2], [3, 4, 5], [6, 7, 8], # Rows + [0, 3, 6], [1, 4, 7], [2, 5, 8], # Columns + [0, 4, 8], [2, 4, 6] # Diagonals + ] + # Check if any player has won + for line in lines: + a, b, c = line + if self.board[a] == self.board[b] == self.board[c] and self.board[a] != ' ': + return True, self.board[a] # Return game over and the winner + # Check for a draw + if ' ' not in self.board: + return True, 0 # Draw + # Game is not over + return False, None + + def clone(self): + # Clone the current game state for simulation + cloned_game = Game() + cloned_game.board = self.board.copy() + cloned_game.current_player = self.current_player + return cloned_game + + def print_board(self): + # Print the current state of the board + print("Current board state:") + print(f"{self.board[0]} | {self.board[1]} | {self.board[2]}") + print("---------") + print(f"{self.board[3]} | {self.board[4]} | {self.board[5]}") + print("---------") + print(f"{self.board[6]} | {self.board[7]} | {self.board[8]}") + print() + +# Node class for the MCTS tree structure +class Node: + def __init__(self, game, parent=None): + self.game = game # Current game state + self.parent = parent # Parent node + self.children = {} # Child nodes, key is the move, value is the node + self.visits = 0 # Number of visits to this node + self.value = 0.0 # Accumulated reward value + + # Strategy for selecting child nodes (using the UCB1 formula) + def select_child(self): + best_score = -float('inf') + best_move = None + best_child = None + for move, child in self.children.items(): + if child.visits == 0: + score = float('inf') # Prioritize unvisited nodes + else: + exploitation = child.value / child.visits # Exploitation term + exploration = math.sqrt(2 * math.log(self.visits) / child.visits) # Exploration term + score = exploitation + exploration + if score > best_score: + best_score = score + best_move = move + best_child = child + return best_move, best_child + + # Expand all possible child nodes for this node + def expand(self, game): + legal_moves = game.get_legal_moves() + for move in legal_moves: + new_game = game.clone() + new_game.make_move(move) + child_node = Node(new_game, parent=self) + self.children[move] = child_node + + # Simulate the game until it ends, returning the game result + def simulate(self): + game = self.game.clone() + while True: + is_over, result = game.is_game_over() + if is_over: + break + legal_moves = game.get_legal_moves() + move = random.choice(legal_moves) # Randomly choose a move + game.make_move(move) + return result # Return 'X', 'O', or 0 + +# MCTS algorithm implementation +def mcts(root_node, simulations=1000): + for _ in range(simulations): + node = root_node + game = node.game.clone() + # Selection phase + while node.children and not game.is_game_over()[0]: + move, node = node.select_child() + game.make_move(move) + # Expansion phase + if not node.children and not game.is_game_over()[0]: + node.expand(game) + # Simulation phase + if not game.is_game_over()[0]: + result = node.simulate() + else: + _, result = game.is_game_over() + # Backpropagation phase + while node: + node.visits += 1 + if result == 'X': + node.value += 1.0 if node.game.current_player == -1 else -1.0 + elif result == 'O': + node.value += -1.0 if node.game.current_player == -1 else 1.0 + else: + node.value += 0.0 # Draw + node = node.parent + # Choose the move with the most visits as the best move + best_move = max(root_node.children.keys(), key=lambda move: root_node.children[move].visits) + return best_move + +# Human player move input +def human_move(game): + while True: + try: + move_input = input("Enter your move (1-9): ") + move = int(move_input) - 1 # Convert to index + if move not in game.get_legal_moves(): + print("Invalid move, please try again.") + else: + game.make_move(move) + break + except ValueError: + print("Invalid input, please enter a number.") + +# Bot player move (uses MCTS) +def bot_move(game): + root_node = Node(game.clone()) + best_move = mcts(root_node, simulations=50) # Adjust simulations for performance + game.make_move(best_move) + print(f"Bot chose move: {best_move + 1}") + +# Main function: game loop +def main(): + game = Game() + game.print_board() + + while not game.is_game_over()[0]: + if game.get_current_player() == 1: + human_move(game) # Player 1 (X) move + else: + bot_move(game) # Player 2 (O) move + game.print_board() + is_over, result = game.is_game_over() + if is_over: + if result == 'X': + print("Player 1 (X) wins!") + elif result == 'O': + print("Player 2 (O) wins!") + else: + print("It's a draw!") + break + +# Run the main function +if __name__ == "__main__": + """ + This file is a simple implementation of a Tic-Tac-Toe game, designed for educational purposes. + Features: + - Player 1 (X) competes against a bot (O) powered by Monte Carlo Tree Search (MCTS). + - The game is played via command-line interaction, with the player providing inputs for their moves. + - The bot uses the MCTS algorithm to determine the best moves. + - Demonstrates the basic principles of MCTS: selection, expansion, simulation, and backpropagation. + """ + main() \ No newline at end of file diff --git a/lzero/agent/mcts_tictactoe_zh.py b/lzero/agent/mcts_tictactoe_zh.py index 58057aaff..5688e4b9e 100644 --- a/lzero/agent/mcts_tictactoe_zh.py +++ b/lzero/agent/mcts_tictactoe_zh.py @@ -70,49 +70,44 @@ def __init__(self, game, parent=None): self.visits = 0 # 访问次数 self.value = 0.0 # 累计奖励值 -# 选择子节点的策略(使用UCB1公式) -def select_child(self): - best_score = -float('inf') - best_move = None - best_child = None - for move, child in self.children.items(): - if child.visits == 0: - score = float('inf') # 未被访问过的节点优先选择 - else: - exploitation = child.value / child.visits # 利用 - exploration = math.sqrt(2 * math.log(self.visits) / child.visits) # 探索 - score = exploitation + exploration - if score > best_score: - best_score = score - best_move = move - best_child = child - return best_move, best_child - -# 为节点扩展所有可能的子节点 -def expand(self, game): - legal_moves = game.get_legal_moves() - for move in legal_moves: - new_game = game.clone() - new_game.make_move(move) - child_node = Node(new_game, parent=self) - self.children[move] = child_node - -# 模拟游戏直到结束,返回游戏结果 -def simulate(self): - game = self.game.clone() - while True: - is_over, result = game.is_game_over() - if is_over: - break + # 选择子节点的策略(使用UCB1公式) + def select_child(self): + best_score = -float('inf') + best_move = None + best_child = None + for move, child in self.children.items(): + if child.visits == 0: + score = float('inf') # 未被访问过的节点优先选择 + else: + exploitation = child.value / child.visits # 利用 + exploration = math.sqrt(2 * math.log(self.visits) / child.visits) # 探索 + score = exploitation + exploration + if score > best_score: + best_score = score + best_move = move + best_child = child + return best_move, best_child + + # 为节点扩展所有可能的子节点 + def expand(self, game): legal_moves = game.get_legal_moves() - move = random.choice(legal_moves) # 随机选择走法 - game.make_move(move) - return result # 返回 'X', 'O' 或 0 - -# 将上述函数绑定到Node类 -Node.select_child = select_child -Node.expand = expand -Node.simulate = simulate + for move in legal_moves: + new_game = game.clone() + new_game.make_move(move) + child_node = Node(new_game, parent=self) + self.children[move] = child_node + + # 模拟游戏直到结束,返回游戏结果 + def simulate(self): + game = self.game.clone() + while True: + is_over, result = game.is_game_over() + if is_over: + break + legal_moves = game.get_legal_moves() + move = random.choice(legal_moves) # 随机选择走法 + game.make_move(move) + return result # 返回 'X', 'O' 或 0 # MCTS算法实现 def mcts(root_node, simulations=1000): @@ -189,4 +184,12 @@ def main(): # 运行主函数 if __name__ == "__main__": + """ + 本文件是一个简易的井字棋(Tic-Tac-Toe)游戏实现,采用单文件结构,主要用于教学目的。 + 功能概述: + - 玩家(X)与基于蒙特卡洛树搜索(MCTS)的机器人(O)进行对战。 + - 通过命令行交互进行游戏,玩家可以输入自己的走法,机器人则通过MCTS算法选择最佳走法。 + - 游戏展示了MCTS算法的基本流程,包括选择、扩展、模拟和回溯阶段。 + - 适合用于学习井字棋和MCTS算法的基本原理。 + """ main() \ No newline at end of file