diff --git a/doc/changelog.md b/doc/changelog.md
index 4dead82b69..0e355dc773 100644
--- a/doc/changelog.md
+++ b/doc/changelog.md
@@ -184,6 +184,40 @@
...
}
```
+
+* Catalyst now has debug interfaces `get_compilation_stage` and `replace_ir` to acquire and
+ recompile the IR from a given pipeline pass. They can only be used with `keep_intermediate=True`.
+ `get_compilation_stage` is renamed from `print_compilation_stage` and now returns a IR string.
+ [(#981)](https://github.com/PennyLaneAI/catalyst/pull/981)
+
+ ```py
+ from catalyst import qjit
+
+ @qjit(keep_intermediate=True)
+ def f(x):
+ return x**2
+ ```
+
+ ```pycon
+ >>> f(2.0)
+ 4.0
+ ```
+
+ ```py
+ from catalyst.debug import get_pipeline_output, replace_ir
+
+ old_ir = get_pipeline_output(f, "HLOLoweringPass")
+ new_ir = old_ir.replace(
+ "%2 = arith.mulf %in, %in_0 : f64\n",
+ "%t = arith.mulf %in, %in_0 : f64\n %2 = arith.mulf %t, %in_0 : f64\n"
+ )
+ replace_ir(f, "HLOLoweringPass", new_ir)
+ ```
+
+ ```pycon
+ >>> f(2.0)
+ 8.0
+ ```
Improvements
diff --git a/doc/dev/debugging.rst b/doc/dev/debugging.rst
index d8f31128e1..c600f1608f 100644
--- a/doc/dev/debugging.rst
+++ b/doc/dev/debugging.rst
@@ -37,7 +37,7 @@ Below is an example of how to obtain a C program that can be linked against the
print(debug.get_cmain(identity, 1.0))
-Using the ``debug.get_cmain`` function, the following string is returned to the user:
+Using the :func:`~.debug.get_cmain` function, the following string is returned to the user:
.. code-block:: C
@@ -181,22 +181,22 @@ to perform. Most of the standard passes are described in the
implemented in Catalyst and can be found in the sources.
All pipelines are executed in sequence, the output MLIR of each non-empty pipeline is stored in
-memory and becomes available via the :func:`~.print_compilation_stage` function in the ``debug`` module.
+memory and becomes available via the :func:`~.debug.get_compilation_stage` function in the ``debug`` module.
It is necessary however, to have compiled with the option ``keep_intermediate=True`` to use
-:func:`~.print_compilation_stage`.
+:func:`~.debug.get_compilation_stage`.
Printing the IR generated by Pass Pipelines
===========================================
We won't get into too much detail here, but sometimes it is useful to look at the output of a
specific pass pipeline.
-To do so, simply use the :func:`~.print_compilation_stage` function.
+To do so, simply use the :func:`~.debug.get_compilation_stage` function and print the return value out.
For example, if one wishes to inspect the output of the ``BufferizationPass`` pipeline, simply run
the following command.
.. code-block:: python
- print_compilation_stage(circuit, "BufferizationPass")
+ print(get_compilation_stage(circuit, "BufferizationPass")
Profiling and instrumentation
=============================
@@ -291,26 +291,26 @@ via standard LLVM-MLIR tooling.
.. code-block:: python
- print_compilation_stage(circuit, "HLOLoweringPass")
+ print(get_compilation_stage(circuit, "HLOLoweringPass"))
The quantum compilation pipeline expands high-level quantum instructions like adjoint, and applies quantum differentiation methods and optimization techniques.
.. code-block:: python
- print_compilation_stage(circuit, "QuantumCompilationPass")
+ print(get_compilation_stage(circuit, "QuantumCompilationPass"))
An important step in getting to machine code from a high-level representation is allocating memory
for all the tensor/array objects in the program.
.. code-block:: python
- print_compilation_stage(circuit, "BufferizationPass")
+ print(get_compilation_stage(circuit, "BufferizationPass"))
The LLVM dialect can be considered the "exit point" from MLIR when using LLVM for low-level compilation:
.. code-block:: python
- print_compilation_stage(circuit, "MLIRToLLVMDialect")
+ print(get_compilation_stage(circuit, "MLIRToLLVMDialect"))
And finally some LLVMIR that is inspired by QIR.
@@ -322,3 +322,85 @@ And finally some LLVMIR that is inspired by QIR.
The LLVMIR code is compiled to an object file using the LLVM static compiler and linked to the
runtime libraries. The generated shared object is stored by the caching mechanism in Catalyst
for future calls.
+
+Recompiling a Function
+=================
+Catalyst offers a way to extract IRs from pipeline stages and feed modified IRs back for recompilation.
+To enable this feature, ``qjit`` decorated function must be compiled with the option ``keep_intermediate=True``.
+
+The following example creates a square function decorated with ``@qjit(keep_intermediate=True)``.
+The function must be compiled first so that the IR from each pipeline stage can be accessed.
+
+.. code-block:: python
+
+ @qjit(keep_intermediate=True)
+ def f(x):
+ return x**2
+ f(2.0)
+ >> 4.0
+
+After compilation, we can use :func:`~.debug.get_compilation_stage` in the ``debug`` module to get the IR from the given compiler stage.
+:func:`~.debug.get_compilation_stage` accepts a ``qjit`` decorated function and a stage name in string. It return the IR after the
+given stage.
+
+The available options are:
+
+* MLIR stages: ``mlir``, ``HLOLoweringPass``, ``QuantumCompilationPass``, ``BufferizationPass`` and ``MLIRToLLVMDialect``.
+* LLVM stages: ``llvm_ir``, ``CoroOpt``, ``O2Opt``, ``Enzyme``, and ``last``.
+
+Note that compiled functions might not always have ``CoroOpt``, ``O2Opt``, and ``Enzyme`` stages.
+The option ``last`` will provide the IR right before generating its object file.
+
+In this example, we request for the IR after ``HLOLoweringPass``.
+
+.. code-block:: python
+
+ from catalyst.debug import get_compilation_stage
+
+ old_ir = get_compilation_stage(f, "HLOLoweringPass")
+
+The output IR is
+
+.. code-block:: mlir
+
+ module @f {
+ func.func public @jit_f(%arg0: tensor) -> tensor attributes {llvm.emit_c_interface} {
+ %0 = tensor.empty() : tensor
+ %1 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%arg0, %arg0 : tensor, tensor) outs(%0 : tensor) {
+ ^bb0(%in: f64, %in_0: f64, %out: f64):
+ %2 = arith.mulf %in, %in_0 : f64
+ linalg.yield %2 : f64
+ } -> tensor
+ return %1 : tensor
+ }
+ func.func @setup() {
+ quantum.init
+ return
+ }
+ func.func @teardown() {
+ quantum.finalize
+ return
+ }
+ }
+
+Here we modify ``%2 = arith.mulf %in, %in_0 : f64`` to turn the square function into a cubic one.
+
+.. code-block:: python
+
+ new_ir = old_ir.replace(
+ "%2 = arith.mulf %in, %in_0 : f64\n",
+ "%t = arith.mulf %in, %in_0 : f64\n %2 = arith.mulf %t, %in_0 : f64\n"
+ )
+
+After that, we can use :func:`~.debug.replace_ir` to make the compiler use the modified
+IR for recompilation.
+:func:`~.debug.replace_ir` accepts a `qjit` decorated function, a checkpoint stage name in string, and a IR in string.
+The recompilation starts after the given checkpoint stage.
+
+.. code-block:: python
+
+ from catalyst.debug import replace_ir
+
+ replace_ir(f, "HLOLoweringPass", new_ir)
+ f(2.0)
+ >> 8.0
\ No newline at end of file
diff --git a/frontend/catalyst/compiled_functions.py b/frontend/catalyst/compiled_functions.py
index 2ff483c6e9..6cea61e91c 100644
--- a/frontend/catalyst/compiled_functions.py
+++ b/frontend/catalyst/compiled_functions.py
@@ -475,3 +475,7 @@ def insert(self, fn, args, out_treedef, workspace):
key = CacheKey(treedef, static_args)
entry = CacheEntry(fn, signature, out_treedef, workspace)
self.cache[key] = entry
+
+ def clear(self):
+ """Clear all previous compiled functions"""
+ self.cache.clear()
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index ca68fd38cb..c24ea994f6 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -87,6 +87,7 @@ class CompileOptions:
static_argnums: Optional[Union[int, Iterable[int]]] = None
abstracted_axes: Optional[Union[Iterable[Iterable[str]], Dict[int, str]]] = None
lower_to_llvm: Optional[bool] = True
+ checkpoint_stage: Optional[str] = ""
disable_assertions: Optional[bool] = False
seed: Optional[int] = None
@@ -545,6 +546,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory):
verbose=self.options.verbose,
pipelines=self.options.get_pipelines(),
lower_to_llvm=lower_to_llvm,
+ checkpoint_stage=self.options.checkpoint_stage,
)
except RuntimeError as e:
raise CompileError(*e.args) from e
@@ -601,7 +603,9 @@ def get_output_of(self, pipeline) -> Optional[str]:
Returns
(Optional[str]): output IR
"""
- if len(dict(self.options.get_pipelines()).get(pipeline, [])) == 0:
+ if not self.last_compiler_output or not self.last_compiler_output.get_pipeline_output(
+ pipeline
+ ):
msg = f"Attempting to get output for pipeline: {pipeline},"
msg += " but no file was found.\n"
msg += "Are you sure the file exists?"
diff --git a/frontend/catalyst/debug/__init__.py b/frontend/catalyst/debug/__init__.py
index bd13305286..8e73fd31c4 100644
--- a/frontend/catalyst/debug/__init__.py
+++ b/frontend/catalyst/debug/__init__.py
@@ -18,7 +18,8 @@
from catalyst.debug.compiler_functions import (
compile_from_mlir,
get_cmain,
- print_compilation_stage,
+ get_compilation_stage,
+ replace_ir,
)
from catalyst.debug.instruments import instrumentation
from catalyst.debug.printing import ( # pylint: disable=redefined-builtin
@@ -30,8 +31,9 @@
"callback",
"print",
"print_memref",
- "print_compilation_stage",
+ "get_compilation_stage",
"get_cmain",
"compile_from_mlir",
"instrumentation",
+ "replace_ir",
)
diff --git a/frontend/catalyst/debug/compiler_functions.py b/frontend/catalyst/debug/compiler_functions.py
index 665066816d..f49a3b225c 100644
--- a/frontend/catalyst/debug/compiler_functions.py
+++ b/frontend/catalyst/debug/compiler_functions.py
@@ -33,18 +33,31 @@
@debug_logger
-def print_compilation_stage(fn, stage):
+def get_compilation_stage(fn, stage):
"""Print one of the recorded compilation stages for a JIT-compiled function.
The stages are indexed by their Catalyst compilation pipeline name, which are either provided
by the user as a compilation option, or predefined in ``catalyst.compiler``.
+ All the available stages are:
+
+ - MILR: mlir, HLOLoweringPass, QuantumCompilationPass, BufferizationPass, and MLIRToLLVMDialect
+
+ - LLVM: llvm_ir, CoroOpt, O2Opt, Enzyme, and last.
+
+ Note that `CoroOpt` (Coroutine lowering), `O2Opt` (O2 optimization), and `Enzyme` (Automatic
+ differentiation) passes do not always happen. `last` denotes the stage right before object file
+ generation.
+
Requires ``keep_intermediate=True``.
Args:
fn (QJIT): a qjit-decorated function
stage (str): string corresponding with the name of the stage to be printed
+ Returns:
+ str: output ir from the target compiler stage
+
.. seealso:: :doc:`/dev/debugging`
**Example**
@@ -55,7 +68,7 @@ def print_compilation_stage(fn, stage):
def func(x: float):
return x
- >>> debug.print_compilation_stage(func, "HLOLoweringPass")
+ >>> print(debug.get_compilation_stage(func, "HLOLoweringPass"))
module @func {
func.func public @jit_func(%arg0: tensor)
-> tensor attributes {llvm.emit_c_interface} {
@@ -76,7 +89,9 @@ def func(x: float):
if not isinstance(fn, catalyst.QJIT):
raise TypeError(f"First argument needs to be a 'QJIT' object, got a {type(fn)}.")
- print(fn.compiler.get_output_of(stage))
+ if stage == "last":
+ return fn.compiler.last_compiler_output.get_output_ir()
+ return fn.compiler.get_output_of(stage)
@debug_logger
@@ -160,3 +175,30 @@ def compile_from_mlir(ir, compiler=None, compile_options=None):
result_types = [mlir.ir.RankedTensorType.parse(rt) for rt in func_data[1].split(",")]
return CompiledFunction(shared_object, qfunc_name, result_types, None, compiler.options)
+
+
+@debug_logger
+def replace_ir(fn, stage, new_ir):
+ """Replace the IR at any compilation stage that will be used the next time the function runs.
+
+ It is important that the function signature (inputs & outputs) for the next execution matches
+ that of the provided IR, or else the behaviour is undefined.
+
+ All the available stages are:
+
+ - MILR: mlir, HLOLoweringPass, QuantumCompilationPass, BufferizationPass, and MLIRToLLVMDialect.
+
+ - LLVM: llvm_ir, CoroOpt, O2Opt, Enzyme, and last.
+
+ Note that `CoroOpt` (Coroutine lowering), `O2Opt` (O2 optimization), and `Enzyme` (Automatic
+ differentiation) passes do not always happen. `last` denotes the stage right before object file
+ generation.
+
+ Args:
+ fn (QJIT): a qjit-decorated function
+ stage (str): Recompilation picks up after this stage.
+ new_ir (str): The replacement IR to use for recompilation.
+ """
+ fn.overwrite_ir = new_ir
+ fn.compiler.options.checkpoint_stage = stage
+ fn.fn_cache.clear()
diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py
index c73663a0c2..29680d2538 100644
--- a/frontend/catalyst/jit.py
+++ b/frontend/catalyst/jit.py
@@ -421,6 +421,7 @@ def __init__(self, fn, compile_options):
self.mlir_module = None
self.qir = None
self.out_type = None
+ self.overwrite_ir = None
functools.update_wrapper(self, fn)
self.user_sig = get_type_annotations(fn)
@@ -660,7 +661,15 @@ def compile(self):
# The MLIR function name is actually a derived type from string which has no
# `replace` method, so we need to get a regular Python string out of it.
func_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
- shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
+ if self.overwrite_ir:
+ shared_object, llvm_ir, _ = self.compiler.run_from_ir(
+ self.overwrite_ir,
+ str(self.mlir_module.operation.attributes["sym_name"]).replace('"', ""),
+ self.workspace,
+ )
+ else:
+ shared_object, llvm_ir, _ = self.compiler.run(self.mlir_module, self.workspace)
+
compiled_fn = CompiledFunction(
shared_object, func_name, restype, self.out_type, self.compile_options
)
diff --git a/frontend/catalyst/passes.py b/frontend/catalyst/passes.py
index d7ade7a7cf..da763b2ffa 100644
--- a/frontend/catalyst/passes.py
+++ b/frontend/catalyst/passes.py
@@ -31,7 +31,7 @@
workflow from Python and lowers it into MLIR, performing compiler
optimizations at the MLIR level.
To inspect the compiled MLIR from Catalyst, use
- :func:`~.print_compilation_stage`,
+ :func:`~.get_compilation_stage`,
where ``stage="QuantumCompilationPass"``, and with ``keep_intermediate=True``
in the ``qjit`` decorator.
@@ -64,7 +64,7 @@ def cancel_inverses(fn=None): # pylint: disable=line-too-long
.. code-block:: python
- from catalyst.debug import print_compilation_stage
+ from catalyst.debug import get_compilation_stage
from catalyst.passes import cancel_inverses
dev = qml.device("lightning.qubit", wires=1)
@@ -93,7 +93,7 @@ def g(x: float):
>>> workflow()
(Array(0.54030231, dtype=float64), Array(0.54030231, dtype=float64))
- >>> print_compilation_stage(workflow, "QuantumCompilationPass")
+ >>> print(get_compilation_stage(workflow, "QuantumCompilationPass"))
.. code-block:: mlir
diff --git a/frontend/test/lit/test_peephole_optimizations.py b/frontend/test/lit/test_peephole_optimizations.py
index e188ef13de..8ace8dcfac 100644
--- a/frontend/test/lit/test_peephole_optimizations.py
+++ b/frontend/test/lit/test_peephole_optimizations.py
@@ -29,7 +29,7 @@
import pennylane as qml
from catalyst import qjit
-from catalyst.debug import print_compilation_stage
+from catalyst.debug import get_compilation_stage
from catalyst.passes import cancel_inverses
@@ -41,7 +41,7 @@ def flush_peephole_opted_mlir_to_iostream(QJIT):
to retrieve it with keep_intermediate=True and manually access the "2_QuantumCompilationPass.mlir".
Then we delete the kept intermediates to avoid pollution of the workspace
"""
- print_compilation_stage(QJIT, "QuantumCompilationPass")
+ print(get_compilation_stage(QJIT, "QuantumCompilationPass"))
shutil.rmtree(QJIT.__name__)
diff --git a/frontend/test/lit/test_tensor_ops.mlir.py b/frontend/test/lit/test_tensor_ops.mlir.py
index fc8a9c8542..0b0d889a6d 100644
--- a/frontend/test/lit/test_tensor_ops.mlir.py
+++ b/frontend/test/lit/test_tensor_ops.mlir.py
@@ -18,7 +18,7 @@
from jax import numpy as jnp
from catalyst import measure, qjit
-from catalyst.debug import print_compilation_stage
+from catalyst.debug import get_compilation_stage
# Test methodology:
# Each mathematical function found in numpy
@@ -43,7 +43,7 @@ def test_ewise_arctan2(x, y):
test_ewise_arctan2(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_arctan2, "BufferizationPass")
+print(get_compilation_stage(test_ewise_arctan2, "BufferizationPass"))
# Need more time to test
# jnp.ldexp
@@ -76,7 +76,7 @@ def test_ewise_add(x, y):
test_ewise_add(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_add, "BufferizationPass")
+print(get_compilation_stage(test_ewise_add, "BufferizationPass"))
# CHECK-LABEL: test_ewise_mult
@@ -91,7 +91,7 @@ def test_ewise_mult(x, y):
test_ewise_mult(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_mult, "BufferizationPass")
+print(get_compilation_stage(test_ewise_mult, "BufferizationPass"))
# CHECK-LABEL: test_ewise_div
@@ -106,7 +106,7 @@ def test_ewise_div(x, y):
test_ewise_div(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_div, "BufferizationPass")
+print(get_compilation_stage(test_ewise_div, "BufferizationPass"))
# CHECK-LABEL: test_ewise_power
@@ -121,7 +121,7 @@ def test_ewise_power(x, y):
test_ewise_power(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_power, "BufferizationPass")
+print(get_compilation_stage(test_ewise_power, "BufferizationPass"))
# CHECK-LABEL: test_ewise_sub
@@ -136,7 +136,7 @@ def test_ewise_sub(x, y):
test_ewise_sub(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_sub, "BufferizationPass")
+print(get_compilation_stage(test_ewise_sub, "BufferizationPass"))
@qjit(keep_intermediate=True)
@@ -151,7 +151,7 @@ def test_ewise_true_div(x, y):
test_ewise_true_div(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_true_div, "BufferizationPass")
+print(get_compilation_stage(test_ewise_true_div, "BufferizationPass"))
# Not sure why the following ops are not working
# perhaps they rely on another function?
@@ -170,7 +170,7 @@ def test_ewise_float_power(x, y):
test_ewise_float_power(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_float_power, "BufferizationPass")
+print(get_compilation_stage(test_ewise_float_power, "BufferizationPass"))
# Not sure why the following ops are not working
@@ -195,7 +195,7 @@ def test_ewise_maximum(x, y):
test_ewise_maximum(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_maximum, "BufferizationPass")
+print(get_compilation_stage(test_ewise_maximum, "BufferizationPass"))
# Only single function support
# * jnp.fmax
@@ -213,7 +213,7 @@ def test_ewise_minimum(x, y):
test_ewise_minimum(jnp.array(1.0), jnp.array(2.0))
-print_compilation_stage(test_ewise_minimum, "BufferizationPass")
+print(get_compilation_stage(test_ewise_minimum, "BufferizationPass"))
# Only single function support
# * jnp.fmin
diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py
index 5ba952fcd0..e63d64ecf8 100644
--- a/frontend/test/pytest/test_debug.py
+++ b/frontend/test/pytest/test_debug.py
@@ -9,8 +9,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
import platform
import re
+import shutil
import jax.numpy as jnp
import numpy as np
@@ -18,9 +20,14 @@
import pytest
from jax.tree_util import register_pytree_node_class
-from catalyst import debug, for_loop, qjit
+from catalyst import debug, for_loop, qjit, value_and_grad
from catalyst.compiler import CompileOptions, Compiler
-from catalyst.debug import compile_from_mlir, get_cmain, print_compilation_stage
+from catalyst.debug import (
+ compile_from_mlir,
+ get_cmain,
+ get_compilation_stage,
+ replace_ir,
+)
from catalyst.utils.exceptions import CompileError
from catalyst.utils.runtime_environment import get_lib_path
@@ -231,7 +238,7 @@ def test_hlo_lowering_stage(self, capsys):
def func():
return 0
- print_compilation_stage(func, "HLOLoweringPass")
+ print(get_compilation_stage(func, "HLOLoweringPass"))
out, _ = capsys.readouterr()
assert "@jit_func() -> tensor" in out
@@ -246,7 +253,7 @@ def func():
return 0
with pytest.raises(TypeError, match="needs to be a 'QJIT' object"):
- print_compilation_stage(func, "HLOLoweringPass")
+ print(get_compilation_stage(func, "HLOLoweringPass"))
class TestCompileFromIR:
@@ -429,6 +436,112 @@ def f(x: float):
with pytest.raises(TypeError, match="First argument needs to be a 'QJIT' object"):
get_cmain(f, 0.5)
+ @pytest.mark.parametrize(
+ ("pass_name", "target", "replacement"),
+ [
+ (
+ "mlir",
+ "%0 = stablehlo.multiply %arg0, %arg0 : tensor\n",
+ "%x = stablehlo.multiply %arg0, %arg0 : tensor\n"
+ + " %0 = stablehlo.multiply %x, %arg0 : tensor\n",
+ ),
+ (
+ "HLOLoweringPass",
+ "%2 = arith.mulf %in, %in_0 : f64\n",
+ "%t = arith.mulf %in, %in_0 : f64\n" + " %2 = arith.mulf %t, %in_0 : f64\n",
+ ),
+ (
+ "QuantumCompilationPass",
+ "%2 = arith.mulf %in, %in_0 : f64\n",
+ "%t = arith.mulf %in, %in_0 : f64\n" + " %2 = arith.mulf %t, %in_0 : f64\n",
+ ),
+ (
+ "BufferizationPass",
+ "%6 = arith.mulf %in, %in_0 : f64\n",
+ "%t = arith.mulf %in, %in_0 : f64\n" + " %6 = arith.mulf %t, %in_0 : f64\n",
+ ),
+ (
+ "MLIRToLLVMDialect",
+ "%21 = llvm.fmul %19, %20 : f64\n",
+ "%t = llvm.fmul %19, %20 : f64\n" + " %21 = llvm.fmul %t, %20 : f64\n",
+ ),
+ (
+ "llvm_ir",
+ "store double %15, ptr %9, align 8\n",
+ "%t1 = load double, ptr %1, align 8\n"
+ + " %t2 = fmul double %15, %t1\n"
+ + " store double %t2, ptr %9, align 8\n",
+ ),
+ (
+ "last",
+ "store double %15, ptr %9, align 8\n",
+ "%t1 = load double, ptr %1, align 8\n"
+ + " %t2 = fmul double %15, %t1\n"
+ + " store double %t2, ptr %9, align 8\n",
+ ),
+ ],
+ )
+ def test_modify_ir(self, pass_name, target, replacement):
+ """Turn a square function in IRs into a cubic one."""
+
+ @qjit(keep_intermediate=True)
+ def f(x):
+ """Square function."""
+ return x**2
+
+ data = 2.0
+ old_result = f(data)
+ old_ir = get_compilation_stage(f, pass_name)
+ old_workspace = str(f.workspace)
+
+ new_ir = old_ir.replace(target, replacement)
+ replace_ir(f, pass_name, new_ir)
+ new_result = f(data)
+
+ shutil.rmtree(old_workspace, ignore_errors=True)
+ shutil.rmtree(str(f.workspace), ignore_errors=True)
+ assert old_result * data == new_result
+
+ @pytest.mark.parametrize("pass_name", ["HLOLoweringPass", "O2Opt", "Enzyme"])
+ def test_modify_ir_file_generation(self, pass_name):
+ """Test if recompilation rerun the same pass."""
+
+ @qjit
+ def f(x: float):
+ """Square function."""
+ return x**2
+
+ grad_f = qjit(value_and_grad(f), keep_intermediate=True)
+ grad_f(3.0)
+ ir = get_compilation_stage(grad_f, pass_name)
+ old_workspace = str(grad_f.workspace)
+
+ replace_ir(grad_f, pass_name, ir)
+ grad_f(3.0)
+ file_list = os.listdir(str(grad_f.workspace))
+ res = [i for i in file_list if pass_name in i]
+
+ shutil.rmtree(old_workspace, ignore_errors=True)
+ shutil.rmtree(str(grad_f.workspace), ignore_errors=True)
+ assert len(res) == 0
+
+ def test_get_compilation_stage_without_keep_intermediate(self):
+ """Test if error is raised when using get_pipeline_output without keep_intermediate."""
+
+ @qjit
+ def f(x: float):
+ """Square function."""
+ return x**2
+
+ f(2.0)
+
+ with pytest.raises(
+ CompileError,
+ match="Attempting to get output for pipeline: mlir, "
+ "but no file was found.\nAre you sure the file exists?",
+ ):
+ get_compilation_stage(f, "mlir")
+
if __name__ == "__main__":
pytest.main(["-x", __file__])
diff --git a/frontend/test/pytest/test_mid_circuit_measurement.py b/frontend/test/pytest/test_mid_circuit_measurement.py
index af880ca3bb..65c3287f65 100644
--- a/frontend/test/pytest/test_mid_circuit_measurement.py
+++ b/frontend/test/pytest/test_mid_circuit_measurement.py
@@ -313,6 +313,9 @@ def circuit(x):
expected_call_count = 1 if postselect_mode == "hw-like" else 0
assert spy.call_count == expected_call_count
+ @pytest.mark.xfail(
+ reason="Midcircuit measurements with sampling is unseeded and hence this test is flaky"
+ )
@pytest.mark.parametrize("postselect_mode", [None, "fill-shots", "hw-like"])
@pytest.mark.parametrize("mcm_method", [None, "one-shot"])
def test_mcm_method_with_dict_output(self, backend, postselect_mode, mcm_method):
diff --git a/mlir/include/Driver/CompilerDriver.h b/mlir/include/Driver/CompilerDriver.h
index 4724131dfe..a2fac2885b 100644
--- a/mlir/include/Driver/CompilerDriver.h
+++ b/mlir/include/Driver/CompilerDriver.h
@@ -80,6 +80,8 @@ struct CompilerOptions {
std::vector pipelinesCfg;
/// Whether to assume that the pipelines output is a valid LLVM dialect and lower it to LLVM IR
bool lowerToLLVM;
+ /// Specify that the compiler should start after reaching the given pass.
+ std::string checkpointStage;
/// Get the destination of the object file at the end of compilation.
std::string getObjectFile() const
@@ -97,6 +99,8 @@ struct CompilerOutput {
FunctionAttributes inferredAttributes;
PipelineOutputs pipelineOutputs;
size_t pipelineCounter = 0;
+ /// if the compiler reach the pass specified by startAfterPass.
+ bool isCheckpointFound;
// Gets the next pipeline dump file name, prefixed with number.
std::string nextPipelineDumpFilename(Pipeline::Name pipelineName, std::string ext = ".mlir")
diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp
index 886a5c6ecb..df94dfcc8d 100644
--- a/mlir/lib/Driver/CompilerDriver.cpp
+++ b/mlir/lib/Driver/CompilerDriver.cpp
@@ -380,6 +380,10 @@ LogicalResult inferMLIRReturnTypes(MLIRContext *ctx, llvm::Type *returnType,
LogicalResult runCoroLLVMPasses(const CompilerOptions &options,
std::shared_ptr llvmModule, CompilerOutput &output)
{
+ if (options.checkpointStage != "" && !output.isCheckpointFound) {
+ output.isCheckpointFound = options.checkpointStage == "CoroOpt";
+ return success();
+ }
auto &outputs = output.pipelineOutputs;
@@ -424,6 +428,10 @@ LogicalResult runO2LLVMPasses(const CompilerOptions &options,
// opt -O2
// As seen here:
// https://llvm.org/docs/NewPassManager.html#just-tell-me-how-to-run-the-default-optimization-pipeline-with-the-new-pass-manager
+ if (options.checkpointStage != "" && !output.isCheckpointFound) {
+ output.isCheckpointFound = options.checkpointStage == "O2Opt";
+ return success();
+ }
auto &outputs = output.pipelineOutputs;
// Create the analysis managers.
@@ -463,6 +471,11 @@ LogicalResult runO2LLVMPasses(const CompilerOptions &options,
LogicalResult runEnzymePasses(const CompilerOptions &options,
std::shared_ptr llvmModule, CompilerOutput &output)
{
+ if (options.checkpointStage != "" && !output.isCheckpointFound) {
+ output.isCheckpointFound = options.checkpointStage == "Enzyme";
+ return success();
+ }
+
auto &outputs = output.pipelineOutputs;
// Create the new pass manager builder.
// Take a look at the PassBuilder constructor parameters for more
@@ -518,6 +531,10 @@ LogicalResult runLowering(const CompilerOptions &options, MLIRContext *ctx, Modu
// Fill all the pipe-to-pipeline mappings
for (const auto &pipeline : options.pipelinesCfg) {
+ if (options.checkpointStage != "" && !output.isCheckpointFound) {
+ output.isCheckpointFound = options.checkpointStage == pipeline.name;
+ continue;
+ }
size_t existingPasses = pm.size();
if (failed(parsePassPipeline(joinPasses(pipeline.passes), pm, options.diagnosticStream))) {
return failure();
@@ -533,12 +550,11 @@ LogicalResult runLowering(const CompilerOptions &options, MLIRContext *ctx, Modu
}
}
- if (options.keepIntermediate) {
- std::string tmp;
- llvm::raw_string_ostream s{tmp};
+ if (options.keepIntermediate && options.checkpointStage == "") {
+ llvm::raw_string_ostream s{outputs["mlir"]};
s << moduleOp;
dumpToFile(options, output.nextPipelineDumpFilename(options.moduleName.str(), ".mlir"),
- tmp);
+ outputs["mlir"]);
}
catalyst::utils::Timer timer{};
@@ -630,6 +646,8 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &
OwningOpRef op =
timer::timer(parseMLIRSource, "parseMLIRSource", /* add_endl */ false, &ctx, *sourceMgr);
catalyst::utils::LinesCount::ModuleOp(*op);
+ output.isCheckpointFound = options.checkpointStage == "mlir";
+
bool enzymeRun = false;
if (op) {
enzymeRun = containsGradients(*op);
@@ -652,7 +670,11 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &
catalyst::utils::LinesCount::Module(*llvmModule);
if (options.keepIntermediate) {
- dumpToFile(options, output.nextPipelineDumpFilename("llvm_ir", ".ll"), *llvmModule);
+ auto &outputs = output.pipelineOutputs;
+ llvm::raw_string_ostream rawStringOstream{outputs["llvm_ir"]};
+ llvmModule->print(rawStringOstream, nullptr);
+ auto outFile = output.nextPipelineDumpFilename("llvm_ir", ".ll");
+ dumpToFile(options, outFile, outputs["llvm_ir"]);
}
}
}
@@ -662,6 +684,8 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &
llvm::SMDiagnostic err;
llvmModule = timer::timer(parseLLVMSource, "parseLLVMSource", /* add_endl */ false,
llvmContext, options.source, options.moduleName, err);
+ output.isCheckpointFound = options.checkpointStage == "llvm_ir";
+
if (!llvmModule) {
// If both MLIR and LLVM failed to parse, exit.
err.print(options.moduleName.data(), options.diagnosticStream);
diff --git a/mlir/python/PyCompilerDriver.cpp b/mlir/python/PyCompilerDriver.cpp
index 707b24205c..96faccdc5b 100644
--- a/mlir/python/PyCompilerDriver.cpp
+++ b/mlir/python/PyCompilerDriver.cpp
@@ -72,13 +72,15 @@ PYBIND11_MODULE(compiler_driver, m)
.def("get_function_attributes",
[](const CompilerOutput &co) -> FunctionAttributes { return co.inferredAttributes; })
.def("get_diagnostic_messages",
- [](const CompilerOutput &co) -> std::string { return co.diagnosticMessages; });
+ [](const CompilerOutput &co) -> std::string { return co.diagnosticMessages; })
+ .def("get_is_checkpoint_found",
+ [](const CompilerOutput &co) -> bool { return co.isCheckpointFound; });
m.def(
"run_compiler_driver",
[](const char *source, const char *workspace, const char *moduleName, bool keepIntermediate,
- bool asyncQnodes, bool verbose, py::list pipelines,
- bool lower_to_llvm) -> std::unique_ptr {
+ bool asyncQnodes, bool verbose, py::list pipelines, bool lower_to_llvm,
+ const char *checkpointStage) -> std::unique_ptr {
// Install signal handler to catch user interrupts (e.g. CTRL-C).
signal(SIGINT,
[](int code) { throw std::runtime_error("KeyboardInterrupt (SIGINT)"); });
@@ -96,7 +98,8 @@ PYBIND11_MODULE(compiler_driver, m)
.asyncQnodes = asyncQnodes,
.verbosity = verbose ? Verbosity::All : Verbosity::Urgent,
.pipelinesCfg = parseCompilerSpec(pipelines),
- .lowerToLLVM = lower_to_llvm};
+ .lowerToLLVM = lower_to_llvm,
+ .checkpointStage = checkpointStage};
errStream.flush();
@@ -108,5 +111,5 @@ PYBIND11_MODULE(compiler_driver, m)
py::arg("source"), py::arg("workspace"), py::arg("module_name") = "jit source",
py::arg("keep_intermediate") = false, py::arg("async_qnodes") = false,
py::arg("verbose") = false, py::arg("pipelines") = py::list(),
- py::arg("lower_to_llvm") = true);
+ py::arg("lower_to_llvm") = true, py::arg("checkpoint_stage") = "");
}
diff --git a/setup.py b/setup.py
index 02157925e6..609afd5940 100644
--- a/setup.py
+++ b/setup.py
@@ -185,6 +185,8 @@ def run(self):
variables = sysconfig.get_config_vars()
# Here we need to switch the deault to MacOs dynamic lib
variables["LDSHARED"] = variables["LDSHARED"].replace("-bundle", "-dynamiclib")
+ if sysconfig.get_config_var("LDCXXSHARED"):
+ variables["LDCXXSHARED"] = variables["LDCXXSHARED"].replace("-bundle", "-dynamiclib")
custom_calls_extension = Extension(
"catalyst.utils.libcustom_calls",
sources=["frontend/catalyst/utils/libcustom_calls.cpp"],