Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running Bamba natively on Pytorch #2

Open
4 of 5 tasks
ani300 opened this issue Dec 9, 2024 · 5 comments
Open
4 of 5 tasks

Running Bamba natively on Pytorch #2

ani300 opened this issue Dec 9, 2024 · 5 comments

Comments

@ani300
Copy link

ani300 commented Dec 9, 2024

This issue tracks progress on running Bamba natively on Pytorch.

Success for this issue implies the following:

cc @raghukiran1224 @fabianlim @AdnanHoque

@fabianlim
Copy link

@ani300 it may be possible that slow path is failing for SDPA, see the stack trace below

(vllm-bamba) nmg@css-host-181 nmg$ ./fmwork/github.ibm.com/hcir/v2.0/inference/transformers/dev/driver -m $css22/nmg/models/__cos/9aeedd4bd01c49a2a4a3dcc889904f70/ibm-llm-input/flim/Avengers-Bamba-9B-HF -i 128 -o 128 -b 2 -r 3 The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 19.50it/s] Traceback (most recent call last): File "/net/storage149/mnt/md0/nmg/./fmwork/github.ibm.com/hcir/v2.0/inference/transformers/dev/driver", line 49, in dts = fmwork.loop(par.reps, model.generate, kwargs) File "/net/storage149/mnt/md0/nmg/fmwork/github.ibm.com/hcir/v2.0/inference/transformers/dev/fmwork.py", line 71, in loop function(**kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/transformers/generation/utils.py", line 2231, in generate result = self._sample( File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/transformers/generation/utils.py", line 3222, in _sample outputs = self(**model_inputs, return_dict=True) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/transformers/models/bamba/modeling_bamba.py", line 1600, in forward outputs = self.model( File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/transformers/models/bamba/modeling_bamba.py", line 1424, in forward layer_outputs = decoder_layer( File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/transformers/models/bamba/modeling_bamba.py", line 1171, in forward hidden_states, self_attn_weights, cache_output = self.self_attn( File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/net/storage149/mnt/md0/nmg/miniconda3/envs/vllm-bamba/lib/python3.10/site-packages/transformers/models/bamba/modeling_bamba.py", line 612, in forward attn_output = torch.nn.functional.scaled_dot_product_attention( RuntimeError: The expanded size of the tensor (130) must match the existing size (129) at non-singleton dimension 3. Target sizes: [2, 32, 2, 130]. Tensor sizes: [2, 1, 1, 129]

@fabianlim
Copy link

fabianlim commented Dec 10, 2024

Other seqlen issues @ani300 :

  • since we inherit from mamba_ssm kernels, we need to repro the large sequence length fix .
  • However even so, the largest seqlen we cannot run 256k prefills, we will OOM with the below error: Update: turns out if you set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True it does work. But somehow there is a lot of fragmentation
scan_output = self.norm(scan_output, gate)
  File "/workspace/mamba-vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/mamba-vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/mamba-vllm/lib/python3.10/site-packages/transformers/models/bamba/modeling_bamba.py", line 647, in forward
    hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
  File "/workspace/mamba-vllm/lib/python3.10/site-packages/torch/nn/functional.py", line 2380, in silu
    return torch._C._nn.silu(input)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 79.15 GiB of which 6.01 GiB is free. Process 3326917 has 73.13 GiB memory in use. Of the allocated memory 58.64 GiB is allocated by PyTorch, and 14.00 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

@fabianlim
Copy link

fabianlim commented Dec 10, 2024

Impl issues: @ani300

  • currently we do not support padding free, this requires changing the switch calling the mamba_scan kernels varlen style, as per the VLLM. This will require significant effort, as it requires changing the depthwise conv kernels as well. VLLMs kernels are not the same as the one in the original repo.
  • currently the fp8 checkpoints are only in VLLM format, cannot be loaded in HF
  • I noticed in the mamba config we have added partial_rotary_factor, however in the checkpoint conversion script we use attn_rotary_emb instead. But this is inconsistent because in modeling_rope_utils the dim is controlled by the former.
  • the default max_position_embeddings value in the configs do not currently matched what was used for training. perhaps we can change the default value @divya-kumari32

@fabianlim
Copy link

@ani300 the issue regarding the slow path here is fixed right?

@JRosenkranz
Copy link

@fabianlim @ani300 Initial FMS implementation can be found here (which uses the slow path). https://github.com/foundation-model-stack/foundation-model-stack/tree/bamba. There is still a bug with rope (something to do with weight adaptation to fms), will let you know when fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants