-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ThunderFX failure: KeyError: 'l_stack0_' #1293
Comments
@kshitij12345, could you please take a look? Is it Thunder's problem or PyTorch FX's problem? |
I am seeing two separate failures in different environment (with more recent PyTorch) both of which seem to occur after the splitting phase. So, I think this particular KeyError is coming from Env 1 - File "/home/kkalambarkar/lightning-thunder/thunder/core/trace_interpreter.py", line 63, in interpret_trace
prim_func = symbol_mapper(symbol) if symbol_mapper is not None else symbol.sym
File "/home/kkalambarkar/lightning-thunder/thunder/core/transforms.py", line 2514, in vjp_symbol_mapper
raise NotImplementedError(f"VJP for {symbol.sym.id} is not implemented")
NotImplementedError: VJP for PrimIDs.COPY_WITH_SETITEM is not implemented We already have an issue for the same - #1240 Env 2 (with latest versions of PyTorch, nvFuser and thunder) - nvFuser ErrorAn error occurred while executing nvFuser FusionDefinition 32.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
Here's a script to reproduce the error:
```python
# CUDA devices:
# 0: NVIDIA RTX 6000 Ada Generation
# 1: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+gita777dea
# cuda version: 12.6
# nvfuser version: 0.2.15+git7616b54
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id32(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[1, 597, 128], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T1 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
T2 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
T3 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
T4 = fd.define_tensor(shape=[1, 597, 128], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T5 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
T6 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
S7 = fd.define_scalar(1, dtype=DataType.Int)
S8 = fd.define_scalar(1, dtype=DataType.Int)
S9 = fd.define_scalar(597, dtype=DataType.Int)
S10 = fd.define_scalar(128, dtype=DataType.Int)
T12 = fd.ops.broadcast_in_dim(T0, shape=[S7, S8, S9, S10], broadcast_dims=[0, 2, 3])
S13 = fd.define_scalar(1, dtype=DataType.Int)
S14 = fd.define_scalar(32, dtype=DataType.Int)
S15 = fd.define_scalar(597, dtype=DataType.Int)
S16 = fd.define_scalar(128, dtype=DataType.Int)
T18 = fd.ops.broadcast_in_dim(T12, shape=[S13, S14, S15, S16], broadcast_dims=[0, 1, 2, 3])
T19 = fd.ops.cast(T1, dtype=DataType.Float)
T20 = fd.ops.cast(T2, dtype=DataType.Float)
T21 = fd.ops.cast(T3, dtype=DataType.Float)
T22 = fd.ops.cast(T18, dtype=DataType.Float)
T23 = fd.ops.add(T20, T19)
T24 = fd.ops.mul(T22, T21)
T25 = fd.ops.mul(T22, T23)
T26 = fd.ops.cast(T24, dtype=DataType.BFloat16)
T27 = fd.ops.cast(T25, dtype=DataType.BFloat16)
T43 = fd.ops.slice(T26, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 597, 64], strides=[1, 1, 1, 1])
T59 = fd.ops.slice(T27, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 597, 64], strides=[1, 1, 1, 1])
T60 = fd.ops.cast(T43, dtype=DataType.Float)
T61 = fd.ops.cast(T59, dtype=DataType.Float)
T62 = fd.ops.neg(T60)
T63 = fd.ops.neg(T61)
S64 = fd.define_scalar(1, dtype=DataType.Int)
S65 = fd.define_scalar(1, dtype=DataType.Int)
S66 = fd.define_scalar(597, dtype=DataType.Int)
S67 = fd.define_scalar(128, dtype=DataType.Int)
T69 = fd.ops.broadcast_in_dim(T4, shape=[S64, S65, S66, S67], broadcast_dims=[0, 2, 3])
T85 = fd.ops.slice(T26, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 597, 128], strides=[1, 1, 1, 1])
T86 = fd.ops.cast(T62, dtype=DataType.BFloat16)
T102 = fd.ops.slice(T27, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 597, 128], strides=[1, 1, 1, 1])
T103 = fd.ops.cast(T63, dtype=DataType.BFloat16)
S104 = fd.define_scalar(1, dtype=DataType.Int)
S105 = fd.define_scalar(32, dtype=DataType.Int)
S106 = fd.define_scalar(597, dtype=DataType.Int)
S107 = fd.define_scalar(128, dtype=DataType.Int)
T109 = fd.ops.broadcast_in_dim(T69, shape=[S104, S105, S106, S107], broadcast_dims=[0, 1, 2, 3])
S110 = fd.define_scalar(0.00000, dtype=DataType.Double)
T111 = fd.ops.pad(T85, [0, 64, 0, 0, 0, 0, 0, 0], S110)
S112 = fd.define_scalar(0.00000, dtype=DataType.Double)
T113 = fd.ops.pad(T86, [64, 0, 0, 0, 0, 0, 0, 0], S112)
S114 = fd.define_scalar(0.00000, dtype=DataType.Double)
T115 = fd.ops.pad(T102, [0, 64, 0, 0, 0, 0, 0, 0], S114)
S116 = fd.define_scalar(0.00000, dtype=DataType.Double)
T117 = fd.ops.pad(T103, [64, 0, 0, 0, 0, 0, 0, 0], S116)
T118 = fd.ops.cast(T109, dtype=DataType.Float)
T119 = fd.ops.cast(T111, dtype=DataType.Float)
T120 = fd.ops.cast(T113, dtype=DataType.Float)
T121 = fd.ops.cast(T115, dtype=DataType.Float)
T122 = fd.ops.cast(T117, dtype=DataType.Float)
T123 = fd.ops.mul(T118, T21)
T124 = fd.ops.add(T120, T119)
T125 = fd.ops.mul(T118, T23)
T126 = fd.ops.add(T122, T121)
T127 = fd.ops.cast(T5, dtype=DataType.Float)
T128 = fd.ops.cast(T6, dtype=DataType.Float)
T129 = fd.ops.add(T124, T123)
T130 = fd.ops.add(T126, T125)
T131 = fd.ops.add(T128, T127)
T132 = fd.ops.cast(T129, dtype=DataType.BFloat16)
T133 = fd.ops.cast(T130, dtype=DataType.BFloat16)
T134 = fd.ops.cast(T131, dtype=DataType.BFloat16)
T135 = fd.ops.permute(T132, dims=[0, 2, 1, 3])
T136 = fd.ops.permute(T133, dims=[0, 2, 1, 3])
T137 = fd.ops.permute(T134, dims=[0, 2, 1, 3])
T142 = fd.ops.reshape(T135, new_shape=[1, 597, 4096])
T147 = fd.ops.reshape(T136, new_shape=[1, 597, 4096])
T152 = fd.ops.reshape(T137, new_shape=[1, 597, 4096])
T156 = fd.ops.reshape(T142, new_shape=[597, 4096])
T160 = fd.ops.reshape(T147, new_shape=[597, 4096])
T164 = fd.ops.reshape(T152, new_shape=[597, 4096])
T165 = fd.ops.permute(T156, dims=[1, 0])
T166 = fd.ops.permute(T160, dims=[1, 0])
T167 = fd.ops.permute(T164, dims=[1, 0])
fd.add_output(T164)
fd.add_output(T167)
fd.add_output(T160)
fd.add_output(T166)
fd.add_output(T156)
fd.add_output(T165)
with FusionDefinition() as fd:
nvfuser_fusion_id32(fd)
inputs = [
torch.randn(76416, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 597, 128), (76416, 128, 1)),
torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 128, 4096, 1)),
torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 76416, 128, 1)),
torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 128, 4096, 1)),
torch.randn(76416, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 597, 128), (76416, 128, 1)),
torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 128, 4096, 1)),
torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 76416, 128, 1)),
]
fd.execute(inputs) Traceback (most recent call last): Error from segmentation group 10: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp":1965, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Couldn't find allocation mapping for T71_l___bfloat[ iblockIdx.x488{( ceilDiv(( ceilDiv(128, 4) ), blockDim.x) )}, iblockIdx.y491{( ceilDiv(( 1 * 597 ), 1) )}, iUS492{1}, iV487{4}, ithreadIdx.x489{blockDim.x} ] ca_pos( 3 ) dim: 2 id: iS301{128}, loops: iblockIdx.x335{( ceilDiv(( ceilDiv(4096, 4) ), blockDim.x) )} iblockIdx.y435{( ceilDiv(( 1 * 597 ), 1) )} iUS436{1} iV487{4} ithreadIdx.x489{blockDim.x} Use NVFUSER_DISABLE=parallel_compile to simplify error message. Traceback (most recent call last): Error from segmentation group 10: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp":1965, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Couldn't find allocation mapping for T71_l___bfloat[ iblockIdx.x488{( ceilDiv(( ceilDiv(128, 4) ), blockDim.x) )}, iblockIdx.y491{( ceilDiv(( 1 * 597 ), 1) )}, iUS492{1}, iV487{4}, ithreadIdx.x489{blockDim.x} ] ca_pos( 3 ) dim: 2 id: iS301{128}, loops: iblockIdx.x335{( ceilDiv(( ceilDiv(4096, 4) ), blockDim.x) )} iblockIdx.y435{( ceilDiv(( 1 * 597 ), 1) )} iUS436{1} iV487{4} ithreadIdx.x489{blockDim.x} Use NVFUSER_DISABLE=parallel_compile to simplify error message.
|
On the latest Thunder (dafc79d) I tried the snippet from
#1174 (comment)
and got the following error:
My PyTorch version is
'2.4.0a0+git3827810'
.The text was updated successfully, but these errors were encountered: