From 1626ddded63af1fe82b0bee482fba44c57b8203d Mon Sep 17 00:00:00 2001 From: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Date: Wed, 15 Jan 2025 13:25:44 -0800 Subject: [PATCH] Add Seq Packing in NeMo / Neva2 (#11633) * api updates and fixes Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix Signed-off-by: yaoyu-33 * fix arg Signed-off-by: yaoyu-33 * update seq packing in mock ds Signed-off-by: yaoyu-33 * save Signed-off-by: yaoyu-33 * update preprocess_data Signed-off-by: yaoyu-33 * update seq packing Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix sp Signed-off-by: yaoyu-33 * save Signed-off-by: yaoyu-33 * fix seq packing Signed-off-by: yaoyu-33 * add truncation and padding Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Fix issues Signed-off-by: yaoyu-33 * change LLaVATemplateConfig variables to class variables * change to use field with default attributes * Apply isort and black reformatting Signed-off-by: yashaswikarnati * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Add seq packing option in energon Signed-off-by: yaoyu-33 * Fix energon conversation Signed-off-by: yaoyu-33 * add energon option in neva training script Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * add ci test for packed seq Signed-off-by: yaoyu-33 * fix mock dataset seq packing Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix mock dataset seq packing Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix lint and update seq pack func Signed-off-by: yaoyu-33 * fix energon module Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix comments Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * address lightning issues Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Update sequence_packing.py Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> * update energon requirements Signed-off-by: yaoyu-33 * Fix for energon update Signed-off-by: yaoyu-33 * fix for test Signed-off-by: yaoyu-33 --------- Signed-off-by: yaoyu-33 Signed-off-by: yaoyu-33 Signed-off-by: yashaswikarnati Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: yaoyu-33 Co-authored-by: ykarnati Co-authored-by: yashaswikarnati --- .github/workflows/cicd-main.yml | 18 +- nemo/collections/llm/peft/api.py | 4 +- .../multimodal/data/energon/base.py | 5 + .../multimodal/data/energon/config.py | 24 ++- .../multimodal/data/energon/conversation.py | 2 +- .../multimodal/data/energon/task_encoder.py | 170 +++++++++++++++--- nemo/collections/vlm/inference/base.py | 2 +- nemo/collections/vlm/neva/data/config.py | 6 +- nemo/collections/vlm/neva/data/lazy.py | 115 ++++++------ nemo/collections/vlm/neva/data/mock.py | 69 ++++++- .../vlm/neva/data/sequence_packing.py | 157 ++++++++++++++++ nemo/collections/vlm/neva/model/base.py | 38 +++- nemo/collections/vlm/recipes/llava15_13b.py | 2 +- nemo/collections/vlm/recipes/llava15_7b.py | 2 +- nemo/collections/vlm/recipes/llava_next_7b.py | 2 +- nemo/lightning/megatron_parallel.py | 5 +- requirements/requirements_multimodal.txt | 2 +- scripts/vlm/llava_next_finetune.py | 2 +- scripts/vlm/llava_next_pretrain.py | 2 +- scripts/vlm/mllama_finetune.py | 2 +- scripts/vlm/neva_finetune.py | 109 ++++++++--- .../data/energon/test_data_module.py | 4 +- .../{mllama_train.py => test_mllama_train.py} | 0 .../vlm/{neva_train.py => test_neva_train.py} | 7 + 24 files changed, 611 insertions(+), 138 deletions(-) create mode 100644 nemo/collections/vlm/neva/data/sequence_packing.py rename tests/collections/vlm/{mllama_train.py => test_mllama_train.py} (100%) rename tests/collections/vlm/{neva_train.py => test_neva_train.py} (95%) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 16037920d080..a815be7bdc2f 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4329,11 +4329,24 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - python tests/collections/vlm/neva_train.py \ + python tests/collections/vlm/test_neva_train.py \ --devices=1 \ --max-steps=5 \ --experiment-dir=/tmp/nemo2_neva_results/${{ github.run_id }} + L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python tests/collections/vlm/test_neva_train.py \ + --devices=1 \ + --max-steps=5 \ + --experiment-dir=/tmp/nemo2_neva_results/${{ github.run_id }} \ + --use_packed_sequence + L2_NeMo_2_MLLAMA_MOCK_TRAINING: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4342,7 +4355,7 @@ jobs: RUNNER: self-hosted-azure SCRIPT: | TRANSFORMERS_OFFLINE=1 \ - python tests/collections/vlm/mllama_train.py \ + python tests/collections/vlm/test_mllama_train.py \ --devices=1 \ --max-steps=5 \ --experiment-dir=/tmp/nemo2_mllama_results/${{ github.run_id }} @@ -5060,6 +5073,7 @@ jobs: - Speech_Checkpoints_tests - L2_Stable_Diffusion_Training - L2_NeMo_2_NEVA_MOCK_TRAINING + - L2_NeMo_2_NEVA_MOCK_PACKED_TRAINING - L2_NeMo_2_MLLAMA_MOCK_TRAINING - L2_NeMo_2_GPT_Pretraining_no_transformer_engine - L2_NeMo_2_GPT_DDP_Param_Parity_check diff --git a/nemo/collections/llm/peft/api.py b/nemo/collections/llm/peft/api.py index c05fd0b8edde..b70601faf7a3 100644 --- a/nemo/collections/llm/peft/api.py +++ b/nemo/collections/llm/peft/api.py @@ -16,10 +16,10 @@ from pathlib import Path from typing import Tuple, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch +from lightning.pytorch.trainer.states import TrainerFn from megatron.core import dist_checkpointing -from pytorch_lightning.trainer.states import TrainerFn from rich.console import Console from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer diff --git a/nemo/collections/multimodal/data/energon/base.py b/nemo/collections/multimodal/data/energon/base.py index 3dfd495edd82..c29935880889 100644 --- a/nemo/collections/multimodal/data/energon/base.py +++ b/nemo/collections/multimodal/data/energon/base.py @@ -68,6 +68,7 @@ def __init__( multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(), task_encoder: Optional[MultiModalTaskEncoder] = None, decoder_seq_length: Optional[int] = None, + packing_buffer_size: Optional[int] = None, ) -> None: """ Initialize the EnergonMultiModalDataModule. @@ -84,6 +85,8 @@ def __init__( Defaults to MultiModalSampleConfig(). task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None. + decoder_seq_length (int, optional): The maximum sequence length for the decoder. Used in encoder-decoder models. + packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None. """ super().__init__() @@ -113,6 +116,7 @@ def __init__( ) self.train_dataloader_object = None self.val_dataloader_object = None + self.packing_buffer_size = packing_buffer_size def io_init(self, **kwargs) -> fdl.Config[Self]: @@ -146,6 +150,7 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val task_encoder=self.task_encoder, worker_config=worker_config, max_samples_per_sequence=None, + packing_buffer_size=self.packing_buffer_size, shuffle_buffer_size=100, split_part=split, ) diff --git a/nemo/collections/multimodal/data/energon/config.py b/nemo/collections/multimodal/data/energon/config.py index c145c5e51019..abbfd874880f 100644 --- a/nemo/collections/multimodal/data/energon/config.py +++ b/nemo/collections/multimodal/data/energon/config.py @@ -13,8 +13,11 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import List +from typing import List, Tuple, Union + import torch +from megatron.core.packed_seq_params import PackedSeqParams + from nemo.collections.multimodal.data.energon.conversation import LLaVATemplateConfig @@ -34,7 +37,7 @@ class ImageToken(MultiModalToken): @dataclass class ImageTextSample: - '''Sample type for template formatted raw image text sample''' + """Sample type for template formatted raw image text sample""" __key__: str = '' images: torch.Tensor = field(default_factory=lambda: torch.empty(0)) @@ -43,6 +46,15 @@ class ImageTextSample: loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float)) +@dataclass +class PackedImageTextSample(ImageTextSample): + """Sample type for packed image text sample""" + + __restore_key__: Tuple[Union[str, int, tuple], ...] = () + position_ids: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float)) + packed_seq_params: PackedSeqParams = field(default_factory=lambda: PackedSeqParams()) + + @dataclass class ImageTextRawBatch: """Sample type for image text raw batch""" @@ -56,6 +68,14 @@ class ImageTextRawBatch: loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float)) +@dataclass +class PackedImageTextRawBatch(ImageTextRawBatch): + """Sample type for image text raw batch""" + + position_ids: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float)) + packed_seq_params: PackedSeqParams = field(default_factory=lambda: PackedSeqParams()) + + @dataclass class MultiModalSampleConfig: image_token: ImageToken = field(default_factory=ImageToken) diff --git a/nemo/collections/multimodal/data/energon/conversation.py b/nemo/collections/multimodal/data/energon/conversation.py index 31019ae9c615..95b0ad184f8c 100644 --- a/nemo/collections/multimodal/data/energon/conversation.py +++ b/nemo/collections/multimodal/data/energon/conversation.py @@ -30,7 +30,7 @@ class LLaVATemplateConfig(BaseConversationTemplateConfig): """LLava-specific template configuration which extends the base config""" system: str = field( - default="A chat between a curious user and artificial assistant agent. " + default="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed and polite answers to user's questions." ) roles: List[str] = field(default_factory=lambda: ['user', 'assistant']) diff --git a/nemo/collections/multimodal/data/energon/task_encoder.py b/nemo/collections/multimodal/data/energon/task_encoder.py index 7a8d0f0ab033..80b6e156f4a1 100644 --- a/nemo/collections/multimodal/data/energon/task_encoder.py +++ b/nemo/collections/multimodal/data/energon/task_encoder.py @@ -25,14 +25,21 @@ batch_list, batch_pad_stack, ) +from megatron.energon.task_encoder.base import stateless -from nemo.collections.multimodal.data.energon.config import ImageTextRawBatch, ImageTextSample +from nemo.collections.multimodal.data.energon.config import ( + ImageTextRawBatch, + ImageTextSample, + PackedImageTextRawBatch, + PackedImageTextSample, +) from nemo.collections.multimodal.data.energon.sample_encoder import ( InterleavedSampleEncoder, SampleEncoder, SimilarityInterleavedEncoder, VQASampleEncoder, ) +from nemo.utils import logging class MultiModalTaskEncoder( @@ -54,16 +61,34 @@ class MultiModalTaskEncoder( for model input. """ - def __init__(self, tokenizer, image_processor, multimodal_sample_config): + def __init__( + self, + tokenizer, + image_processor, + multimodal_sample_config, + packed_sequence=False, + packed_sequence_size=-1, + num_image_embeddings_per_tile=576, + ): """ Initialize the MultiModalTaskEncoder with specific encoders for different sample types. Parameters: - tokenizer (Tokenizer): The tokenizer used for processing text across different sample types. - image_processor (ImageProcessor): The image processor used for preprocessing images. - multimodal_sample_config (MultiModalSampleConfig): MultiModalSampleConfig object. + tokenizer (Tokenizer): The tokenizer used for processing textual components across sample types. + image_processor (ImageProcessor): The image processor responsible for preprocessing image data. + multimodal_sample_config (MultiModalSampleConfig): Configuration object defining properties and + requirements for multimodal samples. + packed_sequence (bool, optional): Flag indicating whether packed sequences are used. Default is False. + packed_sequence_size (int, optional): The size of packed sequences, used when `packed_sequence` is True. + Default is -1. + num_image_embeddings_per_tile (int, optional): Number of image embeddings per image tile. Determines + the granularity of image features. Default is 576. """ self.tokenizer = tokenizer + self.sample_config = multimodal_sample_config + self.packed_sequence = packed_sequence + self.num_image_embeddings_per_tile = num_image_embeddings_per_tile # only used with seq packing + self.packed_sequence_size = packed_sequence_size self.encoders: Dict[str, SampleEncoder] = { VQASample.__name__: VQASampleEncoder( tokenizer=tokenizer, @@ -92,6 +117,7 @@ def register_encoder(self, sample_type: str, encoder: SampleEncoder) -> None: """ self.encoders[sample_type] = encoder + @stateless def encode_sample( self, sample: Union[VQASample, InterleavedSample, SimilarityInterleavedSample, CaptioningSample] ) -> ImageTextSample: @@ -118,7 +144,9 @@ def encode_sample( encoded_sample = encoder.encode(input_sample=sample, output_sample=ImageTextSample()) return encoded_sample - def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch: + def batch( + self, samples: List[Union[ImageTextSample, PackedImageTextSample]] + ) -> Union[ImageTextRawBatch, PackedImageTextRawBatch]: """ Batch a list of encoded samples into a single raw batch. @@ -131,26 +159,51 @@ def batch(self, samples: List[ImageTextSample]) -> ImageTextRawBatch: ImageTextRawBatch: The batched data, including images, tokens, labels, and loss masks. """ - keys, images, tokens, labels, loss_mask = [], [], [], [], [] - for sample in samples: - keys.append(sample.__key__) - images.append(sample.images) - tokens.append(sample.tokens) - labels.append(sample.labels) - loss_mask.append(sample.loss_mask) - - batch_keys = batch_list(keys) - batch_images = batch_pad_stack(images) - batch_prompt_tokens = batch_pad_stack(tokens) - batch_labels = batch_pad_stack(labels) - batch_loss_mask = batch_pad_stack(loss_mask) - return ImageTextRawBatch( - __keys__=batch_keys, - images=batch_images, - tokens=batch_prompt_tokens, - labels=batch_labels, - loss_mask=batch_loss_mask, - ) + if self.packed_sequence: + if len(samples) > 1: + raise ValueError( + "Micro batch size should be 1 when training with packed sequence, but your micro batch size " + f"is {len(samples)}. \nThe following config is equivalent to your current setting for " + f"a packed dataset. Please update your config to the following: \n" + f"Set micro batch size to 1 (currently {len(samples)})\n" + f"Set global batch size to `global_batch_size // {len(samples)}` " + f"Set packed sequence length to `original_sample_seq_len * {len(samples)}` " + f"(currently {self.packed_sequence_size}) \n" + f"For details please visit " + f"https://docs.nvidia.com/nemo-framework/user-guide/latest/sft_peft/packed_sequence.html" + ) + # The batching are taken care by packing. + sample = samples[0] + return PackedImageTextRawBatch( + __keys__=sample.__key__, + images=sample.images, + tokens=sample.tokens, + labels=sample.labels, + loss_mask=sample.loss_mask, + position_ids=sample.position_ids, + packed_seq_params=sample.packed_seq_params, + ) + else: + keys, images, tokens, labels, loss_mask = [], [], [], [], [] + for sample in samples: + keys.append(sample.__key__) + images.append(sample.images) + tokens.append(sample.tokens) + labels.append(sample.labels) + loss_mask.append(sample.loss_mask) + + batch_keys = batch_list(keys) + batch_images = batch_pad_stack(images) + batch_prompt_tokens = batch_pad_stack(tokens) + batch_labels = batch_pad_stack(labels) + batch_loss_mask = batch_pad_stack(loss_mask) + return ImageTextRawBatch( + __keys__=batch_keys, + images=batch_images, + tokens=batch_prompt_tokens, + labels=batch_labels, + loss_mask=batch_loss_mask, + ) def encode_batch(self, batch_data: ImageTextRawBatch) -> dict: """ @@ -165,7 +218,7 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict: Returns: dict: A dictionary containing the encoded batch data, ready for model input. """ - batch_dict = dataclasses.asdict(batch_data) + batch_dict = batch_data.__dict__ if 'images' in batch_dict: batch_dict['media'] = batch_dict['images'] del batch_dict['images'] @@ -177,3 +230,66 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict: if 'attention_mask' not in batch_dict: batch_dict['attention_mask'] = None return batch_dict + + def select_samples_to_pack(self, samples): + """Selects which samples will be packed together. + + NOTE: Energon dataloader calls this method internally if packing is used. + Please see https://nvidia.github.io/Megatron-Energon/packing.html + """ + from nemo.collections.vlm.neva.data.sequence_packing import greedy_knapsack, predict_seq_len + + media_token_id = self.sample_config.image_token.token_id + lengths = [ + predict_seq_len( + sample.tokens, + media_token_index=media_token_id, + num_image_embeddings_per_tile=self.num_image_embeddings_per_tile, + ) + for sample in samples + ] + packed_samples = greedy_knapsack(lengths, samples, self.packed_sequence_size) + avg_samples_per_bin = round(len(lengths) / len(packed_samples)) + logging.info( + f"[Seq Packing Info] - Packing seq len: {self.packed_sequence_size}, " + f"Buffered samples: {len(lengths)}, Total number of bins: {len(packed_samples)}, " + f"Average samples per bin: {avg_samples_per_bin}" + ) + return packed_samples + + @stateless + def pack_selected_samples(self, samples): + """ + Function to pack a list of ImageTaskSample into a single ImageTaskSamplePacked. + + NOTE: Energon dataloader calls this method internally if packing is used. + Please see https://nvidia.github.io/Megatron-Energon/packing.html + + Args: + samples: List of ImageTaskSample instances to pack into one sample. + + Returns: + ImageTaskSamplePacked instance. + """ + from nemo.collections.vlm.neva.data.sequence_packing import convert_to_packed + + packed_images = torch.stack([sample.images for sample in samples]) + media_token_id = self.sample_config.image_token.token_id + packed_tokens, packed_labels, packed_position_ids, packed_loss_mask, packed_seq_params = convert_to_packed( + tokens=[sample.tokens for sample in samples], + labels=[sample.labels for sample in samples], + num_image_embeddings_per_tile=self.num_image_embeddings_per_tile, + media_token_index=media_token_id, + ignore_index=self.sample_config.ignore_place_holder, + ) + + return PackedImageTextSample( + __key__=",".join([s.__key__ for s in samples]), + __restore_key__=(), # Will be set by energon based on `samples` + tokens=packed_tokens, + labels=packed_labels, + images=packed_images, + position_ids=packed_position_ids, + loss_mask=packed_loss_mask, + packed_seq_params=packed_seq_params, + ) diff --git a/nemo/collections/vlm/inference/base.py b/nemo/collections/vlm/inference/base.py index 77918bae26b9..bbceb851edae 100644 --- a/nemo/collections/vlm/inference/base.py +++ b/nemo/collections/vlm/inference/base.py @@ -14,7 +14,7 @@ from typing import List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.distributed from megatron.core.inference.common_inference_params import CommonInferenceParams diff --git a/nemo/collections/vlm/neva/data/config.py b/nemo/collections/vlm/neva/data/config.py index 3b22d5a493b3..2cf3dd80f47d 100644 --- a/nemo/collections/vlm/neva/data/config.py +++ b/nemo/collections/vlm/neva/data/config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from .multimodal_tokens import ImageToken, MultiModalToken, VideoToken @@ -31,7 +31,7 @@ class DataConfig: @dataclass class ImageDataConfig(DataConfig): media_type: str = "image" - media_token: MultiModalToken = ImageToken + media_token: MultiModalToken = field(default_factory=lambda: ImageToken()) image_folder: Optional[str] = None image_process_mode: str = 'pad' @@ -39,7 +39,7 @@ class ImageDataConfig(DataConfig): @dataclass class VideoDataConfig(DataConfig): media_type: str = "video" - media_token: MultiModalToken = VideoToken + media_token: MultiModalToken = VideoToken() splice_single_frame: Optional[str] = None # 'first', 'middle', 'last' will represent video as first / middle / last frame only, all other frames discarded. num_frames: int = 8 # Selects the number of frames to use from the video diff --git a/nemo/collections/vlm/neva/data/lazy.py b/nemo/collections/vlm/neva/data/lazy.py index 066310867777..90199d3c6d30 100644 --- a/nemo/collections/vlm/neva/data/lazy.py +++ b/nemo/collections/vlm/neva/data/lazy.py @@ -251,7 +251,6 @@ def __init__( data_config, tokenizer, image_processor, - sequence_length=None, ): super().__init__() if data_path is not None: @@ -269,8 +268,6 @@ def __init__( self.tokenizer = self.tokenizer.tokenizer self.image_processor = image_processor - self.sequence_length = sequence_length - self.conv_template = data_config.conv_template self.conv = supported_conv_templates[self.conv_template] self.image_process_mode = data_config.image_process_mode @@ -381,6 +378,8 @@ def __init__( data_config, tokenizer, image_processor, + packed_sequence=False, + num_image_embeddings_per_tile=576, ): if data_path.endswith(".json"): @@ -414,29 +413,12 @@ def __init__( else: raise ValueError(f"Formatting of {data_path} is not supported in Neva.") + self.packed_sequence = packed_sequence + self.num_image_embeddings_per_tile = num_image_embeddings_per_tile def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: data_config = self.data_config - packed_sequence = "cu_seqlens" in instances[0] - max_len = max(instance['tokens'].shape[0] for instance in instances) - for instance in instances: - pad_len = max_len - instance['tokens'].shape[0] - instance['tokens'] = F.pad(instance['tokens'], (0, pad_len), 'constant', 0) - instance['labels'] = F.pad(instance['labels'], (0, pad_len), 'constant', IGNORE_INDEX) - if packed_sequence and instance["cu_seqlens"][-1] != max_len: - instance["cu_seqlens"] = torch.cat((instance["cu_seqlens"], torch.IntTensor([max_len])), 0) - - if packed_sequence: - max_len_cu = max(instance['cu_seqlens'].shape[0] for instance in instances) - max_len_image = max(instance['image'].shape[0] for instance in instances) - for instance in instances: - pad_len_cu = max_len_cu - instance['cu_seqlens'].shape[0] - instance['cu_seqlens'] = F.pad(instance['cu_seqlens'], (0, pad_len_cu), 'constant', max_len) - - x = instance['image'] - num_pad = max_len_image - x.shape[0] - pad_tensor = torch.zeros(num_pad, *x.shape[1:], dtype=x.dtype, device=x.device) - instance['image'] = torch.cat((x, pad_tensor), dim=0) + packed_sequence = self.packed_sequence media_type = data_config.media_type if media_type == 'image': @@ -447,24 +429,30 @@ def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: else: raise ValueError(f"Unsupported media type {media_type}") - batch = default_collate(instances) - tokenizer = self.tokenizer + if packed_sequence: + from nemo.collections.vlm.neva.data.sequence_packing import convert_to_packed + + media_token_id = self.data_config.media_token.token_index + tokens, labels, position_ids, loss_mask, packed_seq_params = convert_to_packed( + tokens=[instance['tokens'] for instance in instances], + labels=[instance['labels'] for instance in instances], + num_image_embeddings_per_tile=self.num_image_embeddings_per_tile, + media_token_index=media_token_id, + ignore_index=IGNORE_INDEX, + ) + attention_mask = None + else: # regular dataset + max_len = max(instance['tokens'].shape[0] for instance in instances) + for instance in instances: + pad_len = max_len - instance['tokens'].shape[0] + instance['tokens'] = F.pad(instance['tokens'], (0, pad_len), 'constant', 0) + instance['labels'] = F.pad(instance['labels'], (0, pad_len), 'constant', IGNORE_INDEX) - tokens = batch['tokens'] - labels = batch['labels'] + batch = default_collate(instances) + tokenizer = self.tokenizer - if packed_sequence: - cu_seqlens = batch["cu_seqlens"] - position_ids = [] - for cu_seqlen in cu_seqlens: - position_ids.append([]) - for ind in range(0, len(cu_seqlen) - 1): - seqlen = cu_seqlen[ind + 1] - cu_seqlen[ind] - position_ids[-1].extend(list(range(seqlen))) - position_ids = torch.LongTensor(position_ids) - loss_mask = torch.ones(tokens.size(), dtype=torch.float, device=tokens.device) - attention_mask = torch.ones(tokens.size(), dtype=torch.long, device=tokens.device) - else: + tokens = batch['tokens'] + labels = batch['labels'] attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( data=tokens, eod_token=tokenizer.eos_token_id, @@ -472,8 +460,7 @@ def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: reset_attention_mask=data_config.reset_attention_mask, reset_position_ids=data_config.reset_position_ids, ) - - loss_mask[labels < 0] = 0.0 + loss_mask[labels < 0] = 0.0 batch = { 'tokens': tokens, @@ -484,7 +471,7 @@ def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 'media': media, } if packed_sequence: - batch["cu_seqlens"] = cu_seqlens + batch["packed_seq_params"] = packed_seq_params return batch @@ -506,7 +493,8 @@ def __init__( num_workers: int = 8, pin_memory: bool = True, persistent_workers: bool = False, - use_packed_sequence: bool = False, + packed_sequence: bool = False, + num_image_embeddings_per_tile: int = 576, seed: int = 1234, ) -> None: super().__init__() @@ -534,7 +522,8 @@ def __init__( self.pin_memory = pin_memory self.persistent_workers = persistent_workers self.seed = seed - self.use_packed_sequence = use_packed_sequence + self.packed_sequence = packed_sequence + self.num_image_embeddings_per_tile = num_image_embeddings_per_tile self.init_global_step = 0 if tokenizer is None or image_processor is None: @@ -546,6 +535,20 @@ def __init__( self.tokenizer = tokenizer or AutoTokenizer("llava-hf/llava-1.5-7b-hf") self.image_processor = image_processor or processor.image_processor + if self.packed_sequence: + import dataclasses + + def custom_on_megatron_step_start(self, step): + return dataclasses.replace( + step, + seq_length=self.seq_len, + micro_batch_size=1, # Override the micro_batch_size to 1 (used in PP) + num_microbatches=self.num_microbatches, + decoder_seq_length=self.decoder_seq_len, + ) + + MegatronDataSampler.on_megatron_step_start = custom_on_megatron_step_start + self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, decoder_seq_len=self.decoder_seq_length, @@ -556,14 +559,22 @@ def __init__( def setup(self, stage: str = "") -> None: assert len(self.paths) == 1, "not yet support blend dataset in Neva 2.0!" - if self.use_packed_sequence: - pass # TODO - else: - # TODO: - # rng = torch.Generator().manual_seed(self.seed) - # train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=rng) - self._train_ds = NevaDataset(self.paths[0], self.data_config, self.tokenizer, self.image_processor) - self._validation_ds = NevaDataset(self.paths[0], self.data_config, self.tokenizer, self.image_processor) + self._train_ds = NevaDataset( + self.paths[0], + self.data_config, + self.tokenizer, + self.image_processor, + packed_sequence=self.packed_sequence, + num_image_embeddings_per_tile=self.num_image_embeddings_per_tile, + ) + self._validation_ds = NevaDataset( + self.paths[0], + self.data_config, + self.tokenizer, + self.image_processor, + packed_sequence=self.packed_sequence, + num_image_embeddings_per_tile=self.num_image_embeddings_per_tile, + ) def train_dataloader(self) -> TRAIN_DATALOADERS: return self._create_dataloader(self._train_ds) diff --git a/nemo/collections/vlm/neva/data/mock.py b/nemo/collections/vlm/neva/data/mock.py index 7533bf56ac46..495bd9f0dee5 100644 --- a/nemo/collections/vlm/neva/data/mock.py +++ b/nemo/collections/vlm/neva/data/mock.py @@ -42,6 +42,7 @@ def __init__( num_workers: int = 8, pin_memory: bool = True, persistent_workers: bool = False, + packed_sequence: bool = False, ): super().__init__() self.seq_length = seq_length @@ -54,6 +55,7 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers + self.packed_sequence = packed_sequence if tokenizer is None or image_processor is None: logging.warning(f"Processor or tokenizer are not provided! Fall back to `llava-hf/llava-1.5-7b-hf`.") @@ -72,14 +74,36 @@ def __init__( ) def setup(self, stage: str = "") -> None: + seq_length = self.seq_length + if self.packed_sequence and self.micro_batch_size > 1: + seq_length = seq_length // self.micro_batch_size + logging.warning( + f"Packed sequence is used with mock dataset. Sequence length for each " + f"sample is update to `seq_length // self.micro_batch_size = {seq_length}`!" + ) self._train_ds = _MockNevaDataset( - self.tokenizer, self.image_processor, "train", self.num_train_samples, self.seq_length + self.tokenizer, + self.image_processor, + "train", + self.num_train_samples, + seq_length, + packed_sequence=self.packed_sequence, ) self._validation_ds = _MockNevaDataset( - self.tokenizer, self.image_processor, "valid", self.num_val_samples, self.seq_length + self.tokenizer, + self.image_processor, + "valid", + self.num_val_samples, + seq_length, + packed_sequence=self.packed_sequence, ) self._test_ds = _MockNevaDataset( - self.tokenizer, self.image_processor, "test", self.num_test_samples, self.seq_length + self.tokenizer, + self.image_processor, + "test", + self.num_test_samples, + seq_length, + packed_sequence=self.packed_sequence, ) def train_dataloader(self) -> TRAIN_DATALOADERS: @@ -117,6 +141,8 @@ def __init__( num_samples: int, seq_length: int, seed: int = 42, + packed_sequence: bool = False, + num_image_embeddings_per_tile=576, ) -> None: super().__init__() self.name = name @@ -129,8 +155,10 @@ def __init__( self.length = num_samples self.seed = seed + self.packed_sequence = packed_sequence + self.num_image_embeddings_per_tile = num_image_embeddings_per_tile - self.loss_mask = torch.ones(self.seq_length, dtype=torch.float) + self.loss_mask = torch.ones(self.seq_length + 1 - num_image_embeddings_per_tile, dtype=torch.float) self.position_ids = torch.arange(self.seq_length, dtype=torch.int64) def __len__(self) -> int: @@ -143,7 +171,11 @@ def _get_text(self, idx: int) -> np.ndarray: def __getitem__(self, idx) -> Dict[str, torch.Tensor]: # Generate data of the expected size and datatype (based on GPTDataset). np_gen = np.random.default_rng(seed=(self.seed + idx)) - tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length + 1], dtype=np.int64)) + tokens = torch.from_numpy( + np_gen.integers( + self.vocab_size, size=[self.seq_length + 2 - self.num_image_embeddings_per_tile], dtype=np.int64 + ) + ) tokens[2] = IMAGE_TOKEN_INDEX # ImageToken token index labels = tokens.clone() images = torch.from_numpy(np_gen.random(size=[3, self.image_height, self.image_width], dtype=np.float32)) @@ -164,6 +196,33 @@ def _collate_fn(self, batch): """ collated_batch = data.dataloader.default_collate(batch) collated_batch["attention_mask"] = None + if self.packed_sequence: + from megatron.core.packed_seq_params import PackedSeqParams + + tokens = collated_batch["tokens"] + batch_size = tokens.shape[0] + valid_seqlen = self.seq_length + cu_seqlens = torch.arange( + 0, (batch_size + 1) * (valid_seqlen), step=(valid_seqlen), dtype=torch.int32, device=tokens.device + ) + cu_seqlens_padded = torch.arange( + 0, (batch_size + 1) * (valid_seqlen), step=(valid_seqlen), dtype=torch.int32, device=tokens.device + ) + qkv_format = 'thd' + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=valid_seqlen, + max_seqlen_kv=valid_seqlen, + qkv_format=qkv_format, + ) + collated_batch["packed_seq_params"] = packed_seq_params + + for key in ["tokens", "labels", "loss_mask", "position_ids"]: + collated_batch[key] = collated_batch[key].reshape(1, -1) + return collated_batch def collate_fn(self, batch): diff --git a/nemo/collections/vlm/neva/data/sequence_packing.py b/nemo/collections/vlm/neva/data/sequence_packing.py new file mode 100644 index 000000000000..1ddfe80c5797 --- /dev/null +++ b/nemo/collections/vlm/neva/data/sequence_packing.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bisect +from typing import List + +import torch +import torch.nn.functional as F +from megatron.core.packed_seq_params import PackedSeqParams + + +# pylint:disable=line-too-long +# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L19 +# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. +def search_for_fit(numbers: List[int], capacity: int) -> int: + """Finds the index of largest number that fits into the knapsack with the given capacity.""" + index = bisect.bisect(numbers, capacity) + return -1 if index == 0 else (index - 1) + + +# pylint: disable=line-too-long +# Based on https://github.com/hiyouga/LLaMA-Factory/blob/641d0dab08d96a93c34657742213d8994d9ed476/src/llamafactory/data/processors/processor_utils.py#L27 +# Copyright (c) 2024 LLaMA-Factory. Apache license 2.0. +def greedy_knapsack(item_sizes: List[int], samples: List, max_capacity: int) -> List: + """Greedy algorithm with binary search for the knapsack problem. + + Pack as many samples as possible given a maximum capacity and capacities of individual samples. + Used if sequence packing is enabled. + """ + assert len(item_sizes) == len(samples), "sample lengths and samples must have the same length." + + knapsacks = [] + + if len(item_sizes) == 0: + return knapsacks + + # Sort sample lengths and samples together. + sorted_item_sizes, sorted_samples = zip(*sorted(zip(item_sizes, samples), key=lambda x: x[0])) + sorted_item_sizes = list(sorted_item_sizes) + sorted_samples = list(sorted_samples) + + # Check if all samples fit in the knapsack capacity. + if sorted_item_sizes[-1] > max_capacity: + raise ValueError( + f"knapsack: A sample is larger {sorted_item_sizes[-1]} than the max_sequence_length {max_capacity}." + ) + + while sorted_item_sizes: + current_knapsack = [] + remaining_capacity = max_capacity + + while True: + idx = search_for_fit(sorted_item_sizes, remaining_capacity) + if idx == -1: + break # Can't fit more samples. + + remaining_capacity -= sorted_item_sizes[idx] + + sorted_item_sizes.pop(idx) + sample = sorted_samples.pop(idx) + current_knapsack.append(sample) + + knapsacks.append(current_knapsack) + + return knapsacks + + +def predict_seq_len(instance_tokens: torch.Tensor, num_image_embeddings_per_tile: int, media_token_index: int) -> int: + """ + Predict the effective sequence length, accounting for media embeddings. + + Args: + instance_tokens (torch.Tensor): Token tensor for a single instance. + num_image_embeddings_per_tile (int): Number of image embeddings per tile. + media_token_index (int): Token ID representing media. + + Returns: + int: Effective sequence length. + """ + num_images = torch.sum(instance_tokens == media_token_index).item() + seqlen = len(instance_tokens) + (num_image_embeddings_per_tile - 1) * num_images + return seqlen + + +def convert_to_packed( + tokens: List[torch.Tensor], + labels: List[torch.Tensor], + num_image_embeddings_per_tile: int, + media_token_index: int, + ignore_index: int, + pad_to_multiple_of: int = 64, +): + """ + Convert tokens, labels, and associated inputs into a packed version with padded sequence parameters. + + Args: + tokens (list[torch.Tensor]): List of token tensors for each instance. + labels (list[torch.Tensor]): List of label tensors for each instance. + num_image_embeddings_per_tile (int): Number of image embeddings per tile. + media_token_index (int): Token ID representing media. + ignore_index (int): Value to use for padding labels. + pad_to_multiple_of (int): Sequence length will be padded to a multiple of this value. Default is 8. + """ + packed_tokens = [] + packed_labels = [] + packed_position_ids = [] + seqlens_padded = [] + cu_seqlens = [0] + cu_seqlens_padded = [0] + + for instance_tokens, instance_labels in zip(tokens, labels): + seqlen = predict_seq_len(instance_tokens, num_image_embeddings_per_tile, media_token_index) + seqlen_padded = (seqlen + pad_to_multiple_of - 1) // pad_to_multiple_of * pad_to_multiple_of + pad_len = seqlen_padded - seqlen + + if pad_len > 0: + instance_tokens = F.pad(instance_tokens, (0, pad_len), 'constant', 0) + instance_labels = F.pad(instance_labels, (0, pad_len), 'constant', ignore_index) + + packed_tokens.append(instance_tokens) + packed_labels.append(instance_labels) + packed_position_ids.append(torch.arange(len(instance_tokens), dtype=torch.int, device=instance_tokens.device)) + seqlens_padded.append(seqlen_padded) + cu_seqlens.append(cu_seqlens[-1] + seqlen) + cu_seqlens_padded.append(cu_seqlens_padded[-1] + seqlen_padded) + + packed_tokens = torch.cat(packed_tokens, dim=0).unsqueeze(0) + packed_labels = torch.cat(packed_labels, dim=0).unsqueeze(0) + packed_position_ids = torch.cat(packed_position_ids, dim=0).unsqueeze(0) + packed_loss_mask = torch.ones_like(packed_labels, dtype=torch.float, device=packed_labels.device) + packed_loss_mask[packed_labels < 0] = 0.0 + + cu_seqlens = torch.IntTensor(cu_seqlens) + cu_seqlens_padded = torch.IntTensor(cu_seqlens_padded) + + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=int(max(seqlens_padded)), + max_seqlen_kv=int(max(seqlens_padded)), + qkv_format='thd', + ) + + return packed_tokens, packed_labels, packed_position_ids, packed_loss_mask, packed_seq_params diff --git a/nemo/collections/vlm/neva/model/base.py b/nemo/collections/vlm/neva/model/base.py index 388078484a56..8cead72b4832 100644 --- a/nemo/collections/vlm/neva/model/base.py +++ b/nemo/collections/vlm/neva/model/base.py @@ -121,14 +121,19 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: ) ) + packed_seq_params = _batch.get("packed_seq_params", None) _batch = { key: val.cuda(non_blocking=True) if key in required_keys and val is not None else None for key, val in _batch.items() } - # slice batch along sequence dimension for context parallelism - output = get_batch_on_this_context_parallel_rank(_batch) + if packed_seq_params is not None: + for attr in ["cu_seqlens_q", "cu_seqlens_kv", "cu_seqlens_q_padded", "cu_seqlens_kv_padded"]: + value = getattr(packed_seq_params, attr, None) + if value is not None: + setattr(packed_seq_params, attr, value.cuda(non_blocking=True)) + _batch["packed_seq_params"] = packed_seq_params - return output + return _batch def neva_forward_step(model, batch) -> torch.Tensor: @@ -596,6 +601,7 @@ def forward( image_token_index, num_image_tiles, attention_mask, + packed_seq_params, ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] output = self.language_model( @@ -642,6 +648,7 @@ def _preprocess_data( image_token_index, num_image_tiles, attention_mask, + packed_seq_params, ): """Preprocess input data before input to language model. @@ -698,6 +705,8 @@ def _preprocess_data( labels.shape == loss_mask.shape ), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}" + packed_sequence = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" + # Create indices for new text and label positions. with torch.no_grad(): image_token_mask = input_ids == image_token_index @@ -715,6 +724,16 @@ def _preprocess_data( # Pipeline parallel expects fixed input size. Check if we need to pad. if self._language_is_pipeline_parallel and max_seq_len < self._language_max_sequence_length: max_seq_len = self._language_max_sequence_length + if packed_sequence: + last_seqlen = packed_seq_params.cu_seqlens_q[-1] - packed_seq_params.cu_seqlens_q[-2] + last_seqlen_padded = max_seq_len - packed_seq_params.cu_seqlens_q_padded[-2] + assert ( + last_seqlen_padded >= last_seqlen + ), "`language_max_sequence_length` needs to increase for sequence packing to work properly." + packed_seq_params.cu_seqlens_q_padded[-1] = max_seq_len + packed_seq_params.cu_seqlens_kv_padded[-1] = max_seq_len + packed_seq_params.max_seqlen_q = max(last_seqlen_padded, packed_seq_params.max_seqlen_q) + packed_seq_params.max_seqlen_kv = max(last_seqlen_padded, packed_seq_params.max_seqlen_kv) if self.sequence_parallel_lm: if self.tp_comm_overlap_lm: @@ -835,7 +854,17 @@ def _preprocess_data( # Truncate if exceeding the language model's max sequence length. if final_embedding.shape[0] > self._language_max_sequence_length: final_embedding = final_embedding[: self._language_max_sequence_length] - if self.sequence_parallel_lm: + if packed_sequence: + truncate_len = packed_seq_params.cu_seqlens_q_padded[-1] - self._language_max_sequence_length + packed_seq_params.cu_seqlens_q_padded[-1] = self._language_max_sequence_length + packed_seq_params.cu_seqlens_kv_padded[-1] = self._language_max_sequence_length + packed_seq_params.cu_seqlens_q[-1] -= truncate_len + packed_seq_params.cu_seqlens_kv[-1] -= truncate_len + assert ( + packed_seq_params.cu_seqlens_q[-1] >= packed_seq_params.cu_seqlens_q[-2] + ), "with packed sequence, the truncation can only truncate on the last sequence." + + if self.sequence_parallel_lm and not packed_sequence: # Create an attention mask. This ensures correct computation. # This is done even when no padding was done as we set mask_type to # 'padding' or 'padding_causal' when using SP. @@ -858,6 +887,7 @@ def _preprocess_data( # Attention mask True/False meaning flipped in 1.7.0 attention_mask = attention_mask < 0.5 + if self.sequence_parallel_lm: final_embedding = tensor_parallel.scatter_to_sequence_parallel_region(final_embedding) return final_embedding, final_labels, final_loss_mask, attention_mask diff --git a/nemo/collections/vlm/recipes/llava15_13b.py b/nemo/collections/vlm/recipes/llava15_13b.py index d85ba6f2752b..40bc8cc44682 100644 --- a/nemo/collections/vlm/recipes/llava15_13b.py +++ b/nemo/collections/vlm/recipes/llava15_13b.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from megatron.core.distributed import DistributedDataParallelConfig diff --git a/nemo/collections/vlm/recipes/llava15_7b.py b/nemo/collections/vlm/recipes/llava15_7b.py index 2abb50db6c11..9de60e671e38 100644 --- a/nemo/collections/vlm/recipes/llava15_7b.py +++ b/nemo/collections/vlm/recipes/llava15_7b.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from megatron.core.distributed import DistributedDataParallelConfig diff --git a/nemo/collections/vlm/recipes/llava_next_7b.py b/nemo/collections/vlm/recipes/llava_next_7b.py index d23159125823..53609fe589c8 100644 --- a/nemo/collections/vlm/recipes/llava_next_7b.py +++ b/nemo/collections/vlm/recipes/llava_next_7b.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo import lightning as nl diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 1b1f5c790b61..e3c6c77f4cda 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -1711,7 +1711,10 @@ def masked_token_loss(tensor: Tensor, mask: Tensor): """ losses = tensor.float() loss_mask = mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll + num_valid_tokens = loss_mask.sum() + if num_valid_tokens < 0.5: # no valid tokens + num_valid_tokens += 1.0 + loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens # sequence level nll return loss diff --git a/requirements/requirements_multimodal.txt b/requirements/requirements_multimodal.txt index 92ae32659dac..585e277be72a 100644 --- a/requirements/requirements_multimodal.txt +++ b/requirements/requirements_multimodal.txt @@ -6,7 +6,7 @@ diffusers>=0.19.3 einops_exts imageio kornia -megatron-energon<3.0.0 +megatron-energon==4.0.0 nerfacc>=0.5.3 open_clip_torch==2.24.0 PyMCubes diff --git a/scripts/vlm/llava_next_finetune.py b/scripts/vlm/llava_next_finetune.py index 91df8a39452d..9d3e5053c0c1 100644 --- a/scripts/vlm/llava_next_finetune.py +++ b/scripts/vlm/llava_next_finetune.py @@ -25,8 +25,8 @@ import argparse import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm, vlm diff --git a/scripts/vlm/llava_next_pretrain.py b/scripts/vlm/llava_next_pretrain.py index 0beb9b5b08d0..19bdf47bb668 100644 --- a/scripts/vlm/llava_next_pretrain.py +++ b/scripts/vlm/llava_next_pretrain.py @@ -25,8 +25,8 @@ import argparse import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm, vlm diff --git a/scripts/vlm/mllama_finetune.py b/scripts/vlm/mllama_finetune.py index 15cd8078fd32..9e37d9c3fc0c 100644 --- a/scripts/vlm/mllama_finetune.py +++ b/scripts/vlm/mllama_finetune.py @@ -15,8 +15,8 @@ import argparse import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import WandbLogger from transformers import AutoProcessor from nemo import lightning as nl diff --git a/scripts/vlm/neva_finetune.py b/scripts/vlm/neva_finetune.py index 4069fb2d9278..3bf0084ea60d 100644 --- a/scripts/vlm/neva_finetune.py +++ b/scripts/vlm/neva_finetune.py @@ -21,11 +21,12 @@ import argparse import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm, vlm +from nemo.collections.multimodal.data.energon.task_encoder import MultiModalTaskEncoder from nemo.collections.vlm import ImageDataConfig from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.lightning.pytorch.optim import CosineAnnealingScheduler @@ -42,6 +43,33 @@ def main(args): max_steps = args.max_steps decoder_seq_length = 4096 + if args.use_packed_sequence: + decoder_seq_length = 8192 + + # Submodules configurations + language_transformer_config = llm.Llama2Config7B( + seq_length=decoder_seq_length, + ) + vision_transformer_config = vlm.HFCLIPVisionConfig( + pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" + ) + vision_projection_config = vlm.MultimodalProjectorConfig( + projector_type=args.projector_type, + input_size=vision_transformer_config.hidden_size, + hidden_size=language_transformer_config.hidden_size, + ffn_hidden_size=language_transformer_config.hidden_size, + ) + + # NEVA model configuration + neva_config = vlm.NevaConfig( + language_transformer_config=language_transformer_config, + vision_transformer_config=vision_transformer_config, + vision_projection_config=vision_projection_config, + language_model_from_pretrained=args.language_model_path, + freeze_language_model=False, + freeze_vision_model=True, + ) + num_image_embeddings_per_tile = vision_transformer_config.num_image_embeddings_per_tile if args.data_type == "llava": # Data configuration @@ -60,7 +88,50 @@ def main(args): micro_batch_size=mbs, tokenizer=None, image_processor=None, - num_workers=8, + num_workers=4, + packed_sequence=args.use_packed_sequence, + num_image_embeddings_per_tile=num_image_embeddings_per_tile, + ) + elif args.data_type == "energon": + from transformers import AutoProcessor + + from nemo.collections.multimodal.data.energon import ( + EnergonMultiModalDataModule, + ImageToken, + LLaVATemplateConfig, + MultiModalSampleConfig, + ) + + processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + tokenizer = processor.tokenizer + image_processor = processor.image_processor + + # Configure multimodal samples + config = MultiModalSampleConfig( + image_token=ImageToken(token_str="", token_id=-200), + ignore_place_holder=-100, + conversation_template_config=LLaVATemplateConfig(), + ) + + # Initialize the data module + data = EnergonMultiModalDataModule( + path=args.data_path, + tokenizer=tokenizer, + image_processor=image_processor, + seq_length=decoder_seq_length, + micro_batch_size=mbs, + global_batch_size=gbs, + num_workers=0, + multimodal_sample_config=config, + task_encoder=MultiModalTaskEncoder( + tokenizer=tokenizer, + image_processor=image_processor, + multimodal_sample_config=config, + packed_sequence=args.use_packed_sequence, + packed_sequence_size=decoder_seq_length, + num_image_embeddings_per_tile=num_image_embeddings_per_tile, + ), + packing_buffer_size=200 if args.use_packed_sequence else None, ) elif args.data_type == "mock": data = vlm.NevaMockDataModule( @@ -70,36 +141,11 @@ def main(args): tokenizer=None, image_processor=None, num_workers=4, + packed_sequence=args.use_packed_sequence, ) else: raise ValueError(f"Data type {args.data_type} not supported") - # Submodules configurations - language_transformer_config = llm.Llama2Config7B( - seq_length=decoder_seq_length, - ) - vision_transformer_config = vlm.HFCLIPVisionConfig( - pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" - ) - vision_projection_config = vlm.MultimodalProjectorConfig( - projector_type=args.projector_type, - input_size=vision_transformer_config.hidden_size, - hidden_size=language_transformer_config.hidden_size, - ffn_hidden_size=language_transformer_config.hidden_size, - ) - - # NEVA model configuration - neva_config = vlm.NevaConfig( - language_transformer_config=language_transformer_config, - vision_transformer_config=vision_transformer_config, - vision_projection_config=vision_projection_config, - language_model_from_pretrained=args.language_model_path, - freeze_language_model=False, - freeze_vision_model=True, - ) - - model = vlm.NevaModel(neva_config, tokenizer=data.tokenizer) - from megatron.core.distributed import DistributedDataParallelConfig # Training strategy setup @@ -118,6 +164,8 @@ def main(args): ), ) + model = vlm.NevaModel(neva_config, tokenizer=data.tokenizer) + # Checkpoint callback setup checkpoint_callback = nl.ModelCheckpoint( save_last=True, @@ -231,6 +279,9 @@ def main(args): parser.add_argument("--gbs", type=int, required=False, default=128, help="Global batch size") parser.add_argument("--mbs", type=int, required=False, default=2, help="Micro batch size") parser.add_argument("--lr", type=float, required=False, default=2.0e-06, help="Learning rate") - + parser.add_argument( + "--use_packed_sequence", + action="store_true", + ) args = parser.parse_args() main(args) diff --git a/tests/collections/multimodal/data/energon/test_data_module.py b/tests/collections/multimodal/data/energon/test_data_module.py index c499ecfe9ca4..dff153388f31 100644 --- a/tests/collections/multimodal/data/energon/test_data_module.py +++ b/tests/collections/multimodal/data/energon/test_data_module.py @@ -21,7 +21,7 @@ import numpy as np import webdataset as wds -from megatron.energon.flavors import BaseWebdataset +from megatron.energon.flavors import BaseWebdatasetFactory from PIL import Image from transformers import AutoProcessor @@ -159,7 +159,7 @@ def create_vqa_test_dataset(self, path: Path, num_samples: int): ) total_shards = shard_writer.shard - BaseWebdataset.prepare_dataset( + BaseWebdatasetFactory.prepare_dataset( path, [f"data-{{0..{total_shards-1}}}.tar"], split_parts_ratio=[("train", 1.0), ("val", 1.0)], diff --git a/tests/collections/vlm/mllama_train.py b/tests/collections/vlm/test_mllama_train.py similarity index 100% rename from tests/collections/vlm/mllama_train.py rename to tests/collections/vlm/test_mllama_train.py diff --git a/tests/collections/vlm/neva_train.py b/tests/collections/vlm/test_neva_train.py similarity index 95% rename from tests/collections/vlm/neva_train.py rename to tests/collections/vlm/test_neva_train.py index f1ddf961cb10..e12ce27702c2 100644 --- a/tests/collections/vlm/neva_train.py +++ b/tests/collections/vlm/test_neva_train.py @@ -37,6 +37,10 @@ def get_args(): parser.add_argument( '--experiment-dir', type=str, default=None, help="directory to write results and checkpoints to" ) + parser.add_argument( + "--use_packed_sequence", + action="store_true", + ) return parser.parse_args() @@ -49,6 +53,8 @@ def get_args(): mbs = 2 seq_length = 576 decoder_seq_length = 1024 + if args.use_packed_sequence: + decoder_seq_length = 2048 data = vlm.NevaMockDataModule( seq_length=decoder_seq_length, @@ -57,6 +63,7 @@ def get_args(): tokenizer=None, image_processor=None, num_workers=2, + packed_sequence=args.use_packed_sequence, ) # Transformer configurations