Skip to content

Commit

Permalink
Llama image (#3167)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchez-alex authored Jul 18, 2024
1 parent ebf7dda commit ef5569e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,16 @@ RUN pip install git+https://github.com/stanford-futuredata/megablocks.git@5897cd
# RUN pip install -e .

# When copied to assets repo, change to install from public pypi
RUN pip install llm-optimized-inference==0.2.2 --no-cache-dir
RUN pip install llm-optimized-inference==0.2.3 --no-cache-dir
RUN pip uninstall transformers -y

COPY ./transformers.patch ./transformers.patch
RUN git clone https://github.com/huggingface/transformers.git && \
cd transformers && \
git checkout fc35907f95459d7a6c5281dfadd680b6f7b620e3 && \
git apply ../transformers.patch && \
python setup.py bdist_wheel && \
pip install dist/*.whl

# clean conda and pip caches
RUN rm -rf ~/.cache/pip
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 5c0c57f3e..f94a4cb37 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -73,6 +73,29 @@ class LlamaRMSNorm(nn.Module):

ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)

+def apply_scaling(freqs: torch.Tensor):
+ # Values obtained from grid search
+ scale_factor = 8
+ low_freq_factor = 1
+ high_freq_factor = 4
+ old_context_len = 8192 # original llama3 length
+
+ low_freq_wavelen = old_context_len / low_freq_factor
+ high_freq_wavelen = old_context_len / high_freq_factor
+ new_freqs = []
+ for freq in freqs:
+ wavelen = 2 * math.pi / freq
+ if wavelen < high_freq_wavelen:
+ new_freqs.append(freq)
+ elif wavelen > low_freq_wavelen:
+ new_freqs.append(freq / scale_factor)
+ else:
+ assert low_freq_wavelen != high_freq_wavelen
+ smooth = (old_context_len / wavelen - low_freq_factor) / (
+ high_freq_factor - low_freq_factor
+ )
+ new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
+ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
@@ -82,6 +105,7 @@ class LlamaRotaryEmbedding(nn.Module):
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
+ inv_freq = apply_scaling(inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings

0 comments on commit ef5569e

Please sign in to comment.