Skip to content

Commit

Permalink
Merge branch 'main' into support-qjit-kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdad2m authored Aug 13, 2024
2 parents 5c80cb4 + 4f18276 commit bc60e24
Show file tree
Hide file tree
Showing 16 changed files with 372 additions and 46 deletions.
34 changes: 34 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,40 @@
...
}
```

* Catalyst now has debug interfaces `get_compilation_stage` and `replace_ir` to acquire and
recompile the IR from a given pipeline pass. They can only be used with `keep_intermediate=True`.
`get_compilation_stage` is renamed from `print_compilation_stage` and now returns a IR string.
[(#981)](https://github.com/PennyLaneAI/catalyst/pull/981)

```py
from catalyst import qjit

@qjit(keep_intermediate=True)
def f(x):
return x**2
```

```pycon
>>> f(2.0)
4.0
```

```py
from catalyst.debug import get_pipeline_output, replace_ir

old_ir = get_pipeline_output(f, "HLOLoweringPass")
new_ir = old_ir.replace(
"%2 = arith.mulf %in, %in_0 : f64\n",
"%t = arith.mulf %in, %in_0 : f64\n %2 = arith.mulf %t, %in_0 : f64\n"
)
replace_ir(f, "HLOLoweringPass", new_ir)
```

```pycon
>>> f(2.0)
8.0
```

<h3>Improvements</h3>

Expand Down
100 changes: 91 additions & 9 deletions doc/dev/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Below is an example of how to obtain a C program that can be linked against the
print(debug.get_cmain(identity, 1.0))
Using the ``debug.get_cmain`` function, the following string is returned to the user:
Using the :func:`~.debug.get_cmain` function, the following string is returned to the user:

.. code-block:: C
Expand Down Expand Up @@ -181,22 +181,22 @@ to perform. Most of the standard passes are described in the
implemented in Catalyst and can be found in the sources.

All pipelines are executed in sequence, the output MLIR of each non-empty pipeline is stored in
memory and becomes available via the :func:`~.print_compilation_stage` function in the ``debug`` module.
memory and becomes available via the :func:`~.debug.get_compilation_stage` function in the ``debug`` module.
It is necessary however, to have compiled with the option ``keep_intermediate=True`` to use
:func:`~.print_compilation_stage`.
:func:`~.debug.get_compilation_stage`.

Printing the IR generated by Pass Pipelines
===========================================

We won't get into too much detail here, but sometimes it is useful to look at the output of a
specific pass pipeline.
To do so, simply use the :func:`~.print_compilation_stage` function.
To do so, simply use the :func:`~.debug.get_compilation_stage` function and print the return value out.
For example, if one wishes to inspect the output of the ``BufferizationPass`` pipeline, simply run
the following command.

.. code-block:: python
print_compilation_stage(circuit, "BufferizationPass")
print(get_compilation_stage(circuit, "BufferizationPass")
Profiling and instrumentation
=============================
Expand Down Expand Up @@ -291,26 +291,26 @@ via standard LLVM-MLIR tooling.
.. code-block:: python
print_compilation_stage(circuit, "HLOLoweringPass")
print(get_compilation_stage(circuit, "HLOLoweringPass"))
The quantum compilation pipeline expands high-level quantum instructions like adjoint, and applies quantum differentiation methods and optimization techniques.
.. code-block:: python
print_compilation_stage(circuit, "QuantumCompilationPass")
print(get_compilation_stage(circuit, "QuantumCompilationPass"))
An important step in getting to machine code from a high-level representation is allocating memory
for all the tensor/array objects in the program.
.. code-block:: python
print_compilation_stage(circuit, "BufferizationPass")
print(get_compilation_stage(circuit, "BufferizationPass"))
The LLVM dialect can be considered the "exit point" from MLIR when using LLVM for low-level compilation:
.. code-block:: python
print_compilation_stage(circuit, "MLIRToLLVMDialect")
print(get_compilation_stage(circuit, "MLIRToLLVMDialect"))
And finally some LLVMIR that is inspired by QIR.
Expand All @@ -322,3 +322,85 @@ And finally some LLVMIR that is inspired by QIR.
The LLVMIR code is compiled to an object file using the LLVM static compiler and linked to the
runtime libraries. The generated shared object is stored by the caching mechanism in Catalyst
for future calls.
Recompiling a Function
=================
Catalyst offers a way to extract IRs from pipeline stages and feed modified IRs back for recompilation.
To enable this feature, ``qjit`` decorated function must be compiled with the option ``keep_intermediate=True``.
The following example creates a square function decorated with ``@qjit(keep_intermediate=True)``.
The function must be compiled first so that the IR from each pipeline stage can be accessed.
.. code-block:: python
@qjit(keep_intermediate=True)
def f(x):
return x**2
f(2.0)
>> 4.0
After compilation, we can use :func:`~.debug.get_compilation_stage` in the ``debug`` module to get the IR from the given compiler stage.
:func:`~.debug.get_compilation_stage` accepts a ``qjit`` decorated function and a stage name in string. It return the IR after the
given stage.
The available options are:
* MLIR stages: ``mlir``, ``HLOLoweringPass``, ``QuantumCompilationPass``, ``BufferizationPass`` and ``MLIRToLLVMDialect``.
* LLVM stages: ``llvm_ir``, ``CoroOpt``, ``O2Opt``, ``Enzyme``, and ``last``.
Note that compiled functions might not always have ``CoroOpt``, ``O2Opt``, and ``Enzyme`` stages.
The option ``last`` will provide the IR right before generating its object file.
In this example, we request for the IR after ``HLOLoweringPass``.
.. code-block:: python
from catalyst.debug import get_compilation_stage
old_ir = get_compilation_stage(f, "HLOLoweringPass")
The output IR is
.. code-block:: mlir
module @f {
func.func public @jit_f(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} {
%0 = tensor.empty() : tensor<f64>
%1 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%arg0, %arg0 : tensor<f64>, tensor<f64>) outs(%0 : tensor<f64>) {
^bb0(%in: f64, %in_0: f64, %out: f64):
%2 = arith.mulf %in, %in_0 : f64
linalg.yield %2 : f64
} -> tensor<f64>
return %1 : tensor<f64>
}
func.func @setup() {
quantum.init
return
}
func.func @teardown() {
quantum.finalize
return
}
}
Here we modify ``%2 = arith.mulf %in, %in_0 : f64`` to turn the square function into a cubic one.
.. code-block:: python
new_ir = old_ir.replace(
"%2 = arith.mulf %in, %in_0 : f64\n",
"%t = arith.mulf %in, %in_0 : f64\n %2 = arith.mulf %t, %in_0 : f64\n"
)
After that, we can use :func:`~.debug.replace_ir` to make the compiler use the modified
IR for recompilation.
:func:`~.debug.replace_ir` accepts a `qjit` decorated function, a checkpoint stage name in string, and a IR in string.
The recompilation starts after the given checkpoint stage.
.. code-block:: python
from catalyst.debug import replace_ir
replace_ir(f, "HLOLoweringPass", new_ir)
f(2.0)
>> 8.0
4 changes: 4 additions & 0 deletions frontend/catalyst/compiled_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,7 @@ def insert(self, fn, args, out_treedef, workspace):
key = CacheKey(treedef, static_args)
entry = CacheEntry(fn, signature, out_treedef, workspace)
self.cache[key] = entry

def clear(self):
"""Clear all previous compiled functions"""
self.cache.clear()
6 changes: 5 additions & 1 deletion frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class CompileOptions:
static_argnums: Optional[Union[int, Iterable[int]]] = None
abstracted_axes: Optional[Union[Iterable[Iterable[str]], Dict[int, str]]] = None
lower_to_llvm: Optional[bool] = True
checkpoint_stage: Optional[str] = ""
disable_assertions: Optional[bool] = False
seed: Optional[int] = None

Expand Down Expand Up @@ -545,6 +546,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
verbose=self.options.verbose,
pipelines=self.options.get_pipelines(),
lower_to_llvm=lower_to_llvm,
checkpoint_stage=self.options.checkpoint_stage,
)
except RuntimeError as e:
raise CompileError(*e.args) from e
Expand Down Expand Up @@ -601,7 +603,9 @@ def get_output_of(self, pipeline) -> Optional[str]:
Returns
(Optional[str]): output IR
"""
if len(dict(self.options.get_pipelines()).get(pipeline, [])) == 0:
if not self.last_compiler_output or not self.last_compiler_output.get_pipeline_output(
pipeline
):
msg = f"Attempting to get output for pipeline: {pipeline},"
msg += " but no file was found.\n"
msg += "Are you sure the file exists?"
Expand Down
6 changes: 4 additions & 2 deletions frontend/catalyst/debug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from catalyst.debug.compiler_functions import (
compile_from_mlir,
get_cmain,
print_compilation_stage,
get_compilation_stage,
replace_ir,
)
from catalyst.debug.instruments import instrumentation
from catalyst.debug.printing import ( # pylint: disable=redefined-builtin
Expand All @@ -30,8 +31,9 @@
"callback",
"print",
"print_memref",
"print_compilation_stage",
"get_compilation_stage",
"get_cmain",
"compile_from_mlir",
"instrumentation",
"replace_ir",
)
48 changes: 45 additions & 3 deletions frontend/catalyst/debug/compiler_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,31 @@


@debug_logger
def print_compilation_stage(fn, stage):
def get_compilation_stage(fn, stage):
"""Print one of the recorded compilation stages for a JIT-compiled function.
The stages are indexed by their Catalyst compilation pipeline name, which are either provided
by the user as a compilation option, or predefined in ``catalyst.compiler``.
All the available stages are:
- MILR: mlir, HLOLoweringPass, QuantumCompilationPass, BufferizationPass, and MLIRToLLVMDialect
- LLVM: llvm_ir, CoroOpt, O2Opt, Enzyme, and last.
Note that `CoroOpt` (Coroutine lowering), `O2Opt` (O2 optimization), and `Enzyme` (Automatic
differentiation) passes do not always happen. `last` denotes the stage right before object file
generation.
Requires ``keep_intermediate=True``.
Args:
fn (QJIT): a qjit-decorated function
stage (str): string corresponding with the name of the stage to be printed
Returns:
str: output ir from the target compiler stage
.. seealso:: :doc:`/dev/debugging`
**Example**
Expand All @@ -55,7 +68,7 @@ def print_compilation_stage(fn, stage):
def func(x: float):
return x
>>> debug.print_compilation_stage(func, "HLOLoweringPass")
>>> print(debug.get_compilation_stage(func, "HLOLoweringPass"))
module @func {
func.func public @jit_func(%arg0: tensor<f64>)
-> tensor<f64> attributes {llvm.emit_c_interface} {
Expand All @@ -76,7 +89,9 @@ def func(x: float):
if not isinstance(fn, catalyst.QJIT):
raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.")

print(fn.compiler.get_output_of(stage))
if stage == "last":
return fn.compiler.last_compiler_output.get_output_ir()
return fn.compiler.get_output_of(stage)


@debug_logger
Expand Down Expand Up @@ -160,3 +175,30 @@ def compile_from_mlir(ir, compiler=None, compile_options=None):
result_types = [mlir.ir.RankedTensorType.parse(rt) for rt in func_data[1].split(",")]

return CompiledFunction(shared_object, qfunc_name, result_types, None, compiler.options)


@debug_logger
def replace_ir(fn, stage, new_ir):
"""Replace the IR at any compilation stage that will be used the next time the function runs.
It is important that the function signature (inputs & outputs) for the next execution matches
that of the provided IR, or else the behaviour is undefined.
All the available stages are:
- MILR: mlir, HLOLoweringPass, QuantumCompilationPass, BufferizationPass, and MLIRToLLVMDialect.
- LLVM: llvm_ir, CoroOpt, O2Opt, Enzyme, and last.
Note that `CoroOpt` (Coroutine lowering), `O2Opt` (O2 optimization), and `Enzyme` (Automatic
differentiation) passes do not always happen. `last` denotes the stage right before object file
generation.
Args:
fn (QJIT): a qjit-decorated function
stage (str): Recompilation picks up after this stage.
new_ir (str): The replacement IR to use for recompilation.
"""
fn.overwrite_ir = new_ir
fn.compiler.options.checkpoint_stage = stage
fn.fn_cache.clear()
11 changes: 10 additions & 1 deletion frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def __init__(self, fn, compile_options):
self.mlir_module = None
self.qir = None
self.out_type = None
self.overwrite_ir = None

functools.update_wrapper(self, fn)
self.user_sig = get_type_annotations(fn)
Expand Down Expand Up @@ -660,7 +661,15 @@ def compile(self):
# The MLIR function name is actually a derived type from string which has no
# `replace` method, so we need to get a regular Python string out of it.
func_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
if self.overwrite_ir:
shared_object, llvm_ir, _ = self.compiler.run_from_ir(
self.overwrite_ir,
str(self.mlir_module.operation.attributes["sym_name"]).replace('"', ""),
self.workspace,
)
else:
shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)

compiled_fn = CompiledFunction(
shared_object, func_name, restype, self.out_type, self.compile_options
)
Expand Down
6 changes: 3 additions & 3 deletions frontend/catalyst/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
workflow from Python and lowers it into MLIR, performing compiler
optimizations at the MLIR level.
To inspect the compiled MLIR from Catalyst, use
:func:`~.print_compilation_stage`,
:func:`~.get_compilation_stage`,
where ``stage="QuantumCompilationPass"``, and with ``keep_intermediate=True``
in the ``qjit`` decorator.
Expand Down Expand Up @@ -64,7 +64,7 @@ def cancel_inverses(fn=None): # pylint: disable=line-too-long
.. code-block:: python
from catalyst.debug import print_compilation_stage
from catalyst.debug import get_compilation_stage
from catalyst.passes import cancel_inverses
dev = qml.device("lightning.qubit", wires=1)
Expand Down Expand Up @@ -93,7 +93,7 @@ def g(x: float):
>>> workflow()
(Array(0.54030231, dtype=float64), Array(0.54030231, dtype=float64))
>>> print_compilation_stage(workflow, "QuantumCompilationPass")
>>> print(get_compilation_stage(workflow, "QuantumCompilationPass"))
.. code-block:: mlir
Expand Down
Loading

0 comments on commit bc60e24

Please sign in to comment.