-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create vae_mnist_new_architecture.jl #487
base: master
Are you sure you want to change the base?
Conversation
Proposal for using the vae mnist exmaple with the newer API from Knet. Type definitions allow increased performance (~20% faster due to lower gc time) and better readability of the network architecture.
function train(ae, dtrn, iters) | ||
img = convert(Atype, reshape(dtrn.x[:,1], (28, 28, 1, 1))) | ||
for epoch = 1:iters | ||
@time adam!(ae, dtrn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I'm not wrong, this is not the correct way to iterate over epochs, since here each time a new Adam struct is created and information (e.g. accumulated moments) from previous epochs are lost
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes completely correct! Thanks for the advice, i have adapted my example proposal accordingly.
Added the proper version for designing the training, more detailed callback and improved type definitions.
|
||
BCE = F(0) | ||
|
||
for s = 1:samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is probably not efficient, you can run all samples at once by additional "sample" batching. First, you need to reshape μ
to (nz,B,1), then you need to sample from randn with size (nz,B,Nsample) and broadcast μ
on it. Then, you can change binary_cross_entropy
to deal with (nz,B,Nsample) input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the form was not efficient for sampling multiple times within one batch. The Suggestion to broadcast is more effecient and much faster. However, I was not able to broadcast this efficiently trough the decoder network. As the sampling doesn't increase performance as far as i can tell and the majority of implementations i found do not use it, i have also abandoned this for the example here.
Multiple sampling has been removed as it is also not used in the original VAE approach.
Proposal for using the vae mnist exmaple with the newer API from Knet. Type definitions allow increased performance (~20% faster due to lower gc time) and better readability of the network architecture.