-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MLX_UNET.py #53
base: main
Are you sure you want to change the base?
MLX_UNET.py #53
Conversation
import torch | ||
import numpy as np | ||
|
||
def test_pytorch_mlp(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
example of how to test a pytorch vs mlx implementation, cc: @bdeanhardt @ethanernst11 @levinkhho
Also the config classes, which define the model's structure, are the same so you don't need to copy in any @dataclass
configurations, you can import them from the pytorch files as done in the file above.
The benefit there is that the exact same config class can be used for either implementation, which keeps it transparent to the user
Because the test case for |
a first pass at the MLX implementation for UNET