-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit db0d711
Showing
11 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
/.idea | ||
*.pyc | ||
build/ | ||
dist/ | ||
*.egg-info/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
language: python | ||
python: | ||
- '3.3' | ||
- '3.4' | ||
- '3.5' | ||
- '3.6' | ||
script: | ||
- python -m unittest discover | ||
deploy: | ||
provider: pypi | ||
skip_existing: true | ||
user: janhartigan | ||
password: | ||
secure: UZw0Sm+zqcxNda/VLQPhBLw8YFRqv9VVSLttM/bxRKBnMTu+X56gUTBeN8HfLqbztmMgxQL3M7bRP7voLL10fEP0MnZR0HZGAUPFH0y20GHacxaMP/3dntrSXvhTprRk5Sf70u/wjxTJ1TaB+sNzN75GMxrCljcca8JquRCLSsXH0anWyiGXJQ9oNI0h1zdL4/ujvJmGJI9k9uIhxbqSXfeEsGxz76HWL619jK3a2e7T/trtx7N3721sXKyCB9BlgrpKbyi4kZi1bQvGgnwWsrbKB2fAWPd92Y8ENGu6NR9B/qfeeBcRP77ArP6uqxLT68mUKzWzCEDXjN/wDtf3NgJ4FOB57UOB4QFH7phmPtJM3bq5aIFH+islONgDtS9MniSlmpcdTe6MN4CYLJFYiPQ18fqBtFFSKkbVyhNnYKqw3BUlT6sJd/aKzLG2rQWH3G6Q6T3PKIWlP17pQueBbxcX5YIjByNbLlZVzjSjVrmsEXVwVviOfDzs8xRzWzF/2bXFsdyeQfTnVW8ZpBlUlwUVw8CBHK31pfFgvxZfAkfHm13TPSOWCxLgfBWp1kOTnihqKnwszQFkiOYw0yzj1rEtMfb4NLeEpKgRCPN6xfN3+xUj19155rPF8fExTR2ZIwC3IEQvI1RAXRq4vEp25kxZXEuE6bR1rRWdtIa0N4Y= | ||
on: | ||
tags: true | ||
notifications: | ||
email: | ||
on_success: never |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
The MIT License | ||
|
||
Copyright (c) 2010-2018 Google, Inc. http://angularjs.org | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
THE SOFTWARE. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from random import choice | ||
|
||
class MonteCarlo: | ||
|
||
def __init__(self, root_node): | ||
self.root_node = root_node | ||
self.child_finder = None | ||
self.node_evaluator = lambda child: None | ||
|
||
def make_choice(self): | ||
best_children = [] | ||
most_visits = float('-inf') | ||
|
||
for child in self.root_node.children: | ||
if child.visits > most_visits: | ||
most_visits = child.visits | ||
best_children = [child] | ||
elif child.visits == most_visits: | ||
best_children.append(child) | ||
|
||
return choice(best_children) | ||
|
||
def simulate(self, expansion_count = 1): | ||
for i in range(expansion_count): | ||
current_node = self.root_node | ||
|
||
while current_node.expanded: | ||
current_node = current_node.get_preferred_child() | ||
|
||
self.expand(current_node) | ||
|
||
def expand(self, node): | ||
self.child_finder(node) | ||
|
||
for child in node.children: | ||
child_win_value = self.node_evaluator(child) | ||
|
||
if child_win_value != None: | ||
child.update_win_value(child_win_value) | ||
|
||
if not child.is_scorable(): | ||
self.random_rollout(child) | ||
child.children = [] | ||
|
||
node.expanded = True | ||
|
||
def random_rollout(self, node): | ||
self.child_finder(node) | ||
child = choice(node.children) | ||
node.children = [] | ||
node.add_child(child) | ||
child_win_value = self.node_evaluator(child) | ||
|
||
if child_win_value != None: | ||
node.update_win_value(child_win_value) | ||
else: | ||
self.random_rollout(child) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from random import choice | ||
from math import log, sqrt | ||
|
||
class Node: | ||
|
||
def __init__(self, state): | ||
self.state = state | ||
self.win_value = 0 | ||
self.policy_value = None | ||
self.visits = 0 | ||
self.parent = None | ||
self.children = [] | ||
self.expanded = False | ||
|
||
def update_win_value(self, value): | ||
self.win_value += value | ||
self.visits += 1 | ||
|
||
if self.parent: | ||
self.parent.update_win_value(value) | ||
|
||
def update_policy_value(self, value): | ||
self.policy_value = value | ||
|
||
def add_child(self, child): | ||
self.children.append(child) | ||
child.parent = self | ||
|
||
def add_children(self, children): | ||
for child in children: | ||
self.add_child(child) | ||
|
||
def get_preferred_child(self): | ||
best_children = [] | ||
best_score = float('-inf') | ||
|
||
for child in self.children: | ||
score = child.get_score() | ||
|
||
if score > best_score: | ||
best_score = score | ||
best_children = [child] | ||
elif score == best_score: | ||
best_children.append(child) | ||
|
||
return choice(best_children) | ||
|
||
def get_score(self): | ||
discovery_constant = 0.35 | ||
discovery_operand = discovery_constant * (self.policy_value or 1) * sqrt(log(self.parent.visits) / (self.visits or 1)) | ||
win_operand = self.win_value / (self.visits or 1) | ||
|
||
self.score = win_operand + discovery_operand | ||
|
||
return self.score | ||
|
||
def is_scorable(self): | ||
return self.visits or self.policy_value != None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
A Python3 library that you can use to run a Monte Carlo tree search, either traditionally with drilling down to end game states or with expert policies as you might provide from a neural network. | ||
|
||
- **Version:** 1.0.0 | ||
|
||
[![Build Status](https://travis-ci.org/ImparaAI/monte-carlo-tree-search.png?branch=master)](https://travis-ci.org/ImparaAI/monte-carlo-tree-search) | ||
|
||
# Basics | ||
|
||
If you're unfamiliar with the Monte Carlo tree search algorithm, you should first become familiar with it. Simply put, it helps make a decision from a set of possibile options by doing one of two things: | ||
|
||
- Constructing likely outcomes either by drilling down into random endstates for each option or.. | ||
- Using expert policies to make the similar determinations without having to drill down to end states | ||
|
||
As the user of this library, you only have to provide a mechanism of finding children, and optionally a way of evaluating nodes for end state outcomes. | ||
|
||
# Usage | ||
|
||
## Create instance | ||
|
||
Create a new Monte Carlo tree: | ||
|
||
```python | ||
from game import Game | ||
from montecarlo.node import Node | ||
from montecarlo.montecarlo import MonteCarlo | ||
|
||
montecarlo = MonteCarlo(Node(Game())) | ||
``` | ||
|
||
When instantiating the `MonteCarlo` class, you must pass in the root node of the tree with its state defined. The state of the node can be anything you will need to determine what the children of that node will be. | ||
|
||
For the sake of demonstration, we will assume you have an generic `Game` library that can tell you what moves are possible and make those moves. | ||
|
||
## Traditional Monte Carlo | ||
|
||
Add a child finder and a node evaluator: | ||
|
||
```python | ||
def child_finder(node): | ||
for move in node.state.get_possible_moves(): | ||
child = Node(deepcopy(node.state)) #or however you want to construct the child's state | ||
child.state.move(move) #or however your library works | ||
node.add_child(child) | ||
|
||
def node_evaluator(self, node): | ||
if node.state.won(): | ||
return 1 | ||
elif node.state.lost(): | ||
return -1 | ||
|
||
montecarlo.child_finder = child_finder | ||
montecarlo.node_evaluator = node_evaluator | ||
``` | ||
|
||
The `child_finder` simply needs to add new child nodes to the parent node passed into the function. If there are no children, the library won't try to drill down further. In that scenario, however, the parent should be in an end state, so the `node_evaluator` should return a value between `-1` and `1`. | ||
|
||
## Expert policy (AI) | ||
|
||
If you have an expert policy that you can apply to the children as they're being generated, the library will recognize that it doesn't need to make the costly drill down to an end state. If your neural net produces both an expert policy value for the children and a win value for the parent node, you can skip declaring the `node_evaluator` altogether. | ||
|
||
```python | ||
def child_finder(self, node): | ||
win_value, expert_policy_values = neural_network.predict(node.state) | ||
|
||
for move in node.state.get_possible_moves(): | ||
child = Node(deepcopy(node.state)) | ||
child.state.move(move) | ||
child.policy_value = get_child_policy_value(child, expert_policy_values) #should return a value between 0 and 1 | ||
node.add_child(child) | ||
|
||
node.update_win_value(node.state) # | ||
|
||
montecarlo.child_finder = child_finder | ||
``` | ||
|
||
## Simulate and make a choice | ||
|
||
Run the simulations: | ||
|
||
```python | ||
montecarlo.simulate(50) #number of expansions to run. higher is typically more accurate at the cost of processing time | ||
``` | ||
|
||
Once the simulations have been run you can ask the instance to make a choice: | ||
|
||
```python | ||
chosen_child_node = montecarlo.make_choice() | ||
chosen_child_node.state.do_something() | ||
``` | ||
|
||
After you've chosen a new root node, you can override it on the `montecarlo` instance and do more simulations from the new position in the tree. | ||
|
||
```python | ||
montecarlo.root_node = montecarlo.make_choice() | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import setuptools | ||
|
||
with open("readme.md", "r") as fh: | ||
long_description = fh.read() | ||
|
||
setuptools.setup( | ||
name="imparaai-montecarlo", | ||
version="1.0.0", | ||
license='MIT', | ||
author="ImparaAI", | ||
author_email="[email protected]", | ||
description="Library for running a Monte Carlo tree search either traditionally or with expert policies", | ||
long_description=long_description, | ||
long_description_content_type="text/markdown", | ||
url="https://github.com/ImparaAI/monte-carlo-tree-search", | ||
packages=setuptools.find_packages(), | ||
classifiers=[ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: MIT License", | ||
"Operating System :: OS Independent", | ||
], | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import unittest | ||
from montecarlo.node import Node | ||
from montecarlo.montecarlo import MonteCarlo | ||
|
||
class TestMonteCarlo(unittest.TestCase): | ||
|
||
def test_choice_is_correct(self): | ||
montecarlo = MonteCarlo(Node(0)) | ||
montecarlo.child_finder = self.child_finder | ||
montecarlo.node_evaluator = self.node_evaluator | ||
|
||
montecarlo.simulate(50) | ||
|
||
chosen_node = montecarlo.make_choice() | ||
self.assertIs(chosen_node.state, 1) | ||
|
||
def child_finder(self, node): | ||
if node.state == 0: | ||
node.add_children([Node(1), Node(-1)]) | ||
else: | ||
for i in range(2): | ||
modifier = (100 if i == 1 else 200) * (-1 if node.state < 0 else 1) | ||
node.add_child(Node(node.state + modifier)) | ||
|
||
def node_evaluator(self, node): | ||
if node.state > 1000: | ||
return 1 | ||
elif node.state < -1000: | ||
return -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import unittest | ||
from montecarlo.node import Node | ||
from montecarlo.montecarlo import MonteCarlo | ||
|
||
class TestPolicyValue(unittest.TestCase): | ||
|
||
def test_choice_is_correct(self): | ||
montecarlo = MonteCarlo(Node(0)) | ||
montecarlo.child_finder = self.child_finder | ||
|
||
montecarlo.simulate(50) | ||
|
||
chosen_node = montecarlo.make_choice() | ||
self.assertIs(chosen_node.state, 1) | ||
|
||
def child_finder(self, node): | ||
node.add_children(self.build_children(node)) | ||
node.update_win_value(node.state) | ||
|
||
def build_children(self, node): | ||
children = [] | ||
|
||
for i in range(2): | ||
child = Node(node.state or (1 if i == 1 else -1)) | ||
child.policy_value = .90 if i == 1 else 0.10 | ||
children.append(child) | ||
|
||
return children |