Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add minibatch subsampling (doubly stochastic) objective #84

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Aug 11, 2024

This is a draft for the subsampling variational objective, which addresses #38 . Any perspectives/concerns/comments are welcome! The current plan is to only implement random reshuffling. As I recently showed that there is no point in implementing independent subsampling. Although importance sampling of datapoints could be an interesting addition, it will require custom DynamicPPL.Contexts.

The key design decisions is the following function:

# This function/signature will be moved to src/AdvancedVI.jl
"""
    subsample(model, batch)

# Arguments
- `model`: Model subject to subsampling. Could be the target model or the variational approximation.
- `batch`: Data points or indices corresponding to the subsampled "batch."

# Returns 
- `sub`: Subsampled model.
"""
prob_sub = subsample(prob, batch)
q_sub = subsample(q, batch)

Given a previous slack DM thread with @yebai , this interface could be straightforwardly implemented by Turing models as

@model function pPCA(X::AbstractMatrix{<:Real}, k::Int; data_or_indices = 1:size(X,1))
    N, D = size(X)
    N_sub = length(batch_idx)

    W ~ filldist(Normal(), D, k)
    Z ~ filldist(Normal(), k, N)

    # Subsampling
    # Some AD backends are not happy about `view`.
    # In that case, this step will commit a copy and, therefore, shall not be considered free.
    Z_sub = view(Z, :, idx)
    X_sub = view(X, :, idx)

    genes_mean = W * Z_sub
    return X_sub ~ arraydist([MvNormal(m, Eye(N_sub)) for m in eachcol(genes_mean')])
end;

where data_or_indices could be made a reserved keyword argument for Turing models. Then, I think

using Accessors

function subsample(m::DynamicPPL.Model, batch)
    n, b = length(m.defaults), length(batch)
    m = @set m.defaults = batch
    m = @set m.context = MiniBatchContext(context=m.context; b, n)
    m
end

should generally work?

My current guess would be that subsample(m::DynamicPPL.Model, batch) would have to end up in the main Turing repository.

@Red-Portal Red-Portal marked this pull request as draft August 11, 2024 21:41
@Red-Portal Red-Portal added this to the v0.3.0 milestone Aug 11, 2024
Copy link

codecov bot commented Aug 11, 2024

Codecov Report

Attention: Patch coverage is 0% with 29 lines in your changes missing coverage. Please review.

Project coverage is 82.92%. Comparing base (1b36c6e) to head (644d314).
Report is 12 commits behind head on master.

Files Patch % Lines
src/objectives/subsampling.jl 0.00% 29 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master      #84       +/-   ##
===========================================
- Coverage   96.09%   82.92%   -13.18%     
===========================================
  Files          11       12        +1     
  Lines         205      246       +41     
===========================================
+ Hits          197      204        +7     
- Misses          8       42       +34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Red-Portal
Copy link
Member Author

related: TuringLang/DynamicPPL.jl#633

@Red-Portal Red-Portal marked this pull request as ready for review September 10, 2024 05:46
@Red-Portal Red-Portal requested review from sunxd3, yebai and mhauru and removed request for sunxd3 September 10, 2024 07:39
@Red-Portal
Copy link
Member Author

subsampling_wallclocktime

Here is a comparison of the convergence speed of full-batch v.s. subsampling with respect to wall-clock time.

@yebai
Copy link
Member

yebai commented Sep 10, 2024

I suggest delaying this to AdvancedVI v0.4 so that the syntax in DynamicPPL is implemented.

@Red-Portal
Copy link
Member Author

@yebai Sounds good to me.

@Red-Portal Red-Portal modified the milestones: v0.3.0, v0.4.0 Sep 10, 2024
@Red-Portal Red-Portal added the enhancement New feature or request label Sep 10, 2024
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the theory nor the wider context of AdvancedVI, so my review is quite shallow, but I don't see any significant problems here. I left a few small local proposals.

end
```

Notice that, when computing the log-density, we multiple by a constant `likeadj`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Notice that, when computing the log-density, we multiple by a constant `likeadj`.
Notice that, when computing the log-density, we multiply by a constant `likeadj`.

```
Let's first compare the convergence of full-batch `RepGradELBO` versus subsampled `RepGradELBO` with respect to the number of iterations:

![](subsampling_iteration.svg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these .svg files missing from the repo? I haven't looked at the built docs, just don't see them in the PR.

# Returns
- `sub`: Subsampled model.
"""
subsample(model::Any, ::Any) = model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that subsampling is a more general operation than just a VI thing? If that's the case, could this be moved to DynamicPPL, or even AbstractPPL?

Also, I wonder if an empty function without methods would make more sense. Is returning the unmodified original model a reasonable fallback? I could imagine it confusing users, who would call it and get a return value, not realising it's actually just the original model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai Any comments on the current direction on the PPL side?

"""
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state, objargs...; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the varargs be explained in the docstring?

Subsampled(objective, batchsize, data)

Subsample `objective` over the dataset represented by `data` with minibatches of size `batchsize`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment on what happens if batchsize does not divide length(data), or whether that's significant at all?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants