Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GlobalMaxPool and GlobalMeanPool #88

Merged
merged 10 commits into from
Jul 14, 2024

Conversation

lambe
Copy link

@lambe lambe commented Oct 27, 2023

  • updates serialization and deserialization code to permit model import and export with Flux.GlobalMaxPool and Flux.GlobalMeanPool operations

@lambe
Copy link
Author

lambe commented Oct 27, 2023

Lumping in some code for Flux's sigmoid_fast activation function (just ties into a regular sigmoid) and the unsqueeze method (opposite of flatten)

@DrChainsaw
Copy link
Owner

Looks good, but please add tests as well.

@DrChainsaw
Copy link
Owner

DrChainsaw commented Nov 2, 2023

You can change serialization tests to use the Flux native global pools here and here.

To make things look consistent, you might want to add a gpvertex function here which looks something like this: gmpvertex(name, inpt::AbstractVertex) = fluxvertex(name, GlobalMeanPool(), inpt) and maybe a similar one for the average pool.

@lambe
Copy link
Author

lambe commented Nov 3, 2023

Thanks for the tip. I just pushed this update.

However, these tests are now broken. Error is of this form:

Shortcut to globpool -> dense: Error During Test at /home/ablambe/toolpath/ONNXNaiveNASflux.jl/test/serialize/serialize.jl:850
  Got exception outside of a @test
  MethodError: no method matching size(::ONNXNaiveNASflux.ProtoProbe{String, ONNXNaiveNASflux.var"#302#304"{NaiveNASlib.MutationVertex{NaiveNASlib.OutputsVertex{NaiveNASlib.CompVertex{NaiveNASflux.NoParams{GlobalMeanPool}}}, NaiveNASlib.AfterΔSizeTrait{NaiveNASlib.AfterΔSizeCallback{NaiveNASlib.var"#146#147"{typeof(NaiveNASlib.nameorrepr)}, NaiveNASlib.DefaultJuMPΔSizeStrategy}, NaiveNASlib.NamedTrait{String, NaiveNASlib.SizeInvariant}}}, ONNXNaiveNASflux.var"#306#307"{typeof(genname), Set{String}}}, ONNXNaiveNASflux.BaseOnnx.GraphProto, NTuple{4, Missing}})
  
  Closest candidates are:
    size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted})
     @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:582
    size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer)
     @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:581
    size(::Union{LinearAlgebra.QRCompactWYQ, LinearAlgebra.QRPackedQ})
     @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:585
    ...

The size() call here is within Flux - the size of the input tensor to the layer. Do I need to implement this method for this layer? I'm surprised this error doesn't show up from other layer types.

@lambe
Copy link
Author

lambe commented Nov 3, 2023

Correction to the above - I think I had an environment error in my local test setup.

The errors I'm getting now are mostly shape errors, e.g.,

Conv and batchnorm graph with cat: Error During Test at /home/ablambe/toolpath/ONNXNaiveNASflux.jl/test/serialize/serialize.jl:706
  Got exception outside of a @test
  DimensionMismatch: Wrong input dimension for Dense(5 => 2, relu)! Expected 2 dimensions, got shape (1, 1, 5, missing)

The issue is that the Flux behaviour is to keep the length-1 dimensions but the behaviour of this package is to drop them as a wrapper in the call to ONNXNaiveNASflux.globalpool(). Need to think about the logic a bit more.

@DrChainsaw
Copy link
Owner

Yup. I thought flux Dense supported any number of dimensions these days.

Just add a dropdims as a layer after the global pool in the testcases. Something like this: ddvertex(name, inpt) = invariantvertex(name, x -> dropdims(x; dims=(1,2)), inpt). That is what the wrapper does.

The problem here could be that many onnx models just send a 4D array into the Dense layer after a globpool since the onnx version of Dense supports any number of dimensions iirc. If we just don't touch the deserialization parts we should be ok I hope (unless the test really checks that each op in the loaded model is identical to the original).

@DrChainsaw DrChainsaw marked this pull request as ready for review July 14, 2024 12:59
Copy link

codecov bot commented Jul 14, 2024

Codecov Report

Attention: Patch coverage is 97.33333% with 2 lines in your changes missing coverage. Please review.

Project coverage is 96.12%. Comparing base (af07410) to head (0545cea).

Files Patch % Lines
src/deserialize/ops.jl 97.77% 1 Missing ⚠️
src/serialize/serialize.jl 87.50% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master      #88      +/-   ##
==========================================
+ Coverage   96.07%   96.12%   +0.04%     
==========================================
  Files          15       15              
  Lines        1044     1109      +65     
==========================================
+ Hits         1003     1066      +63     
- Misses         41       43       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Delete redundant global pool functions since we now use the one shipped by Flux.
@DrChainsaw DrChainsaw merged commit 507b0d9 into DrChainsaw:master Jul 14, 2024
5 checks passed
@DrChainsaw
Copy link
Owner

Thanks alot again @lambe. I will publish a new release with this and Protobuf 1 support right away.

The ddvertex idea above didn't work out in the tests for the following somewhat convoluted reason:

The test harness checks that the loaded CompGraph has the exact same vertices as the saved one.

However, ONNXNaiveNASflux tries to be helpful and merge any global pool followed directly by a Squeeze operation into the same vertex. This is both a reminder from before Flux shipped any global pools as well as something which is useful in the NAS context since one typically don't want the search algorithm to insert operations between the global pool and the squeeze.

As this is probably quite unintuitive for normal (i.e non-NAS practicing) users I improved the printing of the squeeze (and a couple of other unrelated ops) so that it at least is clear from casual inspection that all OPs are actually part of the CompGraph. I have also opened #93 for a more long term solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants