Last month I officially released BaldEagle, an open-sourced repo to easily train EAGLE speculative decoding models that can obtain more than 3x faster inference.
Since then, I've received many questions on how to train your own EAGLE model. Today I will walk through this process using Qwen2.5-7B-Instruct
as an example.
Data Generation
The first step is to generate our training data in the form of hidden_states
from the target model. We will have to make a few changes in generate_data.py
for the new model.
First, we change the assistant_header
and user_header
in the Tokenizer section. This is used to compute loss_mask
so that the model is only trained on assistant generations and not system and user prompts.
# ------------------------ 2. Tokenizer ------------------------
# This step tokenizes the conversation and creates the loss mask
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Special token sequences used to identify different parts of the conversation
# For Llama models
# assistant_header = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
# user_header = "<|eot_id|><|start_header_id|>user<|end_header_id|>"
# For Qwen models
assistant_header = "<|im_start|>assistant\n"
user_header = "<|im_start|>user\n"
Currently, shareGPT and Ultrachat datasets are supported by default, with Mixture-of-Thoughts being experimental. These can be specified with the --dataset
parameter when launching the data generation job. Alternatively, if you want to train on your own custom dataset, you can modify Section 1 to load and format the dataset into the messages format by following the examples already provided.
To start the data generation job, it is recommended to use allocation.py
to split the job across multiple GPUs. Modify the file as needed to scale the job across more or less GPUs:
python allocation.py --outdir {output_directory} --model_name Qwen/Qwen2.5-7B-Instruct --dataset sharegpt
After data generation, you can run view_data.py
to verify that the loss_mask is generated correctly.
python view_data.py --data-path {path_to_data_n.ckpt} --tokenizer Qwen2.5-7B-Instruct
I've uploaded sharegpt and ultrachat data for Qwen2.5-7B-Instruct
here:
https://huggingface.co/datasets/NickL77/Qwen2.5-7B-BaldEagle-ShareGPT
https://huggingface.co/datasets/NickL77/Qwen2.5-7B-BaldEagle-Ultrachat
Notes
This data generation follows EAGLE 1 paper to use fixed data, rather than having the target model make generations
EAGLE 3 trains on generations and not fixed dataset
Training
Before training, we need to download the target model so we can use it's embedding and projection layer weights:
>>> from huggingface_hub import snapshot_download
>>> snapshot_download(repo_id="Qwen/Qwen2.5-7B-Instruct", local_dir="train/models/qwen2-5-7b")
Then we can modify necessary parts of the training script to train for Qwen2.5-7B-Instruct
.
First we change
path="models/qwen2-5-7b"
so that we can load the model weights we just downloaded.Next we change some model config parameters to match our model
model_args = LlamaConfig(
vocab_size=vocab_size,
hidden_size=hidden_dim,
intermediate_size=12288, # This is modified to match Qwen2.5
num_hidden_layers=1,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
num_attention_heads=28, # This is modified to match Qwen2.5
num_key_value_heads=28, # This is modified to match Qwen2.5 (can also be 4 for GQA)
tie_word_embeddings=False,
)
We then load from our generated data paths. Feel free to change the train/test split accordingly.
sharegpt_datapaths = list_local_files("{path_to_sharegpt_data}")
ultra_chat_datapaths = list_local_files("{path_to_ultrachat_data}")
Make any changes to
TrainingArguments
andEagleTrainer
as needed.Finally start training with
python train.py
Here, I’ve linked my trained draft model: https://huggingface.co/NickL77/BaldEagle-Qwen-2.5-7B-Instruct
Notes
Even though we're training a Qwen2.5 model, we can still use the Llama architecture for draft model. It's not yet proven that the draft model architecture needs to match the target model but it's worth experimenting.
Benchmarking
Once the training is done or checkpoints are saved, we can benchmark the model.
In one terminal we start the sglang server. Here we use the model I’ve uploaded to Huggingface, but you should use your model checkpoint.
python -m sglang.launch_server \
--model Qwen/Qwen2.5-7B-Instruct \
--speculative-algo EAGLE \
--speculative-draft NickL77/BaldEagle-Qwen-2.5-7B-Instruct \
--speculative-num-steps 5 \
--speculative-eagle-topk 8 \
--speculative-num-draft-tokens 64 \
--dtype bfloat16 \
--port 30000 \
--mem-fraction-static 0.65
then in another terminal:
python benchmark/bench_sglang_eagle_double_turn.py --questions 50 --parallel 1
On an RTX 3090, we get the following results:
#questions: 50, Throughput: 106.22 token/s, Acceptance length: 3.55
runtime: 6 min 17 sec
Note: We're benchmarking on 50 questions out of 80 due to an SGLang issue when running speculative decoding for long periods: https://github.com/sgl-project/sglang/issues/6309
Baseline
Now let’s run the baseline. Start the sglang server with:
python -m sglang.launch_server \
--model Qwen/Qwen2.5-7B-Instruct \
--dtype bfloat16 \
--port 30000 \
--mem-fraction-static 0.65
then in another terminal:
python benchmark/bench_sglang_eagle_double_turn.py --questions 50 --parallel 1
and we get
#questions: 50, Throughput: 50.43 token/s, Acceptance length: 1.00
runtime: 12 min 57 sec
We see that with the trained draft model, that we achieve a 2.06x speed up (50.43 tok/s -> 106.22 tok/s) on Qwen2.5-7B-Instruct!