-
Notifications
You must be signed in to change notification settings - Fork 4
/
agent.py
124 lines (99 loc) · 3.77 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from torch import nn
import torch as th
import numpy as np
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_
from torch.distributions import Categorical
from rocket_learn.agent.policy import Policy
from typing import Optional, List, Tuple
class Opti(nn.Module): # takes an embedder and a network and runs the embedder on the car obs before passing to the network
def __init__(self, embedder: nn.Module, net: nn.Module):
super().__init__()
self.embedder = embedder
self.net = net
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
def forward(self, inp: tuple):
main, cars = inp
# shaped_cars = th.reshape(cars, (len(main), 5, len(cars[0])))
out = th.max(self.embedder(cars), -2)[0]
result = self.net(th.cat((main, out), dim=1))
return result
class OptiSelector(nn.Module): # takes an embedder and a network and runs the embedder on the car obs before passing to the network
# then outputs a tuple of output with action size, 1
def __init__(self, embedder: nn.Module, net: nn.Module, shape: Tuple[int, ...]):
super().__init__()
self.embedder = embedder
self.net = net
self._reset_parameters()
self.shape = shape
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
def forward(self, inp: tuple):
main, cars = inp
out = th.max(self.embedder(cars), -2)[0]
result = self.net(th.cat((main, out), dim=1))
if result.shape[1] != 1: # don't do critic
result = result.split(self.shape, 1)
return result
class MaskIndices(nn.Module):
def __init__(self, indices):
super().__init__()
self.indices = indices
def forward(self, x):
return x[..., ~self.indices]
class MultiDiscretePolicy(Policy):
def __init__(self, net: nn.Module, shape: Tuple[int, ...] = (3,) * 5 + (2,) * 3, deterministic=False):
return NotImplemented
super().__init__(deterministic)
self.net = net
self.shape = shape
def forward(self, obs):
logits = self.net(obs)
return logits
def get_action_distribution(self, obs):
if isinstance(obs, np.ndarray):
obs = th.from_numpy(obs).float()
elif isinstance(obs, tuple):
obs = tuple(o if isinstance(o, th.Tensor) else th.from_numpy(o).float() for o in obs)
logits = self(obs)
if isinstance(logits, th.Tensor):
logits = (logits,)
max_shape = max(self.shape)
logits = th.stack(
[
l
if l.shape[-1] == max_shape
else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf"))
for l in logits
],
dim=1
)
return Categorical(logits=logits)
def sample_action(
self,
distribution: Categorical,
deterministic=None
):
if deterministic is None:
deterministic = self.deterministic
if deterministic:
action_indices = th.argmax(distribution.logits, dim=-1)
else:
action_indices = distribution.sample()
return action_indices
def log_prob(self, distribution: Categorical, selected_action):
log_prob = distribution.log_prob(selected_action).sum(dim=-1)
return log_prob
def entropy(self, distribution: Categorical, selected_action):
entropy = distribution.entropy().sum(dim=-1)
return entropy
def env_compatible(self, action):
if isinstance(action, th.Tensor):
action = action.numpy()
return action