The LLM Sampling Library is a Python package designed to facilitate text generation using various sampling methods with large language models (LLMs). This library provides a simple command-line interface (CLI) for users to generate text based on input prompts, utilizing models from the Hugging Face Transformers library.
-
Multiple Sampling Methods: Implements various sampling techniques including:
- Unconstrained sampling
- Top-k sampling
- Top-p (nucleus) sampling
- Min-p sampling
- Typical sampling
- Epsilon sampling
- Eta sampling
- Beam search
- Chain-of-Thought (CoT) decoding
- Constrained JSON decoding
- Speculative sampling
- Medusa decoding
-
Chat Template Support: Optionally apply chat templates for Instruct models
-
Memory Efficient: Uses KV-cache for better memory usage
In all implementations, temperature scaling is applied to the logits before any sampling methods, following the GPT-2 implementation and HuggingFace's implementation.
pip install -r requirements.txt
Basic usage:
python generate.py --model <model_name> --prompt "<input_prompt>" --apply-chat-template --temperature <temperature> --method <sampling_method> --max_new_tokens <max_new_tokens> --hf-token <hugging_face_token> --dtype <data_type>
--model
: The path or name of the Hugging Face model to use--prompt
: The input sequence for the model--prompt_file
: Alternative to --prompt, load prompt from a file--temperature
: Sampling temperature (default: 1.0)--method
: Sampling method to use (see list below)--max_new_tokens
: Maximum number of new tokens to generate (default: 500)--hf-token
: Your Hugging Face token for model access--dtype
: Data type for model (bfloat16, float16, float32)--seed
: Random seed for reproducibility
Method-specific parameters:
--top_k
: K value for top-k sampling--top_p
: P value for nucleus sampling--min_p
: Threshold for min-p sampling--epsilon
: Epsilon value for epsilon/eta sampling--beam_width
: Beam width for beam search--typical_p_mass
: Mass parameter for typical sampling--json_schema
: Schema file path for constrained JSON sampling--draft-model
: Path to draft model for speculative sampling--medusa-model-heads
: Path to Medusa model heads--lookahead
: Lookahead parameter for speculative sampling
Top-k sampling:
python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Tell me a story" --method top_k --top_k 50 --temperature 0.7
Nucleus sampling:
python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Write a poem" --method top_p --top_p 0.9 --temperature 0.8
Speculative sampling:
python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Explain quantum physics" --method speculative --draft-model meta-llama/Llama-3.2-1B-Instruct --lookahead 4
Medusa decoding:
python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Write code for merge sort" --method medusa --medusa-model-heads <path_to_heads>
- Medusa for the Medusa decoding implementation
- vLLM used as the reference for the Medusa heads architecture
- Jsonformer for the constrained JSON decoding implementation
- HuggingFace Transformers for correctness evaluation