-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllama.py
128 lines (102 loc) · 3.73 KB
/
llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import create_additive_causal_mask
from mlx_lm.models.llama import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock
@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__() # Ensure parent initializations are respected
if isinstance(self.shard, Shard):
return
if not isinstance(self.shard, dict):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
self.shard = Shard(**self.shard)
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(TransformerBlock(args=args))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
if self.args.shard.is_first_layer():
h = self.embed_tokens(inputs)
else:
h = inputs
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(h.shape[1], cache[0].offset if cache is not None else 0)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
if self.args.shard.is_last_layer():
h = self.norm(h)
return h
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = LlamaModel(args)
if self.args.shard.is_last_layer():
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.shard.is_last_layer():
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
shard_state_dict = {}
for key, value in weights.items():
if "self_attn.rotary_emb.inv_freq" in key:
continue
if key.startswith('model.layers.'):
layer_num = int(key.split('.')[2])
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
shard_state_dict[key] = value
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
shard_state_dict[key] = value
elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
shard_state_dict[key] = value
elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
shard_state_dict[key] = value
elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
shard_state_dict[key] = value
return shard_state_dict
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads