-
Notifications
You must be signed in to change notification settings - Fork 4
/
nllb_translation.py
199 lines (153 loc) · 5.99 KB
/
nllb_translation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"Utils for translation models"
import argparse
import time
from typing import List, Tuple
import mlx.core as mx
import numpy as np
from flores200_codes import FLORES_CODES
from transformers import M2M100Config, NllbTokenizer, PreTrainedTokenizerBase
from mlx_transformers.models import (
M2M100ForConditionalGeneration as MlxM2M100ForConditionalGeneration,
)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.
Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits / temperature, axis=-1)
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
# select tokens with cumulative probs below threshold
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
mx.zeros_like(sorted_probs),
)
sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
return token
def load_mlx_nllb_model(
model_name: str, source_language: str, target_language: str
) -> Tuple[MlxM2M100ForConditionalGeneration, PreTrainedTokenizerBase, int]:
"""
Load a nllb model and tokenizer from the given model name and weights path.
"""
if source_language in FLORES_CODES and target_language in FLORES_CODES:
source_language = FLORES_CODES[source_language]
target_language = FLORES_CODES[target_language]
config = M2M100Config.from_pretrained(model_name)
model = MlxM2M100ForConditionalGeneration(config)
model.from_pretrained(model_name)
tokenizer = NllbTokenizer.from_pretrained(
model_name, src_lang=source_language, tgt_lang=target_language
)
tgt_token_id = tokenizer.convert_tokens_to_ids(target_language)
return model, tokenizer, tgt_token_id
def sample(logits: mx.array, temp: float, top_p: float) -> Tuple[mx.array, float]:
"""
Sample a token from the logits.
Args:
logits (mx.array): Logits from the model
temp (float): Temperature for sampling
top_p (float): Top-p sampling value
Returns:
Tuple[mx.array, float]: Tuple of the token and its probability
"""
softmax_logits = mx.softmax(logits)
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
else:
token = mx.random.categorical(logits * (1 / temp))
prob = softmax_logits[0, token]
return token, prob
def run_translation_mlx(
model: MlxM2M100ForConditionalGeneration,
tokenizer: PreTrainedTokenizerBase,
input_sentences: List[str],
target_language_token: int,
max_generation_tokens: int,
temp: float = 0.0,
top_p: float = 1.0,
verbose: bool = False,
) -> List[str]:
"""
Run Translation using the MLX model.
Returns:
List[str]: List of translated sentences
"""
bsz = len(input_sentences)
tokens = tokenizer(input_sentences, return_tensors="np", padding="longest")
tokens = {key: mx.array(v) for key, v in tokens.items()}
decoder_input_ids = mx.array(
[[model.config.eos_token_id, target_language_token]] * bsz
)
decoder_input_mask = mx.array([[1, 1]] * bsz)
encoder_tokens = model.encode(tokens["input_ids"], tokens["attention_mask"])
start_time = time.time() # Start measuring time
for _ in range(max_generation_tokens):
outputs = model.decode(
decoder_input_ids,
decoder_input_mask,
encoder_tokens,
tokens["attention_mask"],
None,
)
logits = model.lm_head(outputs)[:, -1, :]
next_token, _ = sample(logits, temp, top_p)
decoder_input_ids = mx.concatenate(
[decoder_input_ids, next_token.reshape(-1, 1)], axis=1
)
decoder_input_mask = mx.concatenate(
[decoder_input_mask, mx.ones((bsz, 1))], axis=1
)
if next_token[0] == model.config.eos_token_id:
break
end_time = time.time() # Stop measuring time
if verbose:
total_tokens = decoder_input_ids.size - (bsz * 2)
total_time = end_time - start_time
token_per_sec = total_tokens / total_time
print(f"Generated {total_tokens} tokens in {total_time} seconds.")
print(f"Token/s: {token_per_sec}")
translated_sentences = tokenizer.batch_decode(
np.array(decoder_input_ids), skip_special_tokens=True
)
return translated_sentences
def main(args):
model, tokenizer, tgt_token_id = load_mlx_nllb_model(
model_name=args.model_name,
source_language=args.source_language,
target_language=args.target_language,
)
text_to_translate = args.text_to_translate
output = run_translation_mlx(
model=model,
tokenizer=tokenizer,
input_sentences=[text_to_translate],
target_language_token=tgt_token_id,
max_generation_tokens=args.max_generation_tokens,
verbose=True,
)
print(output)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--source_language", type=str, required=True)
parser.add_argument("--target_language", type=str, required=True)
parser.add_argument(
"--text_to_translate", type=str, default="Let us translate text to Yoruba"
)
parser.add_argument("--max_generation_tokens", type=int, default=20)
args = parser.parse_args()
main(args)