Skip to content

Commit

Permalink
Add Seq Packing in NeMo / Neva2 (#11633)
Browse files Browse the repository at this point in the history
* api updates and fixes

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix

Signed-off-by: yaoyu-33 <[email protected]>

* fix arg

Signed-off-by: yaoyu-33 <[email protected]>

* update seq packing in mock ds

Signed-off-by: yaoyu-33 <[email protected]>

* save

Signed-off-by: yaoyu-33 <[email protected]>

* update preprocess_data

Signed-off-by: yaoyu-33 <[email protected]>

* update seq packing

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix sp

Signed-off-by: yaoyu-33 <[email protected]>

* save

Signed-off-by: yaoyu-33 <[email protected]>

* fix seq packing

Signed-off-by: yaoyu-33 <[email protected]>

* add truncation and padding

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* Fix issues

Signed-off-by: yaoyu-33 <[email protected]>

* change LLaVATemplateConfig variables to class variables

* change to use field with default attributes

* Apply isort and black reformatting

Signed-off-by: yashaswikarnati <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* Add seq packing option in energon

Signed-off-by: yaoyu-33 <[email protected]>

* Fix energon conversation

Signed-off-by: yaoyu-33 <[email protected]>

* add energon option in neva training script

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* add ci test for packed seq

Signed-off-by: yaoyu-33 <[email protected]>

* fix mock dataset seq packing

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix mock dataset seq packing

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix lint and update seq pack func

Signed-off-by: yaoyu-33 <[email protected]>

* fix energon module

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix comments

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* address lightning issues

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* Update sequence_packing.py

Signed-off-by: Yu Yao <[email protected]>

* update energon requirements

Signed-off-by: yaoyu-33 <[email protected]>

* Fix for energon update

Signed-off-by: yaoyu-33 <[email protected]>

* fix for test

Signed-off-by: yaoyu-33 <[email protected]>

---------

Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yashaswikarnati <[email protected]>
Signed-off-by: Yu Yao <[email protected]>
Co-authored-by: yaoyu-33 <[email protected]>
Co-authored-by: ykarnati <[email protected]>
Co-authored-by: yashaswikarnati <[email protected]>
  • Loading branch information
4 people authored Jan 15, 2025
1 parent 3591cf8 commit 1626ddd
Show file tree
Hide file tree
Showing 24 changed files with 611 additions and 138 deletions.
18 changes: 16 additions & 2 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4329,11 +4329,24 @@ jobs:
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/vlm/neva_train.py \
python tests/collections/vlm/test_neva_train.py \
--devices=1 \
--max-steps=5 \
--experiment-dir=/tmp/nemo2_neva_results/${{ github.run_id }}
L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/vlm/test_neva_train.py \
--devices=1 \
--max-steps=5 \
--experiment-dir=/tmp/nemo2_neva_results/${{ github.run_id }} \
--use_packed_sequence
L2_NeMo_2_MLLAMA_MOCK_TRAINING:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand All @@ -4342,7 +4355,7 @@ jobs:
RUNNER: self-hosted-azure
SCRIPT: |
TRANSFORMERS_OFFLINE=1 \
python tests/collections/vlm/mllama_train.py \
python tests/collections/vlm/test_mllama_train.py \
--devices=1 \
--max-steps=5 \
--experiment-dir=/tmp/nemo2_mllama_results/${{ github.run_id }}
Expand Down Expand Up @@ -5060,6 +5073,7 @@ jobs:
- Speech_Checkpoints_tests
- L2_Stable_Diffusion_Training
- L2_NeMo_2_NEVA_MOCK_TRAINING
- L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING
- L2_NeMo_2_MLLAMA_MOCK_TRAINING
- L2_NeMo_2_GPT_Pretraining_no_transformer_engine
- L2_NeMo_2_GPT_DDP_Param_Parity_check
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/peft/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from pathlib import Path
from typing import Tuple, Union

import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
from lightning.pytorch.trainer.states import TrainerFn
from megatron.core import dist_checkpointing
from pytorch_lightning.trainer.states import TrainerFn
from rich.console import Console

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/multimodal/data/energon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(),
task_encoder: Optional[MultiModalTaskEncoder] = None,
decoder_seq_length: Optional[int] = None,
packing_buffer_size: Optional[int] = None,
) -> None:
"""
Initialize the EnergonMultiModalDataModule.
Expand All @@ -84,6 +85,8 @@ def __init__(
Defaults to MultiModalSampleConfig().
task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples.
If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None.
decoder_seq_length (int, optional): The maximum sequence length for the decoder. Used in encoder-decoder models.
packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None.
"""

