diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index dd51b2c..2a97675 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,18 +16,18 @@ jobs: fail-fast: false matrix: version: - - "1.6" # Latest + - "1.9" # Latest os: - ubuntu-latest - windows-latest arch: - x64 steps: - # install CUDA - - uses: Jimver/cuda-toolkit@v0.2.4 - id: cuda-toolkit - with: - cuda: '11.2.2' + # # install CUDA + # - uses: Jimver/cuda-toolkit@v0.2.4 + # id: cuda-toolkit + # with: + # cuda: '11.2.2' # check out the project and install Julia - uses: actions/checkout@v2 diff --git a/Project.toml b/Project.toml index 89f2fca..82f239d 100644 --- a/Project.toml +++ b/Project.toml @@ -16,18 +16,20 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7" WeightsAndBiasLogger = "71805093-c1fc-4af5-8d0a-8dde53c6ac46" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] -CUDA = "3, 4" +CUDA = "3, 4, 5" ColorSchemes = "3" Distributions = "0.25" -Flux = "0.13" +Flux = "0.14" Images = "0.25" POMDPTools = "0.1" POMDPs = "0.9" diff --git a/examples/adversarial/continuous_pendulum.jl b/examples/adversarial/continuous_pendulum.jl index 25c8246..139f32d 100644 --- a/examples/adversarial/continuous_pendulum.jl +++ b/examples/adversarial/continuous_pendulum.jl @@ -61,9 +61,6 @@ vline!([deg2rad(20), deg2rad(-20)]) plot(-3:0.1:3, [exp.(logpdf(antagonist(𝒮_isarl.π).A, [0, 0, 0], [x]))[1] for x=-3:0.1:3], label="Px") -antagonist(𝒮_isarl.π).A.logΣ([0, 0, 0]) -, [1])[1] - # Solve with DQN 𝒮_dqn = DQN(π=QS(as), S=S, N=N) π_dqn = solve(𝒮_dqn, mdp) @@ -82,10 +79,6 @@ println("RARL Failure rate: ", pfail_rarl) pfail_isarl = Crux.failure(Sampler(mdp, protagonist(π_isarl), S=S, max_steps=100), Neps=Int(1e5), threshold=100) println("IS Failure rate: ", pfail_isarl) -pol = AdvPol() - -𝒮_isarl.buffer - pol = AdversarialPolicy(π_dqn, Pf(xs), ϵGreedyPolicy(Crux.LinearDecaySchedule(1., 0.1, floor(Int, N/2)), xs)) 𝒮_isarl = ISARL_Discrete(π=pol, S=S, N=N, xlogprobs=xlogprobs) diff --git a/src/Crux.jl b/src/Crux.jl index 45cdea4..32f716c 100644 --- a/src/Crux.jl +++ b/src/Crux.jl @@ -3,13 +3,14 @@ module Crux set_crux_warnings(val::Bool) = global CRUX_WARNINGS = val + using Reexport using Random using Distributions using POMDPs using POMDPTools:render using Parameters using TensorBoardLogger - using Flux + @reexport using Flux using Zygote import Zygote: ignore_derivatives using Flux.Optimise: train! @@ -24,6 +25,7 @@ module Crux using Base.Iterators: partition using WeightsAndBiasLogger using Dates + extra_functions = Dict() function set_function(key, val) diff --git a/src/policies.jl b/src/policies.jl index d80ae2b..11ac448 100644 --- a/src/policies.jl +++ b/src/policies.jl @@ -210,6 +210,7 @@ function exploration(π::MixtureNetwork, s; kwargs...) end indices end + println("weights: ", αs, "indices: ", indices) a = hcat([exploration(d, s[:, i])[1] for (d, i) in zip(π.networks, indices)]...) return a, logpdf(π, s, a) diff --git a/test/policy_tests.jl b/test/policy_tests.jl index 9c6b12b..3779edb 100644 --- a/test/policy_tests.jl +++ b/test/policy_tests.jl @@ -151,7 +151,7 @@ USE_CUDA && @test all([Crux.device(n)== gpu for n in p_gpu.networks]) # @test all(Crux.value(p, s) .≈ Crux.value(p_gpu, s)) # @test Crux.valueall(p, s) == [Crux.value(p.networks[1], s), Crux.value(p.networks[2], s)] # @test Crux.valueall(p, s, a) == [Crux.value(p.networks[1], s, a), Crux.value(p.networks[2], s, a)] -@test_broken try +@test_broken try # I think this is because the mixture network isn't designed for batch actions? action(p, s) true catch