-
Notifications
You must be signed in to change notification settings - Fork 1
/
spec_decoding_deployment.py
279 lines (217 loc) · 12.4 KB
/
spec_decoding_deployment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import torch
import torch.nn as nn
import re
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
import deepspeed
import torch.distributed as dist
import json
import os
def json_loader(file_path):
"""
Input: file_path: Path to the json file containing all the queries.
File format looks like the following:
{"prompt": "A seated man cleans a shoe in a classroom setting with other individuals. the man"}
{"prompt": "Two girls are sailing on a lake. they"}
Output: This function returns a list of prompts to be used by the draft LLM.
"""
data = []
with open(file_path, 'r') as file:
for line in file:
data.append(json.loads(line.strip()))
return data
def rewrite_embedding(model_name="meta-llama/Llama-2-70b-hf"):
"""
This function rewrites your target LLM to include [PAD] token, which does not exist in the original LLaMA/LLaMA-2 tokenizer.
"""
# print(torch.cuda.device_count())
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
dist.init_process_group(backend='nccl', rank=local_rank, world_size=world_size)
access_token = "Your HF access token"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=access_token,
torch_dtype=torch.float16, cache_dir=os.environ['TRANSFORMERS_CACHE'])
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token = "[PAD]"
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
def tokenize_function(examples):
return tokenizer.batch_encode_plus(examples, padding="longest")['input_ids']
oracle_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-70b-hf", torch_dtype=torch.float16)
oracle_model.resize_token_embeddings(len(tokenizer))
oracle_model.config.pad_token_id = tokenizer.pad_token_id
oracle_model.embed_tokens = nn.Embedding(oracle_model.config.vocab_size, oracle_model.config.hidden_size,
padding_idx=oracle_model.config.pad_token_id)
oracle_model.save_pretrained("/path/to/your/new/model")
print("done")
def oracle_verification(model, input_tensor, max_new_tokens, local_rank, iter_count, past_key_values):
"""
Verifies the predictions of an oracle model by comparing the generated tokens against actual tokens.
Args:
model (torch.nn.Module): The oracle model used for generating predictions.
input_tensor (torch.Tensor): The input tensor containing tokens for prediction.
max_new_tokens (int): The maximum number of new tokens to be generated by the model.
local_rank (int): The local rank identifier for distributed training.
iter_count (int): The current iteration count, used to determine the first call to the model.
past_key_values (torch.Tensor): Cached past key values for accelerating generation in subsequent calls.
Returns:
tuple: A tuple containing:
- The positions of the first incorrect predictions in each row.
- The updated past_key_values tensor with dimensions adjusted based on the first incorrect prediction.
"""
# If it's the first iteration, generate outputs using the full input tensor
if iter_count == 1:
outputs = model(input_tensor.unsqueeze(0).cuda(local_rank), use_cache=True)
new_past_key_values = outputs.past_key_values
# Extract the token IDs predicted for the last max_new_tokens positions
next_token_id = outputs.logits[:, -max_new_tokens - 1:-1, :].argmax(dim=-1, keepdim=False)
else:
# For subsequent iterations, use past key values to accelerate generation
outputs = model(input_tensor.unsqueeze(0).cuda(local_rank), past_key_values=past_key_values, use_cache=True)
new_past_key_values = outputs.past_key_values
next_token_id = outputs.logits[:, :, :].argmax(dim=-1, keepdim=False)
# Extract the actual next tokens from the input tensor
actual_next_tokens_tensor = input_tensor[-max_new_tokens:]
# Compare the predicted next tokens with the actual next tokens
correct_predictions = (next_token_id == actual_next_tokens_tensor)
# Convert the boolean tensor to a float tensor for further processing
correct_predictions_float = correct_predictions.float()
# Initialize a tensor to hold the positions of the first incorrect prediction in each row
first_false_positions = torch.full((correct_predictions_float.size(0),), correct_predictions_float.size(1),
device=correct_predictions_float.device)
# Check if there's any incorrect prediction within the max_new_tokens limit and adjust past_key_values accordingly
if first_false_positions < max_new_tokens:
return first_false_positions, new_past_key_values[:, :, :-max_new_tokens + first_false_positions, :]
return first_false_positions, new_past_key_values
if __name__ == "__main__":
dist.init_process_group(backend='nccl')
world_size = dist.get_world_size()
rank = dist.get_rank()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
print(world_size, rank, local_rank)
os.environ['TRANSFORMERS_CACHE'] = "/your/tf/cache/dir" # Replace with your transformers cache directory
test_json = json_loader("./hellaswag.json") # Replace with your file path
rewrite_embedding(model_name="meta-llama/Llama-2-70b-hf")
# Set your target LLM and desired tp degrees here.
model_name = "/path/to/your/new/model"
tensor_parallel_degrees = 4
# Load the model on meta tensors from the config file.
# This prevents deepspeed from loading models multiple times on each rank.
config = AutoConfig.from_pretrained(model_name, cache_dir="/your/cache/dir/")
with deepspeed.OnDevice(dtype=torch.float16, device="meta", enabled=True):
oracle_model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
# Define the checkpoint dict. You may need to convert *.safetensors to
# *.bin for this work. Make sure you get all the *.bin and *.pt files in
# the checkpoint_files list.
checkpoint_dir = "/your/ckpt/dir/"
# Change ckpt names if your .bin files are named differently
checkpoint_files = [
os.path.join(checkpoint_dir, f"pytorch_model-{i:05d}-of-00029.bin")
for i in range(1, 30) # Change number of bin files based on your model
]
checkpoint_dict = {
"type": "DS_MODEL",
"checkpoints": checkpoint_files,
"version": 1.0,
}
oracle_model = deepspeed.init_inference(
oracle_model,
replace_with_kernel_inject=False,
tp={"tp_size": tensor_parallel_degrees, },
dtype=torch.float16,
checkpoint=checkpoint_dict,
)
# The LLaMA tokenizer does not have a pad token.
# Modify the tokenizer to add a pad token and change the model configs accordingly.
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf", padding_side="left",
torch_dtype=torch.float16)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token = "[PAD]"
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
# Feel free to change it to the draft model of your choice
draft_model = AutoModelForCausalLM.from_pretrained("minghaoyan/Wide-Sheared-LLaMA-796M", torch_dtype=torch.float16)
draft_model.resize_token_embeddings(len(tokenizer))
draft_model.config.pad_token_id = tokenizer.pad_token_id
draft_model.embed_tokens = nn.Embedding(draft_model.config.vocab_size, draft_model.config.hidden_size,
padding_idx=draft_model.config.pad_token_id)
# Launch the draft model with deepspeed on 1 node. Alternatively, you could use HF or load from a checkpoint.
draft_model = deepspeed.init_inference(
draft_model,
replace_with_kernel_inject=False,
tp={"tp_size": 1, },
dtype=torch.float16,
# checkpoint=checkpoint_dict,
)
current_prompts = []
curr_count = 0
# Set hyperparameters for speculative decoding
batch_size = 1
max_new_tokens = 7
output_file = "output-file.txt"
processed_batches = 0
for batch in test_json:
processed_batches += 1
# Adding the prompt from the current batch to a list for processing
current_prompts.append(batch['prompt'])
# Keeping track of how many prompts have been processed
curr_count += 1
if curr_count == batch_size:
# If the current count reaches the batch size, encode the batch of prompts
draft_input_ids = tokenizer.batch_encode_plus(current_prompts, padding='longest')
# Reset the prompts list and count for the next batch
current_prompts = []
curr_count = 0
# Calculating the maximum length for the generated sequence
max_length = 200 - max_new_tokens - 2
current_length = 0
iter_count = 0
# Initializing a tensor to keep track of total matched tokens
total_matched = torch.zeros(batch_size, dtype=torch.int32).cuda(local_rank)
# Generating sequences up to the max length
while current_length < max_length:
if iter_count == 0:
# For the first iteration, use the input prompt
iter_count += 1
output_len = len(draft_input_ids["input_ids"][0]) + max_new_tokens
input_tensors = torch.tensor(draft_input_ids["input_ids"]).cuda(local_rank)
else:
# For subsequent iterations, use new inputs based on matched tokens
output_len = len(new_inputs[0]) + max_new_tokens
input_tensors = new_inputs
if batch_size == 1:
input_tensors.unsqueeze(0)
# Generating predictions using the draft models if the local rank is within the models' range
padded_tensor = draft_model.generate(input_tensors, max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id).to(dtype=torch.int32)
# Gathering predictions from all devices in a distributed setting
gathered_padded_tensors = [torch.zeros(output_len, dtype=torch.int32).cuda(local_rank) for _ in
range(world_size)]
dist.all_gather(gathered_padded_tensors, padded_tensor.int())
cat_tensor = gathered_padded_tensors[0]
# Trimming the concatenated tensor if beyond the first iteration
if iter_count > 1:
cat_tensor = cat_tensor[-max_new_tokens:]
# Verifying the generated sequence against the ground truth
first_false_positions, past_key_values = oracle_verification(oracle_model, cat_tensor, max_new_tokens,
local_rank, iter_count, past_key_values)
matched_tokens = first_false_positions + torch.ones_like(first_false_positions)
input_list = []
for idx, matched in enumerate(matched_tokens):
input_list.append(torch.cat((torch.zeros(torch.max(matched_tokens) - matched_tokens[idx],
dtype=torch.int32).cuda(local_rank),
input_tensors[idx],
cat_tensor[idx][:matched_tokens[idx]]),
dim=0))
new_inputs = torch.stack(input_list)
# print(f"new_inputs: {new_inputs[:, -5:]}")
total_matched += matched_tokens
# Logging the total matched tokens to a file if the local rank is 0
if local_rank == 0:
with open(output_file, "a") as f:
f.write(str(total_matched.tolist()) + str("\n"))
# Updating the current length for the next iteration
current_length = min(total_matched)
else:
continue