forked from jameswdelancey/llama3.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rung.c
1461 lines (1279 loc) · 56.9 KB
/
rung.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Inference for Llama-3 Transformer model in pure C, targeting Nvidia RTX 3090 GPU with BF16 support */
#include <ctype.h>
#include <cublas_v2.h>
#include <cublas_api.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <fcntl.h>
#include <immintrin.h>
#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#if defined _WIN32
#include "win.h"
#else
#include <sys/mman.h>
#include <unistd.h>
#endif
static CUfunction batched_softmax_kernel;
static CUfunction fp32_to_bf16_kernel;
static CUfunction swiGLU_kernel;
static CUfunction rope_rotary_encoding_kernel;
static CUfunction rmsnorm_kernel;
// Define USE_CUDA to enable CUDA GPU acceleration
#define USE_CUDA
// ----------------------------------------------------------------------------
// CUDA error checking
#define CHECK_CUDA(call) \
do { \
cudaError_t err = (call); \
if (err != cudaSuccess) { \
fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define CHECK_CUBLAS(call) \
do { \
cublasStatus_t status = (call); \
if (status != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLAS error at %s:%d: %d\n", __FILE__, __LINE__, status); \
exit(EXIT_FAILURE); \
} \
} while (0)
// ----------------------------------------------------------------------------
// BF16 Utilities
static inline float bf16_to_fp32(uint16_t bf16) {
union {
uint32_t u32;
float fp32;
} v;
v.u32 = ((uint32_t)bf16) << 16;
return v.fp32;
}
static inline uint16_t fp32_to_bf16(float fp32) {
union {
uint32_t u32;
float fp32;
} v;
v.fp32 = fp32;
return (uint16_t)(v.u32 >> 16);
}
// Utility function to convert FP32 to BF16
// We assume no subnormal numbers are passed
void fp32_to_bf16_array(uint16_t *output, float *input, size_t size) {
if (input == NULL || output == NULL || size == 0) {
return;
}
size_t i = 0;
for (; i + 8 <= size; i += 8) {
__m256 input_vec = _mm256_loadu_ps(input + i);
// Do BF16 conversion
__m256i int_vec = _mm256_castps_si256(input_vec);
__m256i shifted_vec = _mm256_srli_epi32(int_vec, 16);
__m128i upper_half = _mm256_extracti128_si256(shifted_vec, 1);
__m128i lower_half = _mm256_extracti128_si256(shifted_vec, 0);
__m128i bf16 = _mm_packus_epi32(lower_half, upper_half);
_mm_storeu_si128((__m128i *)(output + i), bf16);
}
// Handle remaining elements (if size is not a multiple of 8)
for (; i < size; ++i) {
output[i] = fp32_to_bf16(input[i]);
}
}
void bf16_to_fp32_array(float *out, uint16_t *in, size_t n) {
for (size_t i = 0; i < n; i++) {
out[i] = bf16_to_fp32(in[i]);
}
}
// ----------------------------------------------------------------------------
// Transformer model
typedef struct {
int dim; // transformer dimension
int hidden_dim; // for ffn layers
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 4096 (byte-level)
int seq_len; // max sequence length
} Config;
typedef struct {
// token embedding table
float *token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
float *rms_att_weight; // (layer, dim) rmsnorm weights
float *rms_ffn_weight; // (layer, dim)
// weights for matmuls. note dim == n_heads * head_size
uint16_t *wq; // (layer, dim, n_heads * head_size)
uint16_t *wk; // (layer, dim, n_kv_heads * head_size)
uint16_t *wv; // (layer, dim, n_kv_heads * head_size)
uint16_t *wo; // (layer, n_heads * head_size, dim)
// weights for ffn
uint16_t *w1; // (layer, hidden_dim, dim)
uint16_t *w2; // (layer, dim, hidden_dim)
uint16_t *w3; // (layer, hidden_dim, dim)
// final rmsnorm
float *rms_final_weight; // (dim,)
// (optional) classifier weights for the logits, on the last layer
uint16_t *wcls;
// storage for the original fp32 weights
float *rms_att_weight_fp32;
float *rms_ffn_weight_fp32;
float *wq_fp32;
float *wk_fp32;
float *wv_fp32;
float *wo_fp32;
float *w1_fp32;
float *w2_fp32;
float *w3_fp32;
float *rms_final_weight_fp32;
float *wcls_fp32;
} TransformerWeights;
typedef struct {
// current wave of activations
float *x; // activation at current time stamp (dim,)
float *xb; // same, but inside a residual branch (dim,)
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
float *q; // query (dim,)
float *k; // key (dim,)
float *v; // value (dim,)
float *att; // buffer for scores/attention values (n_heads, seq_len)
float *logits; // output logits
void **ptrs; // Device Pointers
uint16_t *xb_bf16;
uint16_t *hb_bf16;
// kv cache
float *key_cache; // (layer, seq_len, dim)
float *value_cache; // (layer, seq_len, dim)
} RunState;
typedef struct {
Config config; // the hyperparameters of the architecture (the blueprint)
TransformerWeights weights; // the weights of the model
RunState state; // buffers for the "wave" of activations in the forward pass
TransformerWeights weights_gpu; // GPU version of weights
RunState state_gpu; // GPU version of RunState
cublasHandle_t handle;
// some more state needed to properly clean up the memory mapping (sigh)
int fd; // file descriptor for memory mapping
float *data; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
} Transformer;
void *calloc_aligned(size_t num, size_t size) {
size_t total_size = num * size;
void *ptr = NULL;
if (posix_memalign(&ptr, 64, total_size) != 0) {
return NULL;
}
memset(ptr, 0, total_size);
return ptr;
}
void malloc_run_state(RunState *s, Config *p) {
// we calloc instead of malloc to keep valgrind happy
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
s->x = calloc(p->dim, sizeof(float));
s->xb = calloc(p->dim, sizeof(float));
s->hb = calloc(p->hidden_dim, sizeof(float));
s->hb2 = calloc(p->hidden_dim, sizeof(float));
s->q = calloc(p->dim, sizeof(float));
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
s->ptrs = calloc(3 * 2 + 5 * p->n_heads, sizeof(void *));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->hb || !s->hb2 || !s->q || !s->key_cache || !s->value_cache || !s->att || !s->logits) {
fprintf(stderr, "malloc failed!\n");
exit(EXIT_FAILURE);
}
}
void malloc_run_state_gpu(RunState *s, Config *p) {
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
CHECK_CUDA(cudaMalloc((void **)&s->x, p->dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&s->xb, p->dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&s->hb, p->hidden_dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&s->hb2, p->hidden_dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&s->q, p->dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&s->key_cache, p->n_layers * p->seq_len * kv_dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&s->value_cache, p->n_layers * p->seq_len * kv_dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&s->att, p->n_heads * p->seq_len * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&s->logits, p->vocab_size * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&s->ptrs, (3 * 2 + 5 * p->n_heads) * sizeof(float *)));
CHECK_CUDA(cudaMalloc((void **)&s->xb_bf16, p->dim * sizeof(uint16_t)));
CHECK_CUDA(cudaMalloc((void **)&s->hb_bf16, p->hidden_dim * sizeof(uint16_t)));
}
void free_run_state(RunState *s) {
free(s->x);
free(s->xb);
free(s->hb);
free(s->hb2);
free(s->q);
free(s->att);
free(s->logits);
free(s->key_cache);
free(s->value_cache);
free(s->ptrs);
}
void free_run_state_gpu(RunState *s) {
if (s->x)
CHECK_CUDA(cudaFree(s->x));
if (s->xb)
CHECK_CUDA(cudaFree(s->xb));
if (s->hb)
CHECK_CUDA(cudaFree(s->hb));
if (s->hb2)
CHECK_CUDA(cudaFree(s->hb2));
if (s->q)
CHECK_CUDA(cudaFree(s->q));
if (s->key_cache)
CHECK_CUDA(cudaFree(s->key_cache));
if (s->value_cache)
CHECK_CUDA(cudaFree(s->value_cache));
if (s->att)
CHECK_CUDA(cudaFree(s->att));
if (s->logits)
CHECK_CUDA(cudaFree(s->logits));
if (s->ptrs)
CHECK_CUDA(cudaFree(s->ptrs));
if (s->xb_bf16)
CHECK_CUDA(cudaFree(s->xb_bf16));
if (s->hb_bf16)
CHECK_CUDA(cudaFree(s->hb_bf16));
}
void memory_map_weights(TransformerWeights *w, Config *p, float *ptr, int shared_weights) {
int head_size = p->dim / p->n_heads;
unsigned long long n_layers = p->n_layers;
// Store the FP32 pointers
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
w->rms_att_weight_fp32 = ptr;
ptr += n_layers * p->dim;
w->wq_fp32 = ptr;
ptr += n_layers * p->dim * (p->n_heads * head_size);
w->wk_fp32 = ptr;
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
w->wv_fp32 = ptr;
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
w->wo_fp32 = ptr;
ptr += n_layers * (p->n_heads * head_size) * p->dim;
w->rms_ffn_weight_fp32 = ptr;
ptr += n_layers * p->dim;
w->w1_fp32 = ptr;
ptr += n_layers * p->dim * p->hidden_dim;
w->w2_fp32 = ptr;
ptr += n_layers * p->hidden_dim * p->dim;
w->w3_fp32 = ptr;
ptr += n_layers * p->dim * p->hidden_dim;
w->rms_final_weight_fp32 = ptr;
ptr += p->dim;
ptr += p->seq_len * head_size / 2;
ptr += p->seq_len * head_size / 2;
w->wcls_fp32 = shared_weights ? w->token_embedding_table : ptr;
}
void malloc_weights_gpu(TransformerWeights *w, Config *p) {
int head_size = p->dim / p->n_heads;
unsigned long long n_layers = p->n_layers;
// CHECK_CUDA(cudaMalloc((void **)&w->token_embedding_table, p->vocab_size * p->dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&w->rms_att_weight, n_layers * p->dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&w->wq, n_layers * p->dim * (p->n_heads * head_size) * sizeof(uint16_t)));
CHECK_CUDA(cudaMalloc((void **)&w->wk, n_layers * p->dim * (p->n_kv_heads * head_size) * sizeof(uint16_t)));
CHECK_CUDA(cudaMalloc((void **)&w->wv, n_layers * p->dim * (p->n_kv_heads * head_size) * sizeof(uint16_t)));
CHECK_CUDA(cudaMalloc((void **)&w->wo, n_layers * (p->n_heads * head_size) * p->dim * sizeof(uint16_t)));
CHECK_CUDA(cudaMalloc((void **)&w->rms_ffn_weight, n_layers * p->dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&w->w1, n_layers * p->dim * p->hidden_dim * sizeof(uint16_t)));
CHECK_CUDA(cudaMalloc((void **)&w->w2, n_layers * p->hidden_dim * p->dim * sizeof(uint16_t)));
CHECK_CUDA(cudaMalloc((void **)&w->w3, n_layers * p->dim * p->hidden_dim * sizeof(uint16_t)));
CHECK_CUDA(cudaMalloc((void **)&w->rms_final_weight, p->dim * sizeof(float)));
CHECK_CUDA(cudaMalloc((void **)&w->wcls, p->vocab_size * p->dim * sizeof(uint16_t)));
}
void free_weights_gpu(TransformerWeights *w) {
// CHECK_CUDA(cudaFree(w->token_embedding_table));
CHECK_CUDA(cudaFree(w->rms_att_weight));
CHECK_CUDA(cudaFree(w->wq));
CHECK_CUDA(cudaFree(w->wk));
CHECK_CUDA(cudaFree(w->wv));
CHECK_CUDA(cudaFree(w->wo));
CHECK_CUDA(cudaFree(w->rms_ffn_weight));
CHECK_CUDA(cudaFree(w->w1));
CHECK_CUDA(cudaFree(w->w2));
CHECK_CUDA(cudaFree(w->w3));
CHECK_CUDA(cudaFree(w->rms_final_weight));
CHECK_CUDA(cudaFree(w->wcls));
}
void copy_weights_to_gpu(TransformerWeights *dest_gpu, TransformerWeights *src, Config *p) {
int head_size = p->dim / p->n_heads;
unsigned long long n_layers = p->n_layers;
// Allocate temporary buffers for BF16 conversion
uint16_t *temp_buffer;
// 1. token_embedding_table
size_t size = p->vocab_size * p->dim;
// CHECK_CUDA(cudaMemcpy(dest_gpu->token_embedding_table, src->token_embedding_table, size * sizeof(float), cudaMemcpyHostToDevice));
// 2. rms_att_weight
size = n_layers * p->dim;
CHECK_CUDA(cudaMemcpy(dest_gpu->rms_att_weight, src->rms_att_weight_fp32, size * sizeof(float), cudaMemcpyHostToDevice));
// 3. wq
size = n_layers * p->dim * (p->n_heads * head_size);
CHECK_CUDA(cudaMallocHost((void **)&temp_buffer, size * sizeof(uint16_t)));
fp32_to_bf16_array(temp_buffer, src->wq_fp32, size);
CHECK_CUDA(cudaMemcpy(dest_gpu->wq, temp_buffer, size * sizeof(uint16_t), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaFreeHost(temp_buffer));
// 4. wk
size = n_layers * p->dim * (p->n_kv_heads * head_size);
CHECK_CUDA(cudaMallocHost((void **)&temp_buffer, size * sizeof(uint16_t)));
fp32_to_bf16_array(temp_buffer, src->wk_fp32, size);
CHECK_CUDA(cudaMemcpy(dest_gpu->wk, temp_buffer, size * sizeof(uint16_t), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaFreeHost(temp_buffer));
// 5. wv
size = n_layers * p->dim * (p->n_kv_heads * head_size);
CHECK_CUDA(cudaMallocHost((void **)&temp_buffer, size * sizeof(uint16_t)));
fp32_to_bf16_array(temp_buffer, src->wv_fp32, size);
CHECK_CUDA(cudaMemcpy(dest_gpu->wv, temp_buffer, size * sizeof(uint16_t), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaFreeHost(temp_buffer));
// 6. wo
size = n_layers * (p->n_heads * head_size) * p->dim;
CHECK_CUDA(cudaMallocHost((void **)&temp_buffer, size * sizeof(uint16_t)));
fp32_to_bf16_array(temp_buffer, src->wo_fp32, size);
CHECK_CUDA(cudaMemcpy(dest_gpu->wo, temp_buffer, size * sizeof(uint16_t), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaFreeHost(temp_buffer));
// 7. rms_ffn_weight
size = n_layers * p->dim;
CHECK_CUDA(cudaMemcpy(dest_gpu->rms_ffn_weight, src->rms_ffn_weight_fp32, size * sizeof(float), cudaMemcpyHostToDevice));
// 8. w1
size = n_layers * p->dim * p->hidden_dim;
CHECK_CUDA(cudaMallocHost((void **)&temp_buffer, size * sizeof(uint16_t)));
fp32_to_bf16_array(temp_buffer, src->w1_fp32, size);
CHECK_CUDA(cudaMemcpy(dest_gpu->w1, temp_buffer, size * sizeof(uint16_t), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaFreeHost(temp_buffer));
// 9. w2
size = n_layers * p->hidden_dim * p->dim;
CHECK_CUDA(cudaMallocHost((void **)&temp_buffer, size * sizeof(uint16_t)));
fp32_to_bf16_array(temp_buffer, src->w2_fp32, size);
CHECK_CUDA(cudaMemcpy(dest_gpu->w2, temp_buffer, size * sizeof(uint16_t), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaFreeHost(temp_buffer));
// 10. w3
size = n_layers * p->dim * p->hidden_dim;
CHECK_CUDA(cudaMallocHost((void **)&temp_buffer, size * sizeof(uint16_t)));
fp32_to_bf16_array(temp_buffer, src->w3_fp32, size);
CHECK_CUDA(cudaMemcpy(dest_gpu->w3, temp_buffer, size * sizeof(uint16_t), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaFreeHost(temp_buffer));
// 11. rms_final_weight
size = p->dim;
CHECK_CUDA(cudaMemcpy(dest_gpu->rms_final_weight, src->rms_final_weight_fp32, size * sizeof(float), cudaMemcpyHostToDevice));
// 12. wcls
size = p->vocab_size * p->dim;
CHECK_CUDA(cudaMallocHost((void **)&temp_buffer, size * sizeof(uint16_t)));
fp32_to_bf16_array(temp_buffer, src->wcls_fp32, size);
CHECK_CUDA(cudaMemcpy(dest_gpu->wcls, temp_buffer, size * sizeof(uint16_t), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaFreeHost(temp_buffer));
}
void read_checkpoint(char *checkpoint, Config *config, TransformerWeights *weights, int *fd, float **data, ssize_t *file_size) {
FILE *file = fopen(checkpoint, "rb");
if (!file) {
fprintf(stderr, "Couldn't open file %s\n", checkpoint);
exit(EXIT_FAILURE);
}
// read in the config header
if (fread(config, sizeof(Config), 1, file) != 1) {
exit(EXIT_FAILURE);
}
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config->vocab_size > 0 ? 1 : 0;
config->vocab_size = abs(config->vocab_size);
// figure out the file size
#if defined _WIN32
_fseeki64(file, 0, SEEK_END); // move file pointer to end of file
*file_size = _ftelli64(file); // get the file size, in bytes
#else
fseek(file, 0, SEEK_END); // move file pointer to end of file
*file_size = ftell(file); // get the file size, in bytes
#endif
fclose(file);
// memory map the Transformer weights into the data pointer
*fd = open(checkpoint, O_RDONLY); // open in read only mode
if (*fd == -1) {
fprintf(stderr, "open failed!\n");
exit(EXIT_FAILURE);
}
*data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
if (*data == MAP_FAILED) {
fprintf(stderr, "mmap failed!\n");
exit(EXIT_FAILURE);
}
float *weights_ptr = *data + sizeof(Config) / sizeof(float);
memory_map_weights(weights, config, weights_ptr, shared_weights);
}
void build_transformer(Transformer *t, char *checkpoint_path) {
// read in the Config and the Weights from the checkpoint
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
// allocate the RunState buffers
malloc_run_state(&t->state, &t->config);
// allocate and copy weights to GPU
malloc_weights_gpu(&t->weights_gpu, &t->config);
copy_weights_to_gpu(&t->weights_gpu, &t->weights, &t->config);
// allocate GPU buffers
malloc_run_state_gpu(&t->state_gpu, &t->config);
// Create cublas handle
CHECK_CUBLAS(cublasCreate(&t->handle));
CHECK_CUBLAS(cublasSetMathMode(t->handle, CUBLAS_DEFAULT_MATH));
}
void free_transformer(Transformer *t) {
// close the memory mapping
if (t->data != MAP_FAILED) {
munmap(t->data, t->file_size);
}
if (t->fd != -1) {
close(t->fd);
}
// free the RunState buffers
// free_run_state(&t->state);
free_run_state_gpu(&t->state_gpu);
free_weights_gpu(&t->weights_gpu);
cublasDestroy(t->handle);
}
// ----------------------------------------------------------------------------
// neural net blocks; the dynamics of the Transformer
void rmsnorm(float *o, float *x, float *weight, int size) {
// calculate sum of squares
float ss = 0.0f;
for (int j = 0; j < size; j++) {
ss += x[j] * x[j];
}
ss /= size;
ss += 1e-5f;
ss = 1.0f / sqrtf(ss);
// normalize and scale
for (int j = 0; j < size; j++) {
o[j] = weight[j] * (ss * x[j]);
}
}
void softmax(float *x, int size) {
// find max value (for numerical stability)
float max_val = x[0];
for (int i = 1; i < size; i++) {
if (x[i] > max_val) {
max_val = x[i];
}
}
// exp and sum
float sum = 0.0f;
for (int i = 0; i < size; i++) {
x[i] = expf(x[i] - max_val);
sum += x[i];
}
// normalize
for (int i = 0; i < size; i++) {
x[i] /= sum;
}
}
// Launch function for rmsnorm
void rmsnorm_gpu(float *d_o, float *d_x, float *d_weight, int size, cublasHandle_t cublas_handle) {
void *args[] = {&d_o, &d_x, &d_weight, &size};
// Configure kernel launch parameters
int threads_per_block = 1024;
int blocks_per_grid = 1;
size_t sharedMemSize = (threads_per_block / 32) * sizeof(float);
// Launch the kernel
CHECK_CUDA(cuLaunchKernel(rmsnorm_kernel, blocks_per_grid, 1, 1, // grid dimensions
threads_per_block, 1, 1, // block dimensions
sharedMemSize, NULL, // shared memory and stream
args, 0)); // arguments
}
void batched_softmax_gpu(float *x, int size, int batch_size) {
// Kernel parameters
void *args[] = {&x, &size, &batch_size};
// Kernel launch configuration
int threadsPerBlock = 1024;
int sharedMemSize = (threadsPerBlock / 32) * sizeof(float); // Shared memory for reductions
int gridDim = batch_size;
// Launch kernel
CHECK_CUDA(cuLaunchKernel(batched_softmax_kernel, gridDim, 1, 1, // Grid dimensions
threadsPerBlock, 1, 1, // Block dimensions
sharedMemSize, 0, // Shared memory and stream
args, NULL));
}
void fp32_to_bf16_array_gpu(uint16_t *out, float *in, size_t n) {
int threadsPerBlock = 256;
int blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock;
void *args[] = {&in, &out, &n};
CHECK_CUDA(cuLaunchKernel(fp32_to_bf16_kernel, blocksPerGrid, 1, 1, // Grid dimensions
threadsPerBlock, 1, 1, // Block dimensions
0, 0, // Shared memory and stream
args, NULL));
}
void swiGLU_gpu(float *hb, float *hb2, int hidden_dim) {
int threadsPerBlock = 256;
int blocksPerGrid = (hidden_dim + threadsPerBlock - 1) / threadsPerBlock;
void *args[] = {&hb, &hb2, &hidden_dim};
CHECK_CUDA(cuLaunchKernel(swiGLU_kernel, blocksPerGrid, 1, 1, // Grid dimensions
threadsPerBlock, 1, 1, // Block dimensions
0, 0, // Shared memory and stream
args, NULL));
}
// Function to launch the kernel
void rope_rotary_encoding_gpu(float *q_device, float *k_device, int n_heads, int n_kv_heads, int head_size, int pos) {
// Define grid and block sizes
int threadsPerBlock = head_size / 2; // Each thread processes two positions (j, j+1)
int blocksPerGrid = n_heads; // One block per attention head
// Kernel arguments
void *args[] = {&q_device, &k_device, &n_heads, &n_kv_heads, &head_size, &pos};
// Launch the kernel
CHECK_CUDA(cuLaunchKernel(rope_rotary_encoding_kernel, blocksPerGrid, 1, 1, // Grid dimensions
threadsPerBlock, 1, 1, // Block dimensions
0, 0, // Shared memory and stream
args, NULL));
}
float *forward(Transformer *transformer, int token, int pos) {
// a few convenience variables
Config *p = &transformer->config;
TransformerWeights *w = &transformer->weights_gpu;
RunState *s = &transformer->state_gpu;
float *x = s->x;
int dim = p->dim;
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
int hidden_dim = p->hidden_dim;
int head_size = dim / p->n_heads;
cublasHandle_t handle = transformer->handle;
const float one = 1.0f;
const float zero = 0.0f;
// copy the token embedding into x -- **Corrected section**
CHECK_CUDA(cudaMemcpy(x, transformer->weights.token_embedding_table + token * dim, dim * sizeof(float), cudaMemcpyHostToDevice));
// Setup device memory pointers for batched operations
void **ptrs = s->ptrs;
float **q_pointers_d, **k_pointers_d, **v_pointers_d, **att_pointers_d, **xb_pointers_d;
void **w_pointers_d, **xb_bf16_pointers_d, **h_pointers_d;
q_pointers_d = (float **)&ptrs[p->n_heads * 0];
k_pointers_d = (float **)&ptrs[p->n_heads * 1];
v_pointers_d = (float **)&ptrs[p->n_heads * 2];
att_pointers_d = (float **)&ptrs[p->n_heads * 3];
xb_pointers_d = (float **)&ptrs[p->n_heads * 4];
w_pointers_d = &ptrs[p->n_heads * 5 + 2 * 0];
xb_bf16_pointers_d = &ptrs[p->n_heads * 5 + 2 * 1];
h_pointers_d = &ptrs[p->n_heads * 5 + 2 * 2];
// Setup host memory for device pointers
void **ptrs_h = transformer->state.ptrs;
float **q_pointers_h = (void *)&ptrs_h[p->n_heads * 0];
float **k_pointers_h = (void *)&ptrs_h[p->n_heads * 1];
float **v_pointers_h = (void *)&ptrs_h[p->n_heads * 2];
float **att_pointers_h = (void *)&ptrs_h[p->n_heads * 3];
float **xb_pointers_h = (void *)&ptrs_h[p->n_heads * 4];
float **w_pointers_h = (void *)&ptrs_h[p->n_heads * 5 + 2 * 0];
uint16_t **xb_bf16_pointers_h = (void *)&ptrs_h[p->n_heads * 5 + 2 * 1];
float **h_pointers_h = (void *)&ptrs_h[p->n_heads * 5 + 2 * 2];
float invsqrt_head_size = 1.0f / sqrtf(head_size);
// forward all the layers
for (unsigned long long l = 0; l < p->n_layers; l++) {
// key and value point to the kv cache
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
s->k = s->key_cache + loff + pos * kv_dim;
s->v = s->value_cache + loff + pos * kv_dim;
// Initialize the host pointers to point to the correct locations in the GPU memory
for (int h = 0; h < p->n_heads; ++h) {
q_pointers_h[h] = s->q + h * head_size;
k_pointers_h[h] = s->key_cache + loff + (h / kv_mul) * head_size;
v_pointers_h[h] = s->value_cache + loff + (h / kv_mul) * head_size;
att_pointers_h[h] = s->att + h * (pos + 1);
xb_pointers_h[h] = s->xb + h * head_size;
}
// Initialize host pointers for w1 and w3
w_pointers_h[0] = (float *)(w->w1 + l * dim * hidden_dim);
w_pointers_h[1] = (float *)(w->w3 + l * dim * hidden_dim);
xb_bf16_pointers_h[0] = s->xb_bf16;
xb_bf16_pointers_h[1] = s->xb_bf16;
// Initialize host pointers for hb and hb2 (outputs)
h_pointers_h[0] = s->hb;
h_pointers_h[1] = s->hb2;
// Copy the arrays of pointers from host to device
CHECK_CUDA(cudaMemcpy(ptrs, ptrs_h, (2 * 3 + p->n_heads * 5) * sizeof(void *), cudaMemcpyHostToDevice));
// attention rmsnorm
rmsnorm_gpu(s->xb, x, w->rms_att_weight + l * dim, dim, handle);
fp32_to_bf16_array_gpu(s->xb_bf16, s->xb, dim);
CHECK_CUBLAS(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 1, dim, dim, &one, s->xb_bf16, CUDA_R_16BF, 1, w->wq + l * dim * dim, CUDA_R_16BF, dim, &zero, s->q, CUDA_R_32F, 1,
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
CHECK_CUBLAS(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 1, kv_dim, dim, &one, s->xb_bf16, CUDA_R_16BF, 1, w->wk + l * dim * kv_dim, CUDA_R_16BF, dim, &zero, s->k,
CUDA_R_32F, 1, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
CHECK_CUBLAS(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 1, kv_dim, dim, &one, s->xb_bf16, CUDA_R_16BF, 1, w->wv + l * dim * kv_dim, CUDA_R_16BF, dim, &zero, s->v,
CUDA_R_32F, 1, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
rope_rotary_encoding_gpu(s->q, s->k, p->n_heads, p->n_kv_heads, head_size, pos);
// 2. Multiply Q by K^T for each head to get attention scores
CHECK_CUBLAS(cublasGemmBatchedEx(transformer->handle, CUBLAS_OP_T, CUBLAS_OP_N, pos + 1, 1, head_size, &invsqrt_head_size, (const void *const *)k_pointers_d, CUDA_R_32F,
kv_dim, (const void *const *)q_pointers_d, CUDA_R_32F, head_size, &zero, (void *const *)att_pointers_d, CUDA_R_32F, pos + 1, p->n_heads,
CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
batched_softmax_gpu(s->att, pos + 1, p->n_heads);
// 4. Multiply each attention matrix by V
CHECK_CUBLAS(cublasGemmBatchedEx(transformer->handle, CUBLAS_OP_N, CUBLAS_OP_N, head_size, 1, pos + 1, &one, (const void *const *)v_pointers_d, CUDA_R_32F, kv_dim,
(const void *const *)att_pointers_d, CUDA_R_32F, pos + 1, &zero, (void *const *)xb_pointers_d, CUDA_R_32F, head_size, p->n_heads, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT));
// final matmul to get the output of the attention
fp32_to_bf16_array_gpu(s->xb_bf16, s->xb, dim);
CHECK_CUBLAS(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 1, dim, dim, &one, s->xb_bf16, CUDA_R_16BF, 1, w->wo + l * dim * dim, CUDA_R_16BF, dim, &one, x, CUDA_R_32F, 1,
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// ffn rmsnorm
rmsnorm_gpu(s->xb, x, w->rms_ffn_weight + l * dim, dim, handle);
fp32_to_bf16_array_gpu(s->xb_bf16, s->xb, dim);
// --- Perform batched matrix multiplication ---
CHECK_CUBLAS(cublasGemmBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 1, hidden_dim, dim, &one, (const void *const *)xb_bf16_pointers_d, CUDA_R_16BF, 1,
(const void *const *)w_pointers_d, CUDA_R_16BF, dim, &zero, (void *const *)h_pointers_d, CUDA_R_32F, 1, 2, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
swiGLU_gpu(s->hb, s->hb2, p->hidden_dim);
fp32_to_bf16_array_gpu(s->hb_bf16, s->hb, hidden_dim);
CHECK_CUBLAS(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 1, dim, hidden_dim, &one, s->hb_bf16, CUDA_R_16BF, 1, w->w2 + l * dim * hidden_dim, CUDA_R_16BF, hidden_dim, &one,
x, CUDA_R_32F, 1, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// final rmsnorm
rmsnorm_gpu(x, x, w->rms_final_weight, dim, handle);
// classifier into logits
fp32_to_bf16_array_gpu(s->xb_bf16, x, dim);
CHECK_CUBLAS(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 1, p->vocab_size, dim, &one, s->xb_bf16, CUDA_R_16BF, 1, w->wcls, CUDA_R_16BF, dim, &zero, s->logits, CUDA_R_32F, 1,
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
return s->logits;
}
// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
typedef struct {
char *str;
int id;
} TokenIndex;
typedef struct {
char **vocab;
float *vocab_scores;
TokenIndex *sorted_vocab;
int vocab_size;
unsigned int max_token_length;
unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;
int compare_tokens(const void *a, const void *b) { return strcmp(((TokenIndex *)a)->str, ((TokenIndex *)b)->str); }
void build_tokenizer(Tokenizer *t, char *tokenizer_path, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings
t->vocab = (char **)malloc(vocab_size * sizeof(char *));
t->vocab_scores = (float *)malloc(vocab_size * sizeof(float));
t->sorted_vocab = NULL; // initialized lazily
for (int i = 0; i < 256; i++) {
t->byte_pieces[i * 2] = (unsigned char)i;
t->byte_pieces[i * 2 + 1] = '\0';
}
// read in the file
FILE *file = fopen(tokenizer_path, "rb");
if (!file) {
fprintf(stderr, "couldn't load %s\n", tokenizer_path);
exit(EXIT_FAILURE);
}
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) {
fprintf(stderr, "failed read\n");
exit(EXIT_FAILURE);
}
int len;
for (int i = 0; i < vocab_size; i++) {
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) {
fprintf(stderr, "failed read\n");
exit(EXIT_FAILURE);
}
if (fread(&len, sizeof(int), 1, file) != 1) {
fprintf(stderr, "failed read\n");
exit(EXIT_FAILURE);
}
t->vocab[i] = (char *)malloc(len + 1);
if (fread(t->vocab[i], len, 1, file) != 1) {
fprintf(stderr, "failed read\n");
exit(EXIT_FAILURE);
}
t->vocab[i][len] = '\0'; // add the string terminating token
}
fclose(file);
}
void free_tokenizer(Tokenizer *t) {
for (int i = 0; i < t->vocab_size; i++) {
free(t->vocab[i]);
}
free(t->vocab);
free(t->vocab_scores);
free(t->sorted_vocab);
}
char *decode(Tokenizer *t, int prev_token, int token) {
char *piece = t->vocab[token];
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
// parse this and convert and return the actual byte
unsigned char byte_val;
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
piece = (char *)t->byte_pieces + byte_val * 2;
}
return piece;
}
void safe_printf(char *piece) {
// piece might be a raw byte token, and we only want to print printable chars or whitespace
// because some of the other bytes can be various control codes, backspace, etc.
if (piece == NULL) {
return;
}
if (piece[0] == '\0') {
return;
}
if (piece[1] == '\0') {
unsigned char byte_val = piece[0];
if (!(isprint(byte_val) || isspace(byte_val))) {
return; // bad byte, don't print it
}
}
printf("%s", piece);
}
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = {.str = str}; // acts as the key to search for
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
return res != NULL ? res->id : -1;
}
void encode(Tokenizer *t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
if (text == NULL) {
fprintf(stderr, "cannot encode NULL text\n");
exit(EXIT_FAILURE);
}
if (t->sorted_vocab == NULL) {
// lazily malloc and sort the vocabulary
t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
for (int i = 0; i < t->vocab_size; i++) {
t->sorted_vocab[i].str = t->vocab[i];
t->sorted_vocab[i].id = i;
}
qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
}
// create a temporary buffer that will store merge candidates of always two consecutive tokens
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
char *str_buffer = malloc((t->max_token_length * 2 + 1 + 2) * sizeof(char));
size_t str_len = 0;
// start at 0 tokens
*n_tokens = 0;
// add optional BOS (=128000) token, if desired
if (bos)
tokens[(*n_tokens)++] = 128000;
// add_dummy_prefix is true by default
// so prepend a dummy prefix token to the input string, but only if text != ""
// TODO: pretty sure this isn't correct in the general case but I don't have the
// energy to read more of the sentencepiece code to figure out what it's doing
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
// Code point ↔ UTF-8 conversion
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
// U+0000 U+007F 0xxxxxxx
// U+0080 U+07FF 110xxxxx 10xxxxxx
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
// process the raw (UTF-8) byte sequence of the input string
for (char *c = text; *c != '\0'; c++) {
// reset buffer if the current byte is ASCII or a leading byte
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
// 0x80 is 10000000
// in UTF-8, all continuation bytes start with "10" in first two bits
// so in English this is: "if this byte is not a continuation byte"
if ((*c & 0xC0) != 0x80) {
// this byte must be either a leading byte (11...) or an ASCII char (0x...)
// => reset our location, as we're starting a new UTF-8 codepoint
str_len = 0;
}
// append the current byte to the buffer
str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
str_buffer[str_len] = '\0';
// while the next character is a continuation byte, continue appending
// but if there are too many of them, just stop to avoid overruning str_buffer size.
if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) {
continue;
}
// ok c+1 is not a continuation byte, so we've read in a full codepoint
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1) {
// we found this codepoint in vocab, add it as a token
tokens[(*n_tokens)++] = id;
} else {
// byte_fallback encoding: just encode each byte as a token
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
// so the individual bytes only start at index 3
for (int i = 0; i < str_len; i++) {
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
}
}
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
}
// merge the best consecutive pair or triple each iteration, according to the scores in vocab_scores
while (1) {
float best_score = -1e10;
int best_id = -1;
int best_idx = -1;
int best_len = 2; // length of the best merge sequence (2 for pair, 3 for triple)
// first, try to find the best pair to merge
for (int i = 0; i < (*n_tokens - 1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i + 1]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = t->vocab_scores[id];
best_id = id;
best_idx = i;
}
}
// if no pair was found, try to find the best triple to merge
if (best_idx == -1) {
for (int i = 0; i < (*n_tokens - 2); i++) {
// check if we can merge the triple (tokens[i], tokens[i+1], tokens[i+2])
sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i + 1]], t->vocab[tokens[i + 2]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge triple exists in vocab! record its score and position
best_score = t->vocab_scores[id];
best_id = id;
best_idx = i;
best_len = 3;
}
}
}
if (best_idx == -1) {
break; // we couldn't find any more pairs or triples to merge, so we're done
}
// merge the consecutive pair or triple (best_idx, best_idx+1[, best_idx+2]) into new token best_id
tokens[best_idx] = best_id;
// delete token(s) at position best_idx+1 (and optionally best_idx+2), shift the entire sequence back
for (int i = best_idx + 1; i < (*n_tokens - best_len + 1); i++) {
tokens[i] = tokens[i + best_len - 1];
}
(*n_tokens) -= (best_len - 1); // token length decreased by the number of merged tokens minus one
}
// add optional EOS (=128001) token, if desired
if (eos)
tokens[(*n_tokens)++] = 128001;
free(str_buffer);
}
// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling
typedef struct {
int vocab_size;
ProbIndex *probindex; // buffer used in top-p sampling
float temperature;
float topp;
unsigned long long rng_state;
} Sampler;
int sample_argmax(float *probabilities, int n) {
// return the index that has the highest probability
int max_i = 0;