Skip to content

Latest commit

 

History

History
93 lines (72 loc) · 4.03 KB

README.md

File metadata and controls

93 lines (72 loc) · 4.03 KB

LLM Sampling Library

Overview

The LLM Sampling Library is a Python package designed to facilitate text generation using various sampling methods with large language models (LLMs). This library provides a simple command-line interface (CLI) for users to generate text based on input prompts, utilizing models from the Hugging Face Transformers library.

Features

In all implementations, temperature scaling is applied to the logits before any sampling methods, following the GPT-2 implementation and HuggingFace's implementation.

Installation

pip install -r requirements.txt

Usage

Basic usage:

python generate.py --model <model_name> --prompt "<input_prompt>" --apply-chat-template --temperature <temperature> --method <sampling_method> --max_new_tokens <max_new_tokens> --hf-token <hugging_face_token> --dtype <data_type>

Parameters

  • --model: The path or name of the Hugging Face model to use
  • --prompt: The input sequence for the model
  • --prompt_file: Alternative to --prompt, load prompt from a file
  • --temperature: Sampling temperature (default: 1.0)
  • --method: Sampling method to use (see list below)
  • --max_new_tokens: Maximum number of new tokens to generate (default: 500)
  • --hf-token: Your Hugging Face token for model access
  • --dtype: Data type for model (bfloat16, float16, float32)
  • --seed: Random seed for reproducibility

Method-specific parameters:

  • --top_k: K value for top-k sampling
  • --top_p: P value for nucleus sampling
  • --min_p: Threshold for min-p sampling
  • --epsilon: Epsilon value for epsilon/eta sampling
  • --beam_width: Beam width for beam search
  • --typical_p_mass: Mass parameter for typical sampling
  • --json_schema: Schema file path for constrained JSON sampling
  • --draft-model: Path to draft model for speculative sampling
  • --medusa-model-heads: Path to Medusa model heads
  • --lookahead: Lookahead parameter for speculative sampling

Examples

Top-k sampling:

python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Tell me a story" --method top_k --top_k 50 --temperature 0.7

Nucleus sampling:

python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Write a poem" --method top_p --top_p 0.9 --temperature 0.8

Speculative sampling:

python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Explain quantum physics" --method speculative --draft-model meta-llama/Llama-3.2-1B-Instruct --lookahead 4

Medusa decoding:

python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Write code for merge sort" --method medusa --medusa-model-heads <path_to_heads>

Acknowledgements

  • Medusa for the Medusa decoding implementation
  • vLLM used as the reference for the Medusa heads architecture
  • Jsonformer for the constrained JSON decoding implementation
  • HuggingFace Transformers for correctness evaluation