super().__init__()
Expand Down Expand Up @@ -113,6 +116,7 @@ def __init__(
)
self.train_dataloader_object = None
self.val_dataloader_object = None
self.packing_buffer_size = packing_buffer_size

def io_init(self, **kwargs) -> fdl.Config[Self]:

Expand Down Expand Up @@ -146,6 +150,7 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val
task_encoder=self.task_encoder,
worker_config=worker_config,
max_samples_per_sequence=None,
packing_buffer_size=self.packing_buffer_size,
shuffle_buffer_size=100,
split_part=split,
)
Expand Down
24 changes: 22 additions & 2 deletions nemo/collections/multimodal/data/energon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import List
from typing import List, Tuple, Union

import torch
from megatron.core.packed_seq_params import PackedSeqParams

from nemo.collections.multimodal.data.energon.conversation import LLaVATemplateConfig


Expand All @@ -34,7 +37,7 @@ class ImageToken(MultiModalToken):

@dataclass
class ImageTextSample:
'''Sample type for template formatted raw image text sample'''
"""Sample type for template formatted raw image text sample"""

__key__: str = ''
images: torch.Tensor = field(default_factory=lambda: torch.empty(0))
Expand All @@ -43,6 +46,15 @@ class ImageTextSample:
loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))


@dataclass
class PackedImageTextSample(ImageTextSample):
"""Sample type for packed image text sample"""

__restore_key__: Tuple[Union[str, int, tuple], ...] = ()
position_ids: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))
packed_seq_params: PackedSeqParams = field(default_factory=lambda: PackedSeqParams())


@dataclass
class ImageTextRawBatch:
"""Sample type for image text raw batch"""
Expand All @@ -56,6 +68,14 @@ class ImageTextRawBatch:
loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))


@dataclass
class PackedImageTextRawBatch(ImageTextRawBatch):
"""Sample type for image text raw batch"""

position_ids: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))
packed_seq_params: PackedSeqParams = field(default_factory=lambda: PackedSeqParams())


@dataclass
class MultiModalSampleConfig:
image_token: ImageToken = field(default_factory=ImageToken)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/multimodal/data/energon/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class LLaVATemplateConfig(BaseConversationTemplateConfig):
"""LLava-specific template configuration which extends the base config"""

system: str = field(
default="A chat between a curious user and artificial assistant agent. "
default="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed and polite answers to user's questions."
)
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
Expand Down
170 changes: 143 additions & 27 deletions nemo/collections/multimodal/data/energon/task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@
batch_list,
batch_pad_stack,
)
from megatron.energon.task_encoder.base import stateless

from nemo.collections.multimodal.data.energon.config import ImageTextRawBatch, ImageTextSample
from nemo.collections.multimodal.data.energon.config import (
ImageTextRawBatch,
ImageTextSample,
PackedImageTextRawBatch,
PackedImageTextSample,
)
from nemo.collections.multimodal.data.energon.sample_encoder import (
InterleavedSampleEncoder,
SampleEncoder,
SimilarityInterleavedEncoder,
VQASampleEncoder,
)
from nemo.utils import logging


