forked from Farama-Foundation/Minigrid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwrappers.py
126 lines (91 loc) · 3.22 KB
/
wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import math
import operator
from functools import reduce
import numpy as np
import gym
from gym import error, spaces, utils
class ActionBonus(gym.core.Wrapper):
"""
Wrapper which adds an exploration bonus.
This is a reward to encourage exploration of less
visited (state,action) pairs.
"""
def __init__(self, env):
super().__init__(env)
self.counts = {}
def step(self, action):
obs, reward, done, info = self.env.step(action)
env = self.unwrapped
tup = (env.agentPos, env.agentDir, action)
# Get the count for this (s,a) pair
preCnt = 0
if tup in self.counts:
preCnt = self.counts[tup]
# Update the count for this (s,a) pair
newCnt = preCnt + 1
self.counts[tup] = newCnt
bonus = 1 / math.sqrt(newCnt)
reward += bonus
return obs, reward, done, info
class StateBonus(gym.core.Wrapper):
"""
Adds an exploration bonus based on which positions
are visited on the grid.
"""
def __init__(self, env):
super().__init__(env)
self.counts = {}
def step(self, action):
obs, reward, done, info = self.env.step(action)
# Tuple based on which we index the counts
# We use the position after an update
env = self.unwrapped
tup = (env.agentPos)
# Get the count for this key
preCnt = 0
if tup in self.counts:
preCnt = self.counts[tup]
# Update the count for this key
newCnt = preCnt + 1
self.counts[tup] = newCnt
bonus = 1 / math.sqrt(newCnt)
reward += bonus
return obs, reward, done, info
class FlatObsWrapper(gym.core.ObservationWrapper):
"""
Encode mission strings using a one-hot scheme,
and combine these with observed images into one flat array
"""
def __init__(self, env, maxStrLen=64):
super().__init__(env)
self.maxStrLen = maxStrLen
self.numCharCodes = 27
imgSpace = env.observation_space.spaces['image']
imgSize = reduce(operator.mul, imgSpace.shape, 1)
self.observation_space = spaces.Box(
low=0,
high=255,
shape=(1, imgSize + self.numCharCodes * self.maxStrLen),
dtype='uint8'
)
self.cachedStr = None
self.cachedArray = None
def observation(self, obs):
image = obs['image']
mission = obs['mission']
# Cache the last-encoded mission string
if mission != self.cachedStr:
assert len(mission) <= self.maxStrLen, "mission string too long"
mission = mission.lower()
strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype='float32')
for idx, ch in enumerate(mission):
if ch >= 'a' and ch <= 'z':
chNo = ord(ch) - ord('a')
elif ch == ' ':
chNo = ord('z') - ord('a') + 1
assert chNo < self.numCharCodes, '%s : %d' % (ch, chNo)
strArray[idx, chNo] = 1
self.cachedStr = mission
self.cachedArray = strArray
obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
return obs