Skip to content

Commit

Permalink
support cuda graph
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 committed Oct 20, 2024
1 parent d127e15 commit 46693e3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
14 changes: 12 additions & 2 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def __init__(self, model_runner: "ModelRunner"):

# Capture
try:
self.capture()
with self.model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
Expand All @@ -157,6 +158,16 @@ def __init__(self, model_runner: "ModelRunner"):
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)

@contextmanager
def model_capture_mode(self):
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = True

yield

if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False

def can_run(self, batch_size: int):
if self.disable_padding:
return batch_size in self.graphs
Expand Down Expand Up @@ -190,7 +201,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
out_cache_loc = self.out_cache_loc[:bs]

# Fake encoder lens: just to initialize the attention wrappers
# TODO: support encoder-decode in cuda graph
encoder_lens = torch.zeros_like(seq_lens)

# Attention backend
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def __init__(
bias=True,
)
self.logits_processor = LogitsProcessor(config.text_config)
self.capture_mode = False

def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
pixel_values = image_inputs.pixel_values
Expand Down Expand Up @@ -917,13 +918,17 @@ def forward(
# TODO: support multi-image by this mask
cross_attention_mask = None

if self.capture_mode:
skip_cross_attention = False
else:
skip_cross_attention = forward_batch.encoder_lens.max() == 0

if batched_images is None:
# For 1) text-only prefill and decode, 2) image-present decode.
full_text_row_masked_out_mask = (
(forward_batch.encoder_lens != 0).reshape(-1, 1).to(input_ids.device)
)
cross_attention_states = None
skip_cross_attention = forward_batch.encoder_lens.max() == 0
else:
# NOTE: llama's reference implementation runs vision model on CPU
cross_attention_states = self.vision_model(
Expand All @@ -942,7 +947,6 @@ def forward(
full_text_row_masked_out_mask = self.get_full_text_row_masked_out_mask(
forward_batch
)
skip_cross_attention = False

hidden_states = self.language_model(
input_ids=input_ids,
Expand Down

0 comments on commit 46693e3

Please sign in to comment.