-
Notifications
You must be signed in to change notification settings - Fork 5
/
intersimple-expert-rollout-setobs2.py
65 lines (56 loc) · 2.06 KB
/
intersimple-expert-rollout-setobs2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import functools
from src.core.sampling import rollout_sb3
from intersim.envs import IntersimpleLidarFlatIncrementingAgent
from intersim.envs.intersimple import speed_reward
from intersim.expert import NormalizedIntersimpleExpert
from src.util.wrappers import CollisionPenaltyWrapper, Setobs
import numpy as np
from gym.wrappers import TransformObservation
obs_min = np.array([
[-1000, -1000, 0, -np.pi, -1e-1, 0.],
[0, -np.pi, -20, -20, -np.pi, -1e-1],
[0, -np.pi, -20, -20, -np.pi, -1e-1],
[0, -np.pi, -20, -20, -np.pi, -1e-1],
[0, -np.pi, -20, -20, -np.pi, -1e-1],
[0, -np.pi, -20, -20, -np.pi, -1e-1],
]).reshape(-1)
obs_max = np.array([
[1000, 1000, 20, np.pi, 1e-1, 0.],
[50, np.pi, 20, 20, np.pi, 1e-1],
[50, np.pi, 20, 20, np.pi, 1e-1],
[50, np.pi, 20, 20, np.pi, 1e-1],
[50, np.pi, 20, 20, np.pi, 1e-1],
[50, np.pi, 20, 20, np.pi, 1e-1],
]).reshape(-1)
def main(track:int, loc:int=0):
env = IntersimpleLidarFlatIncrementingAgent(
loc=loc,
track=track,
n_rays=5,
reward=functools.partial(
speed_reward,
collision_penalty=0
),
)
policy = NormalizedIntersimpleExpert(env, mu=0.001)
env = Setobs(TransformObservation(
CollisionPenaltyWrapper(
env,
collision_distance=6, collision_penalty=100
), lambda obs: (obs - obs_min) / (obs_max - obs_min + 1e-10)
))
print(env.nv, 'vehicles')
expert_data = rollout_sb3(env, policy, n_episodes=150, max_steps_per_episode=200)
states, actions, rewards, dones = expert_data
print(f'Expert mean episode length {(~dones).sum() / states.shape[0]}')
print(f'Expert mean reward per episode {rewards[~dones].sum() / states.shape[0]}')
print(f'Observation mean', states[~dones].mean(0))
print(f'Observation std', states[~dones].std(0))
torch.save(expert_data, f'intersimple-expert-data-setobs2-loc{loc}-track{track}.pt')
def loop(tracks:list=[0]):
for track in tracks:
main(track)
if __name__=='__main__':
import fire
fire.Fire(loop)