-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.jl
88 lines (74 loc) · 2.85 KB
/
utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
using Flux, POMDPs, POMDPSimulators
mutable struct ExperienceBuffer
s::Array{Float32, 2}
a::Array{Float32, 2}
sp::Array{Float32, 2}
r::Array{Float32, 2}
done::Array{Float32, 2}
elements::Int64
next_ind::Int64
end
ExperienceBuffer(sdim, adim, N) = ExperienceBuffer(zeros(sdim, N), zeros(adim, N), zeros(sdim, N), zeros(1,N), zeros(1,N), 0, 1)
empty_like(b::ExperienceBuffer) = ExperienceBuffer(size(b.s,1), size(b.a, 1), size(b.s, 2))
Base.length(b::ExperienceBuffer) = b.elements
function Base.push!(b::ExperienceBuffer, s, a, r, sp, done, mdp)
b.s[:, b.next_ind] .= convert_s(AbstractVector, s, mdp)
b.a[:, b.next_ind] .= Flux.onehot(a, actions(mdp))
b.sp[:, b.next_ind] .= convert_s(AbstractVector, sp, mdp)
b.r[1, b.next_ind] = r
b.done[1, b.next_ind] = done
b.elements = min(length(b.r), b.elements + 1)
b.next_ind = mod1(b.next_ind + 1, length(b.r))
end
function trim!(b::ExperienceBuffer)
b.elements = b.next_ind -1
b.next_ind = 1
b.s = b.s[:, 1:b.elements]
b.a = b.a[:, 1:b.elements]
b.r = b.r[:, 1:b.elements]
b.sp = b.sp[:, 1:b.elements]
b.done = b.done[:, 1:b.elements]
end
function gen_buffer(mdp, pol, Neps; desired_return = nothing, max_tries = 100*Neps, max_steps = 100, nonzero_transitions_only = false)
s = rand(initialstate(mdp))
odim, adim = length(convert_s(AbstractVector, s, mdp)), length(actions(mdp))
b = ExperienceBuffer(odim, adim, Neps*max_steps)
i, eps = 0, 0
while eps < Neps && i < max_tries
h = simulate(HistoryRecorder(max_steps = max_steps), mdp, pol)
if isnothing(desired_return) || undiscounted_reward(h) ≈ desired_return
eps += 1
for (s, a, r, sp) in eachstep(h, "(s, a, r, sp)")
if !nonzero_transitions_only || r != 0
push!(b, s, a, r, sp, isterminal(mdp, sp), mdp)
end
end
end
i += 1
end
trim!(b)
N = length(b)
println("eps: ", eps)
println("Took $eps episodes to fill buffer of size $N, for an average of $(N/eps) steps per ep")
# @assert length(b) == N
b
end
function sample(b::Union{ExperienceBuffer,Nothing}, N::Int64)
isnothing(b) && return nothing
ids = randperm(b.elements)[1:N]
(s = b.s[:,ids], a = b.a[:,ids], sp = b.sp[:,ids], r = b.r[:,ids], done = b.done[:,ids])
end
struct ChainPolicy <: Policy
qnet
mdp
end
POMDPs.action(p::ChainPolicy, s) = actions(mdp)[argmax(p.qnet(convert_s(AbstractVector, s, p.mdp)))]
function gen_occupancy(buffer, mdp)
occupancy = Dict(s => 0 for s in states(mdp))
for i=1:length(buffer)
s = convert_s(GWPos, buffer.s[:,i], mdp)
occupancy[s] += 1
end
occupancy
end
expected_return(mdp, policy, eps = 1000, agg = undiscounted_reward) = mean([agg(simulate(HistoryRecorder(max_steps = 100), mdp, policy)) for _=1:eps])