Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register Kernels as AutoGrad Ops #91

Open
fabianlim opened this issue Oct 11, 2024 · 1 comment
Open

Register Kernels as AutoGrad Ops #91

fabianlim opened this issue Oct 11, 2024 · 1 comment
Labels
future Will be affected in future versions (e.g., deprecation) help wanted Extra attention is needed

Comments

@fabianlim
Copy link
Contributor

fabianlim commented Oct 11, 2024

We have a quite a few custom autograd functions in the FOAK plugin

We should test compile with these autograds, and register them. Note that it is better to avoid what kernel-hyperdrive does as it registers them as custom_ops, see here,

  • for the kernel-hyperdrive it cant be helped as there is some stride issue
  • but if its possible, better to use this kind of wrapping.

If there are functions in autograds that need to be changed, the bench needs to be rerun for accuracy and performance checks

@fabianlim fabianlim added future Will be affected in future versions (e.g., deprecation) help wanted Extra attention is needed labels Nov 4, 2024
@fabianlim fabianlim changed the title Register Kernels as AutoGrad Ops: Torch Deprecation Warning Register Kernels as AutoGrad Ops Nov 8, 2024
@fabianlim
Copy link
Contributor Author

Using rms_layer_norm as an example, here is my attempt to list out a set of prescriptive tasks.

  1. Look at all the different kernels that are attached to a model, e.g., llama. Go through them one by one.
  2. For example, start with rms_layer_norm. In the above example, we replace the LlamaRMSNorm with the fast_rms_layernorm
  3. The implementation of fast_rms_layernorm is found here, which is an autograd function Fast_RMS_Layernorm that as a triton kernel _rms_layernorm_forward in the forward, and _rms_layernorm_backward in the backward.
  4. So to make this compilable, you must follow the pattern, to register it as a graph op. One way to do this is custom_op, as it is done here .
  5. Using custom_ops can have overhead, so if its easier, we can do this as a first pass, but we need a clean way to disable the custom_op if compile is not enabled.
  6. Finally, the more "standard" way to register ops is the torch.library.define pattern, see this issue for example.
torch.library.define("mylib::cvmm_triton", "(Tensor x, Tensor sel_index, Tensor sel, Tensor keys, ScalarType out_dtype, Tensor out_index) -> Tensor")

@torch.library.impl("mylib::cvmm_triton", "default")
def func()...
  1. Lastly, after compile works you need to run the bench to test it
tox -e run_benches --  "1 2" "4 8" benchmark_outputs scenarios.yaml full-finetuning

You would add a compiled bench, so that we can bench the speedups that compile will give, in addition to the existing benches.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
future Will be affected in future versions (e.g., deprecation) help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant