-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from smessie/tensorflow-agent
Merge code into main branch
- Loading branch information
Showing
469 changed files
with
193,258 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,3 +130,6 @@ dmypy.json | |
|
||
# Jetbrains | ||
.idea | ||
|
||
# Secret token | ||
token |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, List | ||
|
||
from colorama import Back, Fore, Style, init | ||
|
||
from game.agent import Agent | ||
from game.player import Player | ||
from game.table import Table | ||
from util.cards import print_cards | ||
|
||
if TYPE_CHECKING: | ||
from game.card import Card | ||
|
||
|
||
init() # Required for colorama | ||
|
||
|
||
def print_whitespace(): | ||
print('') | ||
print('-' * 40) | ||
print('') | ||
|
||
|
||
class ConsoleAgent(Agent): | ||
def __init__(self, player_name: str = None): | ||
super().__init__(Player(player_name if player_name is not None and player_name != '' else 'ConsoleAgent')) | ||
|
||
def make_move(self, table: Table) -> None: | ||
print_whitespace() | ||
|
||
possible_moves: List[List[Card]] = self.player.get_all_possible_moves(table, self) | ||
for i, move in enumerate(possible_moves): | ||
print(f'Move {Style.BRIGHT}{i}{Style.NORMAL}:') | ||
print_cards(move) | ||
|
||
print("Your cards:") | ||
print_cards(sorted(self.player.hand, key=lambda x: x.value)) | ||
|
||
move = int(input('Enter move_nr to take: ')) | ||
while not 0 <= move < len(possible_moves): | ||
print(f'{Fore.RED}Move {move} is invalid! Try again.{Fore.RESET}') | ||
move = int(input('Enter move_nr to take: ')) | ||
table.try_move(self, possible_moves[move]) | ||
|
||
def move_played_callback(self, move: List[Card], player: Player): | ||
if move: | ||
print(f'{player.get_player_name()} made following move and has {Style.BRIGHT}{len(player.hand)}' | ||
f'{Style.NORMAL} cards left:') | ||
print_cards(move) | ||
else: | ||
print(f'{player.get_player_name()} passed.') | ||
|
||
def round_end_callback(self, agent_finish_order: List[Agent], table: Table): | ||
print_whitespace() | ||
print(f'{Fore.BLUE}Round has ended! Here is the ranking:{Fore.RESET}') | ||
for i, agent in enumerate(agent_finish_order): | ||
prefix = '' | ||
if agent == self: | ||
prefix = f"{Fore.BLUE}" | ||
print(f'{prefix}#{i+1}. {agent.player.get_player_name()}{Fore.RESET}') | ||
print_whitespace() | ||
|
||
def game_end_callback(self, game_nr: int) -> bool: | ||
print_whitespace() | ||
print('Aww, the game is already over :(') | ||
keep_playing = input('Lets play another game? (y/n)').lower().strip() | ||
while keep_playing != 'y' and keep_playing != 'n': | ||
keep_playing = input('Uhh what did you say? Lets play another game? (y/n)').lower().strip() | ||
return keep_playing == 'n' | ||
|
||
def trick_end_callback(self, table: Table, playing_agents: List[Agent]): | ||
print(f'{Fore.YELLOW}Let\'s clear the deck. On to the next trick! ' | ||
f'{", ".join(map(lambda agent: agent.player.get_player_name(), playing_agents))} are still playing.' | ||
f'{Fore.RESET}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from threading import Thread | ||
from typing import TYPE_CHECKING, List | ||
|
||
from discordbot.discord_bot import DiscordBot | ||
from game.agent import Agent | ||
from game.player import Player | ||
from game.table import Table | ||
from util.cards import print_cards_string | ||
|
||
if TYPE_CHECKING: | ||
from game.card import Card | ||
|
||
|
||
class DiscordAgent(Agent): | ||
def __init__(self, player_name: str = None): | ||
super().__init__(Player(player_name if player_name is not None and player_name != '' else 'DiscordAgent')) | ||
self.discord_bot = DiscordBot() | ||
loop = asyncio.get_event_loop() | ||
loop.create_task(self.discord_bot.start(input('Enter bot token: '))) | ||
loop.create_task(self.discord_bot.print_task()) | ||
thread = Thread(target=loop.run_forever, args=()) | ||
thread.start() | ||
|
||
def make_move(self, table: Table) -> None: | ||
self.print_whitespace() | ||
|
||
to_print = '' | ||
possible_moves: List[List[Card]] = self.player.get_all_possible_moves(table, self) | ||
for i, move in enumerate(possible_moves): | ||
to_print += f'Move {i}:\n' | ||
move_string = 'PASS\n\n' if move == [] else print_cards_string(move) | ||
if len(to_print) + len(move_string) > 1950: | ||
self.discord_bot.print(to_print) | ||
to_print = '' | ||
to_print += move_string | ||
|
||
to_print += 'Your cards:\n' | ||
move_string = print_cards_string(sorted(self.player.hand, key=lambda x: x.value)) | ||
if len(to_print) + len(move_string) > 1950: | ||
self.discord_bot.print(to_print) | ||
to_print = '' | ||
to_print += move_string | ||
self.discord_bot.print(to_print) | ||
|
||
move = self.discord_bot.read_int_input('Enter move_nr to take: ') | ||
while not 0 <= move < len(possible_moves): | ||
self.discord_bot.print(f'Move {move} is invalid! Try again.') | ||
move = self.discord_bot.read_int_input('Enter move_nr to take: ') | ||
table.try_move(self, possible_moves[move]) | ||
|
||
def move_played_callback(self, move: List[Card], player: Player): | ||
if move: | ||
to_print = '' | ||
to_print += f'{player.get_player_name()} made following move and has {len(player.hand)} cards left:\n' | ||
to_print += print_cards_string(move) | ||
self.discord_bot.print(to_print) | ||
else: | ||
self.discord_bot.print(f'{player.get_player_name()} passed.') | ||
|
||
def round_end_callback(self, agent_finish_order: List[Agent], table: Table): | ||
self.print_whitespace() | ||
to_print = '**Round has ended! Here is the ranking:**\n' | ||
for i, agent in enumerate(agent_finish_order): | ||
to_print += f'**#{i + 1}. {agent.player.get_player_name()}**\n' | ||
self.discord_bot.print(to_print) | ||
self.print_whitespace() | ||
|
||
def game_end_callback(self, game_nr: int) -> bool: | ||
self.print_whitespace() | ||
self.discord_bot.print('Aww, the game is already over :(') | ||
keep_playing = self.discord_bot.read_bool_input('Lets play another game? (y/n)') | ||
return not keep_playing | ||
|
||
def trick_end_callback(self, table: Table, playing_agents: List[Agent]): | ||
self.discord_bot.print(f'Let\'s clear the deck. On to the next trick! ' | ||
f'{", ".join(map(lambda agent: agent.player.get_player_name(), playing_agents))}' | ||
f' are still playing.') | ||
|
||
def print_whitespace(self): | ||
self.discord_bot.print('- \n' + ('-' * 40) + '\n- \n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from __future__ import annotations | ||
|
||
from itertools import chain | ||
from os import mkdir, path | ||
from pathlib import Path | ||
from random import choice, randint | ||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union | ||
|
||
from ai.model import PresidentModel | ||
from ai.representation_mapper import map_action_to_cards, map_cards_to_vector | ||
from game.agent import Agent | ||
from game.player import Player | ||
from game.table import Table | ||
|
||
if TYPE_CHECKING: | ||
from game.card import Card | ||
|
||
|
||
class DQLAgent(Agent): | ||
""" | ||
lower_eps_over_time adjusts the epsilon greedy policy by lowering epsilon over time. | ||
Every round, the eps_over_time is decreased with one, eps_over_time/lower_eps_over_time is used as epsilon in the | ||
epsilon greedy policy to do exploration as long as eps_over_time is bigger than zero. | ||
In training_mode=False, the epsilon parameter is ignored and all moves are requested from the model. | ||
In this mode there is no learning, no data is written to the model. | ||
""" | ||
def __init__( | ||
self, | ||
filepath: str = None, | ||
csv_filepath: str = None, | ||
buffer_capacity: int = 1000, | ||
hidden_layers: List[int] = [64], | ||
load_checkpoint: bool = False, | ||
gamma: float = 0.9, | ||
batch_size: int = 100, | ||
epsilon: int = 5, | ||
lower_eps_over_time: int = 0, | ||
start_eps_over_time: int = 100, | ||
track_training_loss: bool = False, | ||
living_reward: float = -0.01, | ||
training_mode: bool = True, | ||
early_stopping: bool = False, | ||
optimizer=None, | ||
loss=None, | ||
metrics=None, | ||
player_name: str = None, | ||
): | ||
super().__init__(Player(player_name if player_name is not None else 'DQLAgent')) | ||
print(f'Player {self.player.get_player_id()} is {hidden_layers},{buffer_capacity}') | ||
self.model: PresidentModel = PresidentModel( | ||
hidden_layers=hidden_layers, | ||
gamma=gamma, | ||
sample_batch_size=batch_size, | ||
track_training_loss=track_training_loss, | ||
filepath=f'data/results/training_loss-{self.player.get_player_id()}.csv', | ||
early_stopping=early_stopping, | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=metrics | ||
) | ||
# input vector (= cards in hand, previous move, all played cards); calculated move; reward; next move | ||
self.replay_buffer: List[Union[List[int], int, int, Optional[List[int]]]] = [] | ||
self.replay_buffer_capacity: int = buffer_capacity | ||
self.filepath: str = filepath if filepath else f'data/training-{self.player.player_id}/cp.ckpt' | ||
self.csv_filepath: str = csv_filepath if csv_filepath else f'data/results/wins-{self.player.player_id}.csv' | ||
self.epsilon: int = epsilon | ||
self.lower_eps_over_time: int = lower_eps_over_time | ||
self.start_eps_over_time: float = start_eps_over_time / 100 | ||
self.eps_over_time: int = lower_eps_over_time | ||
self.living_reward: float = living_reward | ||
self.training_mode: bool = training_mode | ||
|
||
for p in [Path(self.filepath), Path(self.csv_filepath)]: | ||
if not path.exists(p.parent.__str__()): | ||
mkdir(p.parent) | ||
|
||
self.rounds_positions: Optional[List[int]] = None | ||
self.triggered_early_stopping = False | ||
|
||
if load_checkpoint: | ||
self.model.load(filepath) | ||
|
||
def make_move(self, table: Table) -> None: | ||
""" | ||
Agent makes a move by using Deep Q-Learning. | ||
""" | ||
cards_in_hand_vector: List[int] = map_cards_to_vector(self.player.hand) | ||
cards_previous_move_vector: List[int] = map_cards_to_vector(table.last_move()[0] if table.last_move() else []) | ||
all_played_cards_vector: List[int] = map_cards_to_vector( | ||
list(chain.from_iterable([*map(lambda x: x[0], table.played_cards), *table.discard_pile]))) | ||
|
||
input_vector = cards_in_hand_vector + cards_previous_move_vector + all_played_cards_vector | ||
|
||
rand: int = randint(0, 100) | ||
|
||
exploration_chance: float = self.epsilon | ||
if self.eps_over_time > 0: | ||
exploration_chance: float = (self.eps_over_time / self.lower_eps_over_time) * self.start_eps_over_time | ||
if not self.training_mode: | ||
exploration_chance = 0 | ||
if rand >= exploration_chance: | ||
q_values: List[Tuple[int, int]] = sorted( | ||
[(i, v) for i, v in enumerate(self.model.calculate_next_move(input_vector)) | ||
], key=lambda x: -x[1]) | ||
|
||
i = 0 | ||
move: Optional[List[Card]] = map_action_to_cards(q_values[i][0], self.player.hand) | ||
while i < len(q_values) and (move is None or not table.game.valid_move(move, self)): | ||
i += 1 | ||
if i >= len(q_values): | ||
move = [] | ||
else: | ||
move = map_action_to_cards(q_values[i][0], self.player.hand) | ||
|
||
table.try_move(self, move) | ||
else: | ||
table.try_move(self, choice(self.player.get_all_possible_moves(table, self))) | ||
|
||
def get_preferred_card_order(self, table: Table) -> List[Card]: | ||
""" | ||
Returns the preferred cards to exchange in the beginning of a round in descending value order. | ||
""" | ||
possible_cards = list(set(table.deck.card_stack) - set(self.player.hand)) | ||
return sorted(possible_cards, reverse=True) | ||
|
||
def round_end_callback(self, agent_finish_order: List[Agent], table: Table): | ||
""" | ||
The game has ended, Train the model based on the moves made during the game and before the game. | ||
""" | ||
if self.training_mode: | ||
reward_list = list(map(lambda agent: agent.player.player_id, agent_finish_order)) | ||
# add reward to moves of last round. | ||
for agent in table.game.temp_memory: | ||
total_living_reward: float =\ | ||
len(table.game.temp_memory[agent]) * self.living_reward + self.living_reward | ||
for move in table.game.temp_memory[agent]: | ||
new_move: Any = list(move) | ||
if new_move[2] == 0: | ||
# If we didn't set a negative reward already, set the reward equal to the given reward for the | ||
# game. | ||
new_move[2] = (len(agent_finish_order) - | ||
(reward_list.index(agent.player.player_id) - 1) ** 2) + total_living_reward | ||
self.replay_buffer.append(new_move) | ||
if len(self.replay_buffer) > self.replay_buffer_capacity: | ||
self.replay_buffer.pop(0) | ||
total_living_reward -= self.living_reward | ||
|
||
self.triggered_early_stopping = self.model.train_model(self.replay_buffer) or self.triggered_early_stopping | ||
|
||
if self.eps_over_time > 0: | ||
self.eps_over_time -= 1 | ||
|
||
if not self.rounds_positions: | ||
self.rounds_positions = [0 for _ in range(len(agent_finish_order))] | ||
|
||
self.rounds_positions[agent_finish_order.index(self)] += 1 | ||
|
||
def game_end_callback(self, game_nr: int) -> bool: | ||
if self.training_mode: | ||
self.model.save(self.filepath) | ||
|
||
with open(self.csv_filepath, 'a+') as file: | ||
file.write(f'{game_nr},{",".join(map(str, self.rounds_positions))}\n') | ||
|
||
self.rounds_positions = None | ||
return self.triggered_early_stopping |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.