Skip to content

Commit

Permalink
enforce max context_length_estimate <= max n_positions
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 3e9481845f4cb4ac6b0e6af94383e48bb56e4722
  • Loading branch information
aws-yishanm authored and hannanjgaws committed Dec 23, 2024
1 parent c9359b9 commit f4e3818
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/transformers_neuronx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,12 @@ def from_pretrained(cls, pretrained_model_path, *model_args, **kwargs):
def _sanity_check(**kwargs):
context_length_estimate = kwargs.get("context_length_estimate", None)
n_positions = kwargs.get("n_positions", 2048)
max_n_pos = max(n_positions) if isinstance(n_positions, list) else n_positions
max_cle = max(context_length_estimate) if isinstance(context_length_estimate, list) else context_length_estimate
# max_n_pos or max_cle could be None if customer intends to use defaults
if isinstance(max_n_pos, int) and isinstance(max_cle, int):
assert max_n_pos >= max_cle, \
f"Max context_length_estimate {max_cle} cannot be more than max n_positions {max_n_pos}."
neuron_config = kwargs.get("neuron_config", None)
bsh_cache_layout = False
if neuron_config is not None:
Expand Down

0 comments on commit f4e3818

Please sign in to comment.