Skip to content

Commit

Permalink
Support reverse probability mode
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Nov 6, 2024
1 parent a50ae9e commit b62066a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
3 changes: 2 additions & 1 deletion petagraph/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ def get_dataloader_from_data_stage(
create_attention_mask=True,
log_directory=trainer.config.checkpoints.checkpoints_path,
rank=global_rank,
packed=True
packed=True,
reverse_probability=data.reverse_probability,
)


Expand Down
1 change: 1 addition & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class DataArgs:
sequence_files_path: Optional[str] = None
prefetch_buffer_seq_size: Optional[int] = 1
all_sequences_resources_path: Optional[str] = None
reverse_probability: float = 0.0

def __post_init__(self):
if self.seed is None:
Expand Down
16 changes: 13 additions & 3 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ class PetaGraphStreamDatasetV2(torch.utils.data.IterableDataset):
The sequence length at which to switch from sampling to keeping the sequence
below the inflection point we only keep the sequence with a probability pr
to its length. Above the inflection point we always keep the sequence.
reverse_probability : float
The probability to reverse the sequence. Only active if != 0.0
"""

def __init__(self,
Expand All @@ -550,13 +552,15 @@ def __init__(self,
log_directory: Path = None,
rank: int = 0,
packed: bool = False,
sampling_seq_len_inflection: int = 1024
sampling_seq_len_inflection: int = 1024,
reverse_probability: float = 0.0
):

self.maxlen = maxlen
self.create_attention_mask = create_attention_mask
self.debug = debug
self.sampling_seq_len_inflection = sampling_seq_len_inflection
self.reverse_probability = reverse_probability


self.logger = logger
Expand All @@ -567,6 +571,8 @@ def __init__(self,
self.logging_func(f"[PetaGraphStreamDataset] Num. URLs: {len(url_list)}")
self.logging_func(f"[PetaGraphStreamDataset] From Cloud: {from_cloud}")
self.logging_func(f"[PetaGraphStreamDataset] Sampling Seq. Len. Inflection: {self.sampling_seq_len_inflection}")
if self.reverse_probability > 0.0:
self.logging_func(f"[PetaGraphStreamDataset] Reverse Probability: {self.reverse_probability}")

self.VOCAB = vocabulary
self._pad_token_id = self.VOCAB["PAD"]
Expand Down Expand Up @@ -858,6 +864,10 @@ def tokenize_and_pad(self, input_sequence: str, apply_pad: bool = True):
tokenized_sequence.append(self._eos_token_id) # end with EOS token
tokenized_sequence = np.array(tokenized_sequence, dtype=np.int32)

if self.reverse_probability > 0.0:
if np.random.rand() < self.reverse_probability:
tokenized_sequence = tokenized_sequence[::-1]

# Pad the sequence
if apply_pad and len(tokenized_sequence) < maxlen:
# 2 is the PAD token
Expand Down Expand Up @@ -964,8 +974,8 @@ def generate(self):
current_tokens = new_tokens
else:
# Check the last token of the current sequence
# is an EOS token
assert current_tokens[-1] == self._eos_token_id
# is an EOS token or BOS token (if reverse_probability > 0.0)
assert current_tokens[-1] == self._eos_token_id or (self.reverse_probability > 0.0 and current_tokens[-1] == self._bos_token_id)
current_tokens = np.concatenate([current_tokens, new_tokens])

if len(current_tokens) >= self.maxlen:
Expand Down

0 comments on commit b62066a

Please sign in to comment.