-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
207 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,12 @@ | ||
# Snake AI | ||
|
||
## Überblick | ||
In diesem Projekt schauen wir uns ein *Reinforcement-Learning* Projekt zum Spiel Snake an. | ||
In diesem Projekt schauen wir uns ein `Reinforcement-Learning` Projekt zum Spiel Snake an. | ||
|
||
## Ziel | ||
Das Ziel dieses Projekts ist es beispielhaft kennenzulernen wie man, mit der [Pytorch](https://pytorch.org/) library, ein etwas komplexeres neuronales Netzwerk implementiert und lernen die Grundlagen von *Reinforcement-Learning* kennen. | ||
Das Ziel dieses Projekts ist es beispielhaft kennenzulernen wie man, mit der [Pytorch](https://pytorch.org/) library, ein etwas komplexeres neuronales Netzwerk implementiert und lernen die Grundlagen von `Reinforcement-Learning` kennen. | ||
|
||
## Projekt | ||
Den Quellcode des Projekts und wie du es ausführst findest du [hier](https://github.com/MINT-EC-KI-Cluster/snake-rl) | ||
|
||
|
||
|
||
|
||
### Der Agent | ||
Ein sogenannter *Agent* ist im Reinforcement-Learning die Schnittstelle zwischen dem Neuronalen-Netzwerk / Modell und dem environment. In unserem Fall ist das *environment* unser Snake spiel. Dieses kann aber in verschiedenen Anwendungen, z.B. dem Autonomen Fahren ein völlig anderes sein.\ | ||
Der *Agent* dient also dazu das Modell anhand des environments zu nutzen und zu trainieren.\ | ||
In unserem Fall ist der Code für den Agenten in der 'agent.py' Datei zu finden. | ||
Er funktioniert wiefolgt: | ||
die train() Methode, ist die Methode welche für das ganze Training zuständig ist: | ||
``` | ||
def train(): | ||
total_score = 0 | ||
record = 0 | ||
agent = Agent() | ||
game = environment.World() | ||
while True: | ||
# get old state | ||
state_old = Agent.get_state(game) | ||
# get move | ||
final_move = agent.get_action(state_old) | ||
# perform move and get new state | ||
reward, done, score = game.step(final_move) | ||
state_new = Agent.get_state(game) | ||
# train short memory | ||
agent.train_short_memory(state_old, final_move, reward, state_new, done) | ||
# remember | ||
agent.remember(state_old, final_move, reward, state_new, done) | ||
if done: | ||
# train long memory, print result | ||
game.reset() | ||
agent.n_games += 1 | ||
agent.train_long_memory() | ||
if score > record: | ||
record = score | ||
agent.model.save() | ||
print('Game', agent.n_games, 'Score', score, 'Record:', record) | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Teil 2 | ||
## Der Agent | ||
Ein sogenannter `Agent` ist im `Reinforcement-Learning` die Schnittstelle zwischen dem Neuronalen-Netzwerk / Modell und dem environment. In unserem Fall ist das `environment` unser Snake spiel. Dieses kann aber in verschiedenen Anwendungen, z.B. dem Autonomen Fahren ein völlig anderes sein.\ | ||
Der `Agent` dient also dazu das Modell anhand des environments zu nutzen und zu trainieren.\ | ||
In unserem Fall ist der Code für den Agenten in der `agent.py` Datei zu finden. | ||
Er funktioniert wiefolgt: | ||
Es gibt eine sogenannte get_state() Methode | ||
``` | ||
def get_state(game: environment.World): | ||
state = [ | ||
# current direction | ||
game.snake.orientation[0] == -1, | ||
game.snake.orientation[0] == 1, | ||
game.snake.orientation[1] == 1, | ||
game.snake.orientation[1] == -1, | ||
# danger ahead | ||
game.danger_in_direction(environment.Direction.LEFT) == 1, | ||
game.danger_in_direction(environment.Direction.FORWARD) == 1, | ||
game.danger_in_direction(environment.Direction.RIGHT) == 1, | ||
# food pos | ||
game.foods[0][0] < game.snake.head.pos[0], # food left | ||
game.foods[0][0] > game.snake.head.pos[0], # food right | ||
game.foods[0][1] < game.snake.head.pos[1], # food above | ||
game.foods[0][1] > game.snake.head.pos[1], # food below | ||
] | ||
return np.array(state, dtype=int) | ||
``` | ||
Diese nimmt als Eingabe das `environment` und gibt einen sogenannten `state` zurück. Dieser repräsentiert die Eingabe für das Neuronale Netzwerk. Hier wird also definiert was das Modell von dem `environment` sieht und basierend auf diesen Informationen berechnet er sich eine statistisch gesehen beste Aktion. | ||
in diesem simplen Fall besteht der `state` aus 11 Werten die entweder 0 oder 1 sind. | ||
4 Werte welche die Ausrichtung der Schlange repräsentieren | ||
3 Werte ob links, rechts oder geradeaus eine Gefahr herrscht | ||
4 Werte für die relative Positon von dem Apfel | ||
|
||
Kommen wir zum Training | ||
hierfür sind mehrere Methoden wichtig | ||
z.B. die get_action() Methode: | ||
``` | ||
def get_action(self, state): | ||
self.epsilon = 80 - self.n_games | ||
final_move = [0,0,0] | ||
if random.randint(0, 200) < self.epsilon: | ||
move = random.randint(0, 2) | ||
final_move[move] = 1 | ||
else: | ||
state0 = torch.tensor(state, dtype=torch.float) | ||
prediction = self.model(state0) | ||
move = torch.argmax(prediction).item() | ||
final_move[move] = 1 | ||
return final_move | ||
``` | ||
Diese ist der sogenannte `forward-pass` für das Modell. Das heißt, dass der `state` dem Modell als Eingabewert gegeben wird und eine Aktion oder auch `prediction` genannt ausgegeben wird. | ||
Zu beachten ist, dass die ersten 80 Spiele, welche das Modell durchtrainiert, Randomness in die Aktionen eingeführt wird. Heißt, dass die Entscheidung was die Schlange macht nicht beim Modell liegt, sondern dem Zufall überlassen wird. Der Grund hierfür ist, dass das Modell noch "unerfahren" ist und durch die Zufälligkeit verschiedene Situationen sammelt, aus denen es lernen kann. | ||
|
||
Der nächste Codeblock ist sehr wichtig für das Training des Modells: | ||
``` | ||
def remember(self, state, action, reward, next_state, done): | ||
self.memory.append((state, action, reward, next_state, done)) | ||
def train_long_memory(self): | ||
if len(self.memory) > BATCH_SIZE: | ||
mini_sample = random.sample(self.memory, BATCH_SIZE) # list of tuples | ||
else: | ||
mini_sample = self.memory | ||
states, actions, rewards, next_states, dones = zip(*mini_sample) | ||
self.trainer.train_step(states, actions, rewards, next_states, dones) | ||
def train_short_memory(self, state, action, reward, next_state, done): | ||
self.trainer.train_step(state, action, reward, next_state, done) | ||
``` | ||
In diesem Codeblock sind 3 Methoden. | ||
Die `remember` Methode wird nach jedem Spielzug aufgerufen. Diese speichert im memory `self.memory = deque(maxlen=MAX_MEMORY) # popleft()` 5 Werte | ||
- state -> der ausgangs state | ||
- action -> die genommene Aktion welche von get_action() kommt | ||
- reward -> die Belohnung die wir von der step() Methode zurückkriegen | ||
- next_state -> der state nach dem Zug | ||
- done -> ob das Spiel nach dem Zug vorbei ist | ||
da jetzt all diese Situationen im "Gedächtnis" gespeichert sind, können sie in train_long_memory() verwendet werden. Diese Methode nimmt einen Zufälligen Batch aus diesen Situationen und trainiert das Modell auf diesen Batch. (Wie der training step verläuft und wie das Modell insgesamt aussieht, wird im nächsten Teil besprochen) | ||
train_long_memory() wird nach jeder Runde Snake aufgerufen. | ||
Im gegensatz zu train_short_memory(). Diese macht das gleiche wie train_long_memory(), nur mit dem jetzigen state und wird jeden Schritt aufgerufen. | ||
|
||
Zuletzt gibt es in der agent.py Datei eine eigenständige Methode die train() heißt. | ||
``` | ||
def train(): | ||
total_score = 0 | ||
record = 0 | ||
agent = Agent() | ||
game = environment.World() | ||
while True: | ||
# get old state | ||
state_old = Agent.get_state(game) | ||
# get move | ||
final_move = agent.get_action(state_old) | ||
# perform move and get new state | ||
reward, done, score = game.step(final_move) | ||
state_new = Agent.get_state(game) | ||
# train short memory | ||
agent.train_short_memory(state_old, final_move, reward, state_new, done) | ||
# remember | ||
agent.remember(state_old, final_move, reward, state_new, done) | ||
if done: | ||
# train long memory, print result | ||
game.reset() | ||
agent.n_games += 1 | ||
agent.train_long_memory() | ||
if score > record: | ||
record = score | ||
agent.model.save() | ||
print('Game', agent.n_games, 'Score', score, 'Record:', record) | ||
``` | ||
Diese besteht aus einem while-True-loop und trainiert das Modell. Hier sieht man die ganzen oben genannten Methoden und wie sie genau genutzt werden. | ||
Als letztes kommt noch das Model | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Teil 3 | ||
|
||
## Das Modell | ||
[Das Modell](https://music.youtube.com/watch?v=o0iga1eNgvA&si=QXGW6CjOd-Iel_sM&feature=xapp_share) ist recht simpel. | ||
Es besteht aus 11 x 256 x 3 Neuronen. | ||
Wie man sieht hat es 11 aktivations Neuronen, was genau der Werte aus einer `state` entspricht. | ||
Außerdem hat sie 3 output Neuronen. 1mal für links, 1mal für geradeaus und 1mal für rechts. | ||
Dazwischen sind 256 hidden layer neuronen, welche als `activation function` `ReLU` haben | ||
``` | ||
class Linear_QNet(nn.Module): | ||
def __init__(self, input_size, hidden_size, output_size): | ||
super().__init__() | ||
self.linear1 = nn.Linear(input_size, hidden_size) | ||
self.linear2 = nn.Linear(hidden_size, output_size) | ||
def forward(self, x): | ||
x = F.relu(self.linear1(x)) | ||
x = self.linear2(x) | ||
return x | ||
def save(self, file_name='model.pth'): | ||
model_folder_path = './model' | ||
if not os.path.exists(model_folder_path): | ||
os.makedirs(model_folder_path) | ||
file_name = os.path.join(model_folder_path, file_name) | ||
torch.save(self.state_dict(), file_name) | ||
``` | ||
|
||
Zuletzt haben wir noch den Trainer für das Modell | ||
``` | ||
class QTrainer: | ||
def __init__(self, model, lr, gamma): | ||
self.lr = lr | ||
self.gamma = gamma | ||
self.model = model | ||
self.optimizer = optim.Adam(model.parameters(), lr=self.lr) | ||
self.criterion = nn.MSELoss() | ||
def train_step(self, state, action, reward, next_state, done): | ||
state = torch.tensor(state, dtype=torch.float) | ||
next_state = torch.tensor(next_state, dtype=torch.float) | ||
action = torch.tensor(action, dtype=torch.long) | ||
reward = torch.tensor(reward, dtype=torch.float) | ||
# (n, x) | ||
if len(state.shape) == 1: | ||
# (1, x) | ||
state = torch.unsqueeze(state, 0) | ||
next_state = torch.unsqueeze(next_state, 0) | ||
action = torch.unsqueeze(action, 0) | ||
reward = torch.unsqueeze(reward, 0) | ||
done = (done, ) | ||
# 1: predicted Q values with current state | ||
pred = self.model(state) | ||
target = pred.clone() | ||
for idx in range(len(done)): | ||
Q_new = reward[idx] | ||
if not done[idx]: | ||
Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])) | ||
target[idx][torch.argmax(action[idx]).item()] = Q_new | ||
# 2: Q_new = r + y * max(next_predicted Q value) -> only do this if not done | ||
# pred.clone() | ||
# preds[argmax(action)] = Q_new | ||
self.optimizer.zero_grad() | ||
loss = self.criterion(target, pred) | ||
loss.backward() | ||
self.optimizer.step() | ||
``` | ||
Dieser hat die train_step() Methode die wir so oft im Agenten genutzt haben. Im train-step wird (Q-Lernen)[https://de.wikipedia.org/wiki/Q-Lernen] angewandt. Die Wikipedia Seite welche hierzu verlinkt ist, erklärt das komplexe Thema recht gut. | ||
Hierbei wird der (MeanSquaredErrorLoss)[https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html] als Loss-Function genutzt. |
Empty file.