Skip to content

Commit

Permalink
pre-commit: running and fixing...
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] authored and Borda committed Mar 27, 2024
1 parent 043db17 commit ad5b138
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
8 changes: 6 additions & 2 deletions thunder/executors/cudnn_layernormex.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def cudnn_available() -> bool:
def make_cacheable_cudnn_graph_inputs(func):
def wrapper(*args, **kwargs):
cudnn_input_args = [
CudnnTensorAttributes(arg.size(), arg.stride(), arg.dtype, args.device_index) if isinstance(arg, torch.Tensor) else arg
CudnnTensorAttributes(arg.size(), arg.stride(), arg.dtype, args.device_index)
if isinstance(arg, torch.Tensor)
else arg
for arg in args
]
return func(*cudnn_input_args, **kwargs)
Expand Down Expand Up @@ -93,7 +95,9 @@ def _transform_layer_norm_inputs(a, normalized_shape, weight, bias):
# Assume strides to be NCHW contiguous
assumed_stride = (elements_to_normalize, 1, 1, 1)
a_4d = CudnnTensorAttributes((batch_size, elements_to_normalize, 1, 1), assumed_stride, a.dtype, a.device.index)
weight_4d = CudnnTensorAttributes((1, elements_to_normalize, 1, 1), assumed_stride, weight.dtype, weight.device.index)
weight_4d = CudnnTensorAttributes(
(1, elements_to_normalize, 1, 1), assumed_stride, weight.dtype, weight.device.index
)
bias_4d = CudnnTensorAttributes((1, elements_to_normalize, 1, 1), assumed_stride, bias.dtype, bias.device.index)

return a_4d, weight_4d, bias_4d
Expand Down
5 changes: 4 additions & 1 deletion thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class CudnnTensorAttributes:
dtype: torch.dtype
device_index: int


from collections import OrderedDict


Expand Down Expand Up @@ -222,7 +223,9 @@ def compute_NHWC_strides(shape):

# cudnn does not support boolean attn_mask, so make one with -inf
attn_mask_dtype = query.dtype if attn_mask.dtype in [torch.bool, dtypes.bool8] else attn_mask.dtype
attn_mask_4d = CudnnTensorAttributes(attn_mask_shape, compute_NHWC_strides(attn_mask_shape), attn_mask_dtype, attn_mask.device.index)
attn_mask_4d = CudnnTensorAttributes(
attn_mask_shape, compute_NHWC_strides(attn_mask_shape), attn_mask_dtype, attn_mask.device.index
)

return query_4d, key_4d, value_4d, attn_mask_4d

Expand Down

0 comments on commit ad5b138

Please sign in to comment.