-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bandicoot library example to train an mnist example #219
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: Omar Shrit <[email protected]>
Signed-off-by: Omar Shrit <[email protected]>
This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm looking forward to when this works! We will also need to think about ensuring that the CI workers have a GPU so that we can actually test this example. Or, something along those lines.
* @author Omar Shrit | ||
*/ | ||
#define MLPACK_ENABLE_ANN_SERIALIZATION | ||
#define MLPACK_HAS_COOT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a TODO, we should integrate definition of MLPACK_HAS_COOT
into the mlpack headers (maybe in prereqs.hpp
or something, we try to detect Bandicoot), so that users don't have to write it by hand. You're probably already thinking that, I just wanted to write it down so it doesn't get forgotten.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, on my list,
* @author Eugene Freyman | ||
* @author Omar Shrit | ||
*/ | ||
#define MLPACK_ENABLE_ANN_SERIALIZATION |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Serialization will be a little tricky. When this macro is enabled, we compile all the serialize()
functions for all the layer types---for the type arma::mat
. Instead, we'll need to do CEREAL_REGISTER_MLPACK_LAYERS(coot::mat)
or similar. Also, we'll need to actually implement a serialize()
function for Bandicoot types, a bit like in src/mlpack/core/arma_extend/serialize_armadillo.hpp
. I think the strategy should just be to convert it to an Armadillo matrix (i.e. pull it off the GPU into the CPU), and then serialize that, plus a little bit of code to put the Armadillo matrix back onto the GPU during loading. Anyway, I can help with that when we get there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for now, the easy solution for now would be to convert it back to Armadillo.
In all cases, we are loading an armadillo matrix, so we will have to do the conversion at least once.
Once we have the load functionality for Bandicoot, if it makes sense, then we can think of serialization.
This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍 |
needs to be reimplemented again, but I am re-opening it so I do not forget |
Hi,
This is a draft and still does not compile.