-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add L2/L1 regularization #179
Conversation
oops oops correct order of move and collate fix missing chain move fix forgotten addition to cache update doc-string
Codecov Report
@@ Coverage Diff @@
## dev #179 +/- ##
==========================================
- Coverage 90.74% 90.59% -0.15%
==========================================
Files 8 9 +1
Lines 216 234 +18
==========================================
+ Hits 196 212 +16
- Misses 20 22 +2
Continue to review full report at Codecov.
|
@ayush-1506 Would you like to review this PR? I've tested it locally on a GPU. @DilumAluthge How do I put the GPU tests back for PR's onto dev? |
@ablaom Sure, please give me a day. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one small question, everything else looks great.
@@ -43,48 +43,53 @@ end | |||
true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng | |||
|
|||
function MLJModelInterface.fit(model::MLJFluxModel, | |||
verbosity::Int, | |||
verbosity, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please correct me if I'm wrong, but verbosity should is still Int
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but you no longer need to explicitly type it. There used to be a type ambiguity that required the type annotation but that is now long gone.
This PR:
alpha
andlambda
are not used #169fit!
method. This change was natural in addressing above and anticipates further transfer of responsibility to planned data front-end (see https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Implementing-a-data-front-end and [Design discussion] Batches and resampling #97).@ToucheSir Further to Julia Discourse discussion, no change was actually necessary to the core
train!
loop to add regularization. It's just that the loss function passed to this loop now depends on thechain
(Flux model), when L2/L1 regularisation parameters are non-trivial. To avoid the array mutation error I needed to avoid broadcasting in the computation of the penalty here. Any performance suggestions re these two bits of code appreciated.