class MultiModalTaskEncoder(
Expand All @@ -54,16 +61,34 @@ class MultiModalTaskEncoder(
for model input.
"""

def __init__(self, tokenizer, image_processor, multimodal_sample_config):
def __init__(
self,
tokenizer,
image_processor,
multimodal_sample_config,
packed_sequence=False,
packed_sequence_size=-1,
num_image_embeddings_per_tile=576,
):
"""
Initialize the MultiModalTaskEncoder with specific encoders for different sample types.
Parameters:
tokenizer (Tokenizer): The tokenizer used for processing text across different sample types.
image_processor (ImageProcessor): The image processor used for preprocessing images.
multimodal_sample_config (MultiModalSampleConfig): MultiModalSampleConfig object.
tokenizer (Tokenizer): The tokenizer used for processing textual components across sample types.
image_processor (ImageProcessor): The image processor responsible for preprocessing image data.
multimodal_sample_config (MultiModalSampleConfig): Configuration object defining properties and
requirements for multimodal samples.
packed_sequence (bool, optional): Flag indicating whether packed sequences are used. Default is False.
packed_sequence_size (int, optional): The size of packed sequences, used when `packed_sequence` is True.
Default is -1.
num_image_embeddings_per_tile (int, optional): Number of image embeddings per image tile. Determines
the granularity of image features. Default is 576.
"""
self.tokenizer = tokenizer
self.sample_config = multimodal_sample_config
self.packed_sequence = packed_sequence
self.num_image_embeddings_per_tile = num_image_embeddings_per_tile # only used with seq packing
self.packed_sequence_size = packed_sequence_size
self.encoders: Dict[str, SampleEncoder] = {
VQASample.__name__: VQASampleEncoder(
tokenizer=tokenizer,
Expand Down Expand Up @@ -92,6 +117,7 @@ def register_encoder(self, sample_type: str, encoder: SampleEncoder) -> None:
"""
self.encoders[sample_type] = encoder

@stateless
def encode_sample(
self, sample: Union[VQASample, InterleavedSample, SimilarityInterleavedSample, CaptioningSample]
) -> ImageTextSample:
Expand All @@ -118,7 +144,9 @@ def encode_sample(
encoded_sample = encoder.encode(input_sample=sample, output_sample=ImageTextSample())
return encoded_sample

def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch:
def batch(
self, samples: List[Union[ImageTextSample, PackedImageTextSample]]
) -> Union[ImageTextRawBatch, PackedImageTextRawBatch]:
"""
Batch a list of encoded samples into a single raw batch.
Expand All @@ -131,26 +159,51 @@ def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch:
ImageTextRawBatch: The batched data, including images, tokens, labels, and loss masks.
"""

keys, images, tokens, labels, loss_mask = [], [], [], [], []
for sample in samples:
keys.append(sample.__key__)
images.append(sample.images)
tokens.append(sample.tokens)
labels.append(sample.labels)
loss_mask.append(sample.loss_mask)

batch_keys = batch_list(keys)
batch_images = batch_pad_stack(images)
batch_prompt_tokens = batch_pad_stack(tokens)
batch_labels = batch_pad_stack(labels)
batch_loss_mask = batch_pad_stack(loss_mask)
return ImageTextRawBatch(
__keys__=batch_keys,
images=batch_images,
tokens=batch_prompt_tokens,
labels=batch_labels,
loss_mask=batch_loss_mask,
)
if self.packed_sequence:
if len(samples) > 1:
raise ValueError(
"Micro batch size should be 1 when training with packed sequence, but your micro batch size "
f"is {len(samples)}. \nThe following config is equivalent to your current setting for "
f"a packed dataset. Please update your config to the following: \n"
f"Set micro batch size to 1 (currently {len(samples)})\n"
f"Set global batch size to `global_batch_size // {len(samples)}` "
f"Set packed sequence length to `original_sample_seq_len * {len(samples)}` "
f"(currently {self.packed_sequence_size}) \n"
f"For details please visit "
f"https://docs.nvidia.com/nemo-framework/user-guide/latest/sft_peft/packed_sequence.html"
)
# The batching are taken care by packing.
sample = samples[0]
return PackedImageTextRawBatch(
__keys__=sample.__key__,
images=sample.images,
tokens=sample.tokens,
labels=sample.labels,
loss_mask=sample.loss_mask,
position_ids=sample.position_ids,
packed_seq_params=sample.packed_seq_params,
)
else:
keys, images, tokens, labels, loss_mask = [], [], [], [], []
for sample in samples:
keys.append(sample.__key__)
images.append(sample.images)
tokens.append(sample.tokens)
labels.append(sample.labels)
loss_mask.append(sample.loss_mask)

batch_keys = batch_list(keys)
batch_images = batch_pad_stack(images)
batch_prompt_tokens = batch_pad_stack(tokens)
batch_labels = batch_pad_stack(labels)
batch_loss_mask = batch_pad_stack(loss_mask)
return ImageTextRawBatch(
__keys__=batch_keys,
images=batch_images,
tokens=batch_prompt_tokens,
labels=batch_labels,
loss_mask=batch_loss_mask,
)

def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
"""
Expand All @@ -165,7 +218,7 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
Returns:
dict: A dictionary containing the encoded batch data, ready for model input.
"""
batch_dict = dataclasses.asdict(batch_data)
batch_dict = batch_data.__dict__
if 'images' in batch_dict:
batch_dict['media'] = batch_dict['images']
del batch_dict['images']
Expand All @@ -177,3 +230,66 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict:
if 'attention_mask' not in batch_dict:
batch_dict['attention_mask'] = None
return batch_dict

def select_samples_to_pack(self, samples):
"""Selects which samples will be packed together.
NOTE: Energon dataloader calls this method internally if packing is used.
Please see https://nvidia.github.io/Megatron-Energon/packing.html
"""
from nemo.collections.vlm.neva.data.sequence_packing import greedy_knapsack, predict_seq_len

media_token_id = self.sample_config.image_token.token_id
lengths = [
predict_seq_len(
sample.tokens,
media_token_index=media_token_id,
num_image_embeddings_per_tile=self.num_image_embeddings_per_tile,
)
for sample in samples
]
packed_samples = greedy_knapsack(lengths, samples, self.packed_sequence_size)
avg_samples_per_bin = round(len(lengths) / len(packed_samples))
logging.info(
f"[Seq Packing Info] - Packing seq len: {self.packed_sequence_size}, "
f"Buffered samples: {len(lengths)}, Total number of bins: {len(packed_samples)}, "
f"Average samples per bin: {avg_samples_per_bin}"
)
return packed_samples

@stateless
def pack_selected_samples(self, samples):
"""
Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked.
NOTE: Energon dataloader calls this method internally if packing is used.
Please see https://nvidia.github.io/Megatron-Energon/packing.html
Args:
samples: List of ImageTaskSample instances to pack into one sample.
Returns:
ImageTaskSamplePacked instance.
"""
from nemo.collections.vlm.neva.data.sequence_packing import convert_to_packed

packed_images = torch.stack([sample.images for sample in samples])
media_token_id = self.sample_config.image_token.token_id
packed_tokens, packed_labels, packed_position_ids, packed_loss_mask, packed_seq_params = convert_to_packed(
tokens=[sample.tokens for sample in samples],
labels=[sample.labels for sample in samples],
num_image_embeddings_per_tile=self.num_image_embeddings_per_tile,
media_token_index=media_token_id,
ignore_index=self.sample_config.ignore_place_holder,
)

return PackedImageTextSample(
__key__=",".join([s.__key__ for s in samples]),
__restore_key__=(), # Will be set by energon based on `samples`
tokens=packed_tokens,
labels=packed_labels,
images=packed_images,
position_ids=packed_position_ids,
loss_mask=packed_loss_mask,
packed_seq_params=packed_seq_params,
)
2 changes: 1 addition & 1 deletion nemo/collections/vlm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import List, Optional, Union

import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams
Expand Down
Loading

0 comments on commit 1626ddd

Please sign in to comment.