Skip to content

Commit

Permalink
reexport Flux, update CI
Browse files Browse the repository at this point in the history
  • Loading branch information
ancorso committed Nov 30, 2023
1 parent 864127f commit ac8d7c4
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 17 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ jobs:
fail-fast: false
matrix:
version:
- "1.6" # Latest
- "1.9" # Latest
os:
- ubuntu-latest
- windows-latest
arch:
- x64
steps:
# install CUDA
- uses: Jimver/[email protected]
id: cuda-toolkit
with:
cuda: '11.2.2'
# # install CUDA
# - uses: Jimver/[email protected]
# id: cuda-toolkit
# with:
# cuda: '11.2.2'

# check out the project and install Julia
- uses: actions/checkout@v2
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
WeightsAndBiasLogger = "71805093-c1fc-4af5-8d0a-8dde53c6ac46"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
CUDA = "3, 4"
CUDA = "3, 4, 5"
ColorSchemes = "3"
Distributions = "0.25"
Flux = "0.13"
Flux = "0.14"
Images = "0.25"
POMDPTools = "0.1"
POMDPs = "0.9"
Expand Down
7 changes: 0 additions & 7 deletions examples/adversarial/continuous_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ vline!([deg2rad(20), deg2rad(-20)])

plot(-3:0.1:3, [exp.(logpdf(antagonist(𝒮_isarl.π).A, [0, 0, 0], [x]))[1] for x=-3:0.1:3], label="Px")

antagonist(𝒮_isarl.π).A.logΣ([0, 0, 0])
, [1])[1]

# Solve with DQN
𝒮_dqn = DQN=QS(as), S=S, N=N)
π_dqn = solve(𝒮_dqn, mdp)
Expand All @@ -82,10 +79,6 @@ println("RARL Failure rate: ", pfail_rarl)
pfail_isarl = Crux.failure(Sampler(mdp, protagonist(π_isarl), S=S, max_steps=100), Neps=Int(1e5), threshold=100)
println("IS Failure rate: ", pfail_isarl)

pol = AdvPol()

𝒮_isarl.buffer


pol = AdversarialPolicy(π_dqn, Pf(xs), ϵGreedyPolicy(Crux.LinearDecaySchedule(1., 0.1, floor(Int, N/2)), xs))
𝒮_isarl = ISARL_Discrete=pol, S=S, N=N, xlogprobs=xlogprobs)
Expand Down
4 changes: 3 additions & 1 deletion src/Crux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ module Crux

set_crux_warnings(val::Bool) = global CRUX_WARNINGS = val

using Reexport
using Random
using Distributions
using POMDPs
using POMDPTools:render
using Parameters
using TensorBoardLogger
using Flux
@reexport using Flux
using Zygote
import Zygote: ignore_derivatives
using Flux.Optimise: train!
Expand All @@ -24,6 +25,7 @@ module Crux
using Base.Iterators: partition
using WeightsAndBiasLogger
using Dates


extra_functions = Dict()
function set_function(key, val)
Expand Down
1 change: 1 addition & 0 deletions src/policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ function exploration(π::MixtureNetwork, s; kwargs...)
end
indices
end
println("weights: ", αs, "indices: ", indices)
a = hcat([exploration(d, s[:, i])[1] for (d, i) in zip.networks, indices)]...)

return a, logpdf(π, s, a)
Expand Down
2 changes: 1 addition & 1 deletion test/policy_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ USE_CUDA && @test all([Crux.device(n)== gpu for n in p_gpu.networks])
# @test all(Crux.value(p, s) .≈ Crux.value(p_gpu, s))
# @test Crux.valueall(p, s) == [Crux.value(p.networks[1], s), Crux.value(p.networks[2], s)]
# @test Crux.valueall(p, s, a) == [Crux.value(p.networks[1], s, a), Crux.value(p.networks[2], s, a)]
@test_broken try
@test_broken try # I think this is because the mixture network isn't designed for batch actions?
action(p, s)
true
catch
Expand Down

0 comments on commit ac8d7c4

Please sign in to comment.