-
Notifications
You must be signed in to change notification settings - Fork 10
/
sac.jl
65 lines (59 loc) · 2.61 KB
/
sac.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
function SAC_target(π)
(π⁻, 𝒫, 𝒟, γ; kwargs...) -> begin
ap, logprob = exploration(actor(π), 𝒟[:sp])
y = 𝒟[:r] .+ γ .* (1.f0 .- 𝒟[:done]) .* (min.(value(π⁻, 𝒟[:sp], ap)...) .- exp(𝒫[:SAC_log_α][1]).*logprob)
end
end
function SAC_deterministic_target(π)
(π⁻, 𝒫, 𝒟, γ; kwargs...) -> begin
y = 𝒟[:r] .+ γ .* (1.f0 .- 𝒟[:done]) .* min.(value(π⁻, 𝒟[:sp], action(actor(π), 𝒟[:sp]))...)
end
end
function SAC_max_Q_target(π)
(π⁻, 𝒫, 𝒟, γ; kwargs...) -> begin
error("not implemented")
#TODO: Sample some number of actions and then choose the max
end
end
function SAC_actor_loss(π, 𝒫, 𝒟; info = Dict())
a, logprob = exploration(π.A, 𝒟[:s])
ignore_derivatives() do
info["entropy"] = -mean(logprob)
end
mean(exp(𝒫[:SAC_log_α][1]).*logprob .- min.(value(π, 𝒟[:s], a)...))
end
function SAC_temp_loss(π, 𝒫, 𝒟; info = Dict())
ignore_derivatives() do
info["SAC alpha"] = exp(𝒫[:SAC_log_α][1])
end
_, logprob = exploration(π.A, 𝒟[:s])
target_α = logprob .+ 𝒫[:SAC_H_target]
-mean(exp(𝒫[:SAC_log_α][1]) .* target_α)
end
function SAC(;π::ActorCritic{T, DoubleNetwork{ContinuousNetwork, ContinuousNetwork}},
ΔN=50,
SAC_α::Float32=1f0,
SAC_H_target::Float32 = Float32(-prod(dim(action_space(π)))),
π_explore=GaussianNoiseExplorationPolicy(0.1f0),
SAC_α_opt::NamedTuple=(;),
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
a_loss=SAC_actor_loss,
c_loss=double_Q_loss(),
target_fn=SAC_target(π),
prefix="",
log::NamedTuple=(;),
𝒫::NamedTuple=(;),
param_optimizers=Dict(),
kwargs...) where T
𝒫 = (SAC_log_α=[Base.log(SAC_α)], SAC_H_target=SAC_H_target, 𝒫...)
OffPolicySolver(;agent=PolicyParams(π=π, π_explore=π_explore, π⁻=deepcopy(π)),
ΔN=ΔN,
𝒫=𝒫,
log=LoggerParams(;dir = "log/sac", log...),
param_optimizers=Dict(Flux.params(𝒫[:SAC_log_α]) => TrainingParams(;loss=SAC_temp_loss, name="temp_", SAC_α_opt...), param_optimizers...),
a_opt=TrainingParams(;loss=a_loss, name=string(prefix, "actor_"), a_opt...),
c_opt=TrainingParams(;loss=c_loss, name=string(prefix, "critic_"), epochs=ΔN, c_opt...),
target_fn=target_fn,
kwargs...)
end