-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
qjit(static_argnums=...)
fails when the marked static argument has a default value
#1163
Comments
I think this is because the |
qjit(static_argnums=...)
fails when the marks static argument has a default valueqjit(static_argnums=...)
fails when the marked static argument has a default value
Hi! I am Aniket, a PhD candidate at Duke interviewing for a role on the compiler team. I was given this issue as a technical challenge. Apart from the details mentioned in this issue, are there any other pointers that might help me tackle this issue? Thank you! |
Hi @AniketDalvi , do you have a more specific question in mind? I would love to give out more pointers and resolve any confusion you may have! |
Hi! So for installation - it says to download the PyPI wheel. Which version should I be downloading the wheel for? More specifically, I get the following error when downloading the wheel on my linux computer:
|
I don't think there's gonna be any difference w.r.t. this particular issue, any one of the three should be good. cc @rauletorresc who worked on the frontend dev plug-in |
Hey @AniketDalvi, I am not too sure why your pip says that there are only 3 versions available. We have version 0.8.1 in pypi and will be releasing version 0.9.0 soon. |
Okay I am just going to download the |
Hi! Okay, so I have followed the instruction and seemed to have successfully installed the repository. Is there a quick sanity check experiment/file I can run to verify the installation? |
You can run the tests. |
Running pytest gave me an error stating that there was an interpreter mismatch. The interpreter is Python 3.10 while the package is compatible only with 3.12 |
When you download via pip, pip enforces this compatibility check, but manually downloading it you need to make sure to download the appropriate one for you. See here for a list of several wheels with different python version compatibility. Maybe it would be easier to install from source? It just takes a long time to build LLVM initially. EDIT: It also looks like the new Catalyst version 0.9.0 is now available for download :) |
Understood, that makes sense. I might re-try it with a different wheel with a compatible python version. If not, I will resort to installing from source. |
Okay when trying to run the |
There's some required packages before building Catalyst. Maybe some of them are missing? See the build from source guide |
(If wheels are too complicated I recommend just building from source.) |
Yup installed all the required packages, but the error persists. I am now just going to build from source instead. |
To avoid all package version issues, I also recommend using a fresh virtual environment when developing, e.g.
after which you can pip install all the requirements and |
Okay I seemed to have gotten it to work from source. Most tests pass, some are skipped, and 4 debugging tests fails (as @erick-xanadu said is expected). I am running all of this from with within a conda environment on a linux machine. |
Yeah, a bunch are skipped, 4 failing ones.
Awesome! We don't normally use conda so I am happy to hear this worked for you :) |
Hi! So from my initial analysis, I traced the issue down to this check that throws an exception -
It appears that it checks the index used to specify the static argument with the number of Would like to get your thoughts on this! |
Hi! Usually what we do for these challenges is you can fork the catalyst repo, push your changes, and open a PR. It will be easier to review. It's good if it turns out to be a simple fix, but one thing I'm afraid of is whether loosening the verification would allow in some errors. Can you test the frontend test suite to make sure this does not happen? You can run |
Hi! Okay that sounds good. I am working off of a my branch. Does that work, or does it have to be a fork? |
It doesn't really matter how you develop, as long as you are able to push a pull request :). I think to do that, you do need a fork. But you can always just add a new remote to your local git workspace. git clone $pennylane/catalyst
# work on the issue
# fork $pennylane/catalyst to $yourrepo/catalyst
git remote add myrepo $yourrepo/catalyst
git push myrepo $yourbranch
# open a PR from $yourrepo/catalyst to $pennylane/catalyst |
Context
When jit-compiling a python function, the arguments of the compiled function lose their concrete values and are replaced by tracers, which at a high level means abstract variables that have the same type and shape as the concrete variable. A compiled program, called the jaxpr, uses these abstract tracers to represent how the arguments of the function are used.
The below example shows how to use Catalyst to jit-compile a function, and how to inspect the compiled jaxpr.
Notice that in the jaxpr, the type of the arguments to the function,
i64
andf64
, are the same as the type of the concrete arguments of their corresponding calls. The process of converting python to jaxpr is called tracing.One issue with arguments being abstract is when their concrete value is needed, for example when being compared to other concrete values, tracing will fail, since abstract tracers cannot be interpreted as concrete values. See here for more details.
To avoid this problem, some of the function arguments can be marked static, which essentially means when tracing, keep their concrete values, and don't replace them with tracers. This marking can be done by the
static_argnums
keyword argument ofqjit
, which takes in a list of argument indices to be marked static.However, currently in Catalyst, arguments with default values cannot be marked as
static_argnum
:Goal
We would like to support
static_argnums
in qjit to mark arguments with default values, as this is supported by nativejax.jit
:Requirements:
jax.jit
. Explicitly:Technical details
Due to reasons that do not concern us here, all jaxprs produced by
qjit
will carry atransform_named_sequence
. You can safely ignore it.The
qjit
function takes in a python function and returns aQJIT
object, which is a callable. In theQJIT
object, there is acapture
method that determines how a python function is traced into a jaxpr. See frontend/catalyst/jit.py.It should be possible to implement this functionality completely in the capture layer, without delving into the actual underlying machinery of the
trace_to_jaxpr
methods. For example, one option is to create two versions of the function, both without any default-valued arguments, adjusted to behave correctly, and trace these two functions depending on whether a default value was supplied by the user's call. Other options might be possible too.There are many cases potentially possible for how a user might call a function. For the purpose of this assessment, only the simplest case is required, aka the example above that worked with pure jax (though bonus points if you can make more complicated patterns work).
Installation help
To save time, instead of installing Catalyst from source it is also be possible to download the PyPI wheel and extract it into the
frontend
directory of a cloned catalyst repository (taking care to match git tags before hand), followed bymake frontend
. This then allows modifying the Python files in-place.Alternatively, complete instructions to install Catalyst from source can be found here, but due to the size of the llvm-project it can take a while (~3 hrs on a personal laptop) to compile.
The text was updated successfully, but these errors were encountered: