Skip to content

Commit

Permalink
option to not concatenate during pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 16, 2025
1 parent 8606093 commit 867cbec
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,8 @@ def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None

if self.cfg.model_config_type == "mamba":
Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,12 @@ class Config:
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)

batch_flattening: Optional[Union[Literal["auto"], bool]] = None

Expand Down
16 changes: 16 additions & 0 deletions src/axolotl/utils/data/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def encode_pretraining(
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]:
res = tokenizer(
examples[text_column],
Expand All @@ -33,6 +34,13 @@ def encode_pretraining(
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}

new_input_ids = []
new_labels = []
new_attention_mask = []
Expand Down Expand Up @@ -198,6 +206,14 @@ def wrap_pretraining_dataset(
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
elif cfg.pretraining_sample_concatenation is False:
encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=False,
)
else:
encode = functools.partial(
encode_pretraining,
Expand Down

0 comments on commit 867cbec

Please sign in to comment.