Skip to content

Commit

Permalink
enable nvfuser matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Apr 16, 2024
1 parent 47eb3dc commit 09d9ab8
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,28 @@ def var_mean(

register_supported(PrimIDs.VAR_MEAN, var_mean, _var_mean_check)

def _matmul_check(
a: TensorProxy,
b: TensorProxy,
) -> bool:
enable_matmul: None | bool = get_compile_option("nv_enable_matmul", "Enable nvFuser matmul.")
if enable_matmul is None:
enable_matmul = False
return enable_matmul and is_supported_tensor(a) and is_supported_tensor(b)

def matmul(
a: TensorProxy,
b: TensorProxy,
*,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:
nva = getnv(a, fd, lc_to_nv_map)
nvb = getnv(b, fd, lc_to_nv_map)
return fd.ops.matmul(nva, nvb)

register_supported(PrimIDs.MATMUL, matmul, _matmul_check)


# Removes excessive float casts, like those that occur when autocasting
# NOTE This passes actually changes a program's semantics, because it will take a sequence like
Expand Down

0 comments on commit 09d9ab8

Please sign in to comment.