Skip to content

Commit

Permalink
Use ad_hoc_executor
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 13, 2024
1 parent 94c3409 commit c5621cd
Showing 1 changed file with 77 additions and 59 deletions.
136 changes: 77 additions & 59 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,8 @@ def _convert_pytorchfunc_to_thundertrace(
*args:
**kwargs
"""
from thunder.core.baseutils import sequencify

active_jit_ctx: JitCtx = get_jit_ctx()
active_jit_ctx.computation_trace.push_scope([])
wrapped_func_result = _interpret_call(func, *args, **kwargs)
Expand All @@ -637,8 +639,6 @@ def _convert_pytorchfunc_to_thundertrace(
trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope())
func_result = unwrap(wrapped_func_result)
if shallow_copy_output:
from thunder.core.baseutils import sequencify

out_to_shallow_copy: dict[Variable, TensorProxy] = {}
for a in sequencify(func_result):
shallow_copy_of_a = prims.shallow_copy.meta(a)
Expand All @@ -648,7 +648,7 @@ def _convert_pytorchfunc_to_thundertrace(
func_result = tree_map(lambda t: out_to_shallow_copy.get(variableify(t), t), func_result)
with tracectx(trace):
prims.python_return(func_result)
return trace, wrapped_func_result.provenance
return trace, sequencify(wrapped_func_result)[0].provenance


@register_general_jit_lookaside(torch.autograd.function.Function.apply.__func__)
Expand Down Expand Up @@ -792,17 +792,17 @@ def _generate_random_str_id() -> str:
# non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx")
length_of_tensor_args = sum(args_tensor_mask)
new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args]
jit_ctx.computation_trace.push_scope([])

fwd_result = _interpret_call(fwd, *new_fwd_args)
if fwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return fwd_result
output, saved_values = unwrap(fwd_result)
wrapped_output = wrap(output, provenance=fwd_result.provenance)
aug_fwd_trace, aug_fwd_provenance = _convert_pytorchfunc_to_thundertrace(fwd, False, *new_fwd_args)
if aug_fwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return aug_fwd_trace
aug_fwd_result = aug_fwd_trace.output
output, saved_values = unwrap(aug_fwd_result)
wrapped_output = wrap(output, provenance=aug_fwd_provenance)

unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)[1:]

fwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope()
fwd_bsyms: list[BoundSymbol] = aug_fwd_trace.bound_symbols
producer_map = utils.producers(fwd_bsyms)
tensor_to_prod_bsym: dict[Variable, BoundSymbol] = {}
for p in tree_flatten((output, saved_values))[0]:
Expand All @@ -813,67 +813,85 @@ def _generate_random_str_id() -> str:
tensor_to_prod_bsym[variableify(p)] = prod_bsym
prod_bsym_to_tensor = {v: unvariableify(k) for k, v in tensor_to_prod_bsym.items()}

# Encapsulate custom fwd into a bsym.
sym_id = f"autograd_function_apply_{_generate_random_str_id()}"
sym = Symbol(
name=sym_id,
id=sym_id,
_module=fwd_bsyms[-1].sym.module,
)
bsym_of_custom_autograd_func = BoundSymbol(
sym,
args=unwrapped_fwd_args,
kwargs={},
output=output,
subsymbols=fwd_bsyms,
header=(
f"output of fwd_body: {output}, saved_values from fwd_body: "
f"{[t.name if isinstance(t, Proxy) else t for t in saved_values]}"
),
source_filename=jit_ctx.computation_trace._current_source_filename,
source_positions=None,
_call_ctx=fwd_bsyms[0]._call_ctx,
_import_ctx=fwd_bsyms[0]._import_ctx,
_object_ctx=fwd_bsyms[0]._object_ctx,
_executor=fwd_bsyms[0]._executor,
vanilla_fwd_trace = TraceCtx()
vanilla_fwd_trace.args = unwrapped_fwd_args
unpack_bsyms = [
prims.unpack_trivial.bind(a, name=a.name, output=a)
for a in filter(lambda a: isinstance(a, Proxy), vanilla_fwd_trace.args)
]
for bsym in unpack_bsyms + fwd_bsyms[:-1]:
vanilla_fwd_trace.add_bound_symbol(bsym)
vanilla_fwd_trace.add_bound_symbol(prims.python_return.bind(output, output=()))
vanilla_fwd_trace._siginfo = SigInfo.from_name_and_args(sym_id, vanilla_fwd_trace.args)

@wraps(vanilla_fwd_trace.python_callable())
def core_of_fwd(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(vanilla_fwd_trace, *args, **kwargs)

sym = jit_ctx.ad_hoc_executor.register_operator(
vanilla_fwd_trace._siginfo.name,
like=core_of_fwd,
)
jit_ctx.computation_trace.scopes[-1].append(bsym_of_custom_autograd_func)
unwrapped_forward_result = sym(*unwrapped_fwd_args)

# Define augmented fwd rule and backward rule on the fly.
augmented_fwd_trace = TraceCtx()
for bsym in fwd_bsyms:
augmented_fwd_trace.args = vanilla_fwd_trace.args
for bsym in unpack_bsyms + fwd_bsyms[:-1]:
augmented_fwd_trace.add_bound_symbol(bsym)
augmented_fwd_trace.add_bound_symbol(prims.python_return.bind(output, saved_values, output=()))
si = SigInfo.from_name_and_args(f"augmented_autograd_function_apply_{sym_id}", bsym_of_custom_autograd_func.args)
si = SigInfo.from_name_and_args(f"augmented_autograd_function_apply_{sym_id}", augmented_fwd_trace.args)
augmented_fwd_trace._siginfo = si
augmented_fwd_callable = augmented_fwd_trace.python_callable(include_decorators=False)

def augmented_fwd_rule(*args):
# First arg is `None` or `FunctionCtx`
updated_output, updated_saved_values = augmented_fwd_callable(*args)
residuals = tuple(sequencify(updated_saved_values))
return VJPDual(primal=updated_output, residuals=residuals)
grads = sequencify(tree_map(lambda t: TensorProxy(like=t), sequencify(output)))
bwd_args = (wrap_const(None),)
bwd_tensor_args = grads + tuple(saved_values)
wrapped_bwd_tensor_args = tree_map(lambda t: wrap(t, provenance=aug_fwd_provenance), bwd_tensor_args)
bwd_trace, bwd_trace_provenance = _convert_pytorchfunc_to_thundertrace(
bwd,
False,
*(bwd_args + wrapped_bwd_tensor_args),
)
if bwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return bwd_trace
bwd_trace.args = bwd_tensor_args
bwd_unpack_bsyms = [
prims.unpack_trivial.bind(a, name=a.name, output=a)
for a in filter(lambda a: isinstance(a, Proxy), bwd_trace.args)
]
bwd_trace.bound_symbols = bwd_unpack_bsyms + bwd_trace.bound_symbols
bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{sym_id}", saved_values + grads)

augmented_forward_impls[sym.id] = augmented_fwd_rule
@wraps(bwd_trace.python_callable())
def bwd_impl_callable(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(bwd_trace, *args, **kwargs)

jit_ctx.computation_trace.push_scope([])
bwd_trace = TraceCtx()
@wraps(core_of_fwd)
def grad_transform(*args, **kwargs):
from thunder.core.transforms import get_grad, put_grads

grads = sequencify(tree_map(lambda t: TensorProxy(like=t), output))
bwd_args = (wrap_const(None),)
bwd_tensor_args = grads + tuple(saved_values)
wrapped_bwd_tensor_args = tree_map(lambda t: wrap(t, provenance=fwd_result.provenance), bwd_tensor_args)
bwd_result = _interpret_call(bwd, *(bwd_args + wrapped_bwd_tensor_args))
if bwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return bwd_result
unwrapped_bwd_result = unwrap(bwd_result)
bwd_trace.bound_symbols = jit_ctx.computation_trace.pop_scope()
bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=()))

bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{si.name}", saved_values + grads)
backward_impls[sym.id] = bwd_trace.python_callable(include_decorators=False)

return wrapped_output
primal, residuals = thunder.core.trace_interpreter.interpret_trace(
augmented_fwd_trace,
*args,
**kwargs,
)
grads = tree_map(lambda t: get_grad(t), sequencify(primal))
bwd_args = grads + residuals
result = bwd_impl_callable(*bwd_args)
put_grads(args, result)
return primal

jit_ctx.ad_hoc_executor.register_implementation(
sym,
execution_transform=core_of_fwd,
grad_transform=grad_transform,
)

return wrap(
unwrapped_forward_result,
provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[fwd.provenance, bwd.provenance]),
)


@register_general_jit_lookaside(torch.autocast.__enter__)
Expand Down

0 comments on commit c5621cd

Please sign in to comment.