Soft Prompt Tuning Modern LLMs
Manual prompt engineering is time-consuming and often feels more like an art than a science. While many researchers argue that carefully crafted prompts can match fine-tuning performance, the process of iteratively refining natural language [hard]-prompts remains a significant bottleneck. Below, we demonstrate how soft prompt tuning can automate this optimization process, improving accuracy from 29% to 83% on the Banking-77 dataset using Llama 3.2 3B Instruct.
Introduced in the 2021 paper The Power of Scale for Parameter-Efficient Neural Network Training, soft-prompt tuning offers a simple but powerful insight: instead of constraining ourselves to manual prompt engineering with natural language tokens, we can use the backpropagation process to learn optimal prompt embeddings directly. As shown in the diagram below, while traditional fine-tuning updates the entire model and prompt engineering freezes the model but restricts us to natural language, soft prompt tuning offers a middle ground – we keep the model frozen while learning optimal prompts in the full embedding space.
This approach aligned with a broader trend toward parameter-efficient fine-tuning (PEFT) methods. The original paper showed that by only training a small prefix of learnable tokens, we could match the performance of full model fine-tuning. Similar PEFT methods like LoRA have since demonstrated comparable results while updating only a small fraction of the model's parameters, though soft prompt tuning offers a conceptually simpler approach by just prepending learned tokens rather than modifying the model's internal weights.
However, most subsequent research has focused on T5 models and benchmarked against the SuperGLUE dataset. Both the T5 models (2020) and SuperGLUE (2019) are quite outdated, with current open-source fine-tuning interest primarily focused on decoder-only models like Llama, Phi, and Mistral. There has been surprisingly little exploration of soft prompt tuning for modern decoder-only LLMs in both academia and the open-source community. At the time of writing, HuggingFace's PEFT implementation of prompt-tuning, prefix-tuning, and p-tuning (which are similar approaches) do not work with modern decoder-only LLMs.
Below, I walk through a practical implementation of soft prompt tuning on Llama 3.2 models using the Banking-77 intent classification dataset to demonstrate how automated prompt optimization can significantly outperform traditional hard-prompting approaches.
Dataset
Banking77 is an intent classification benchmark contains customer service queries in the banking domain each labeled as one of 77 fine-grained set of intents. This is a notable benchmark in current LLMs because 1) it is focused on a specific domain, which mirrors most use cases to finetune a pretrained large language model to specialize in a specific set of tasks and 2) it tests in-context learning (ICL) capabilities as a model will receive a substantial prompt with all 77 intent labels, in which the model needs to decide matches the intent of the user query. Recently the HELMET long-context benchmark has adopted Banking77 as one of the ICL tasks.
Code
Github: https://github.com/NickL77/SoftPromptTuning.
We start by loading in the Llama 3.2 1B model and tokenizer along with setting some tokenizer settings.
model_name = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
Llama tokenizer reserves 250 special tokens that we will use for our soft-prompt.
prefix_token_strs, prefix_token_ids = [], []
for i in range(251):
prefix_token_strs.append(f"<|reserved_special_token_{i}|>")
prefix_token_ids = tokenizer.convert_tokens_to_ids(prefix_token_strs)
Next, we load in the Banking77 dataset. Within the repo, we also define a mapping from the integer labels to the actual label name.
def map_labels(example):
example['label_str'] = banking77_label_map[example['label']]
return example
dataset = load_dataset("PolyAI/banking77")
train_dataset = dataset["train"]
test_dataset = dataset["test"]
train_dataset = train_dataset.map(map_labels)
test_dataset = test_dataset.map(map_labels)
train_dataset = train_dataset.shuffle(seed=42)
test_dataset = test_dataset.shuffle(seed=42)
Here we define the prompt template that will be provide all the names of the labels to the LLM as well as give instructions on the output format. In this case, we ask the model to output the predicted label within <answer></answer>
XML tags. We also write a function and quick test to parse this XML format using regex.
prompt_template = \
"""## Instructions
Classify the provided piece of text into one of the predefined classes.
## Classes
{classes}
## Output Format
Provide your answer in <answer></answer> XML tags. Output the xml tags and answer only.
## Input Text
{{text}}
## Answer""".format(classes="\n".join(banking77_label_map.values()))
def parse_value_from_xml_with_regex(xml_string, tag_name):
pattern = f'<{tag_name}>(.*?)</{tag_name}>'
match = re.search(pattern, xml_string, re.DOTALL) # re.DOTALL allows matching across multiple lines
if match:
return match.group(1)
else:
return ""
assert parse_value_from_xml_with_regex("<answer>foo</answer>", "answer") == "foo"
Now we run a baseline on the original Llama model. We first map the dataset into the chat template format, then we iterate through the first 300 samples and get a baseline.
def create_test_messages(row):
return {"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt_template.format(text=row["text"])}
]}
test_dataset_no_prefix = test_dataset.map(create_test_messages)
pred_ls, golden_ls = [], []
num_correct, num_total = 0, 0
for i in tqdm(range(300)):
messages = test_dataset_no_prefix[i]["messages"]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=128,
eos_token_id=terminators,
do_sample=True,
temperature=0.001,
top_p=0,
pad_token_id=tokenizer.eos_token_id,
)
response = outputs[0][input_ids.shape[-1]:]
response = tokenizer.decode(response, skip_special_tokens=False)
pred = parse_value_from_xml_with_regex(response, "answer")
pred_ls.append(pred)
golden_ls.append(test_dataset_no_prefix[i]["label_str"])
if pred == test_dataset_no_prefix[i]["label_str"]:
num_correct += 1
num_total += 1
accuracy = num_correct / num_total
print(f"Accuracy: {accuracy}")
After tokenizing and the forward pass, we use the parsing function to get the model’s prediction then compare it against the true label. Llama 3.2 3B correctly predicts 29% of the first 300 samples in the test dataset, while Llama 3.2 1B has a very low score of 0.66% with most errors due to not following the XML output format.
Now let’s actually train the soft-prompt. We’ll start by initiating some hyper-parameters; the most important of which is the length of the soft-prompt, in this we’re setting to 32.
NUM_SPECIAL_TOKENS_IN_PREFIX = 32
LEARNING_RATE = 2e-4
BATCH_SIZE = 4
WARMUP_RATIO = 0.1
WEIGHT_DECAY = 0.01
We then create two datasets, both with the soft-prefix, but one with the correct answer and format for training and one without the answer for testing the trained model.
prefix = "".join(prefix_token_strs[:NUM_SPECIAL_TOKENS_IN_PREFIX])
def create_prefix_messages(row):
return {"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prefix + prompt_template.format(text=row["text"])},
{"role": "assistant", "content": "<answer>" + row["label_str"] + "</answer>"}
]}
def create_prefix_messages_no_answer(row):
return {"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prefix + prompt_template.format(text=row["text"])}
]}
train_dataset = train_dataset.map(create_prefix_messages)
test_dataset_with_prefix = test_dataset.map(create_prefix_messages_no_answer)
Now, the juicy part. We freeze the entire model except for the embedding layer. We also want to freeze all other embeddings except those in the soft-prompt. Unfortunately, we can freeze part of an embedding layer and keep another part unfrozen, so we do a hack and zero out all gradients that don’t correspond to a soft-prompt token. We do this by adding a backward hook, which is a user-specified function that gets invoked every time the gradient is computed. You can learn more about this in the pytorch docs.
for param in model.parameters():
param.requires_grad = False
model.get_input_embeddings().weight.requires_grad = True
embeddings_to_update = torch.tensor(prefix_token_ids[:NUM_SPECIAL_TOKENS_IN_PREFIX], dtype=torch.long)
def grad_hook(grad):
mask = torch.zeros_like(grad)
mask[embeddings_to_update] = 1.0
masked_grad = grad * mask
return masked_grad
hook_handle = model.get_input_embeddings().weight.register_hook(grad_hook)
Finally, we initialize some HuggingFace boilerplate and start training the soft-prompt. Do note that we use the collator to specify training on completions only. Afterwards, we remove the backward hook.
response_template = "<|start_header_id|>assistant<|end_header_id|>"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
trainer = SFTTrainer(
model,
train_dataset=train_dataset,
data_collator=collator,
args = TrainingArguments(
per_device_train_batch_size = BATCH_SIZE,
gradient_accumulation_steps = 1,
warmup_ratio = WARMUP_RATIO,
num_train_epochs = 1,
learning_rate = LEARNING_RATE,
logging_steps = 16,
optim = "adamw_8bit",
weight_decay = WEIGHT_DECAY,
lr_scheduler_type = "linear",
seed = 3407,
gradient_checkpointing=True,
)
)
trainer.train()
hook_handle.remove()
After some time, we finish training the model. Now we can run benchmarks again on the test dataset that includes the prefix with very similar code to before.
pred_ls, golden_ls = [], []
num_correct, num_total = 0, 0
for i in tqdm(range(300)):
messages = test_dataset_with_prefix[i]["messages"]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=128,
eos_token_id=terminators,
do_sample=True,
temperature=0.001,
top_p=0,
pad_token_id=tokenizer.eos_token_id,
)
response = outputs[0][input_ids.shape[-1]:]
response = tokenizer.decode(response, skip_special_tokens=False)
pred = parse_value_from_xml_with_regex(response, "answer")
pred_ls.append(pred)
golden_ls.append(test_dataset_no_prefix[i]["label_str"])
if pred == test_dataset_no_prefix[i]["label_str"]:
num_correct += 1
num_total += 1
accuracy = num_correct / num_total
print(f"Accuracy: {accuracy}")
Here I’ll list some of the results I achieved with the hyper-parameters listed above.
Llama 3B 16 prefix tokens: 79%
Llama 3B 32 prefix tokens: 83%
Llama 3B 64 prefix tokens: 82.66%
Llama 1B 16 prefix tokens: 64.66%
Llama 1B 32 prefix tokens: 67.66%
Llama 1B 64 prefix tokens: 73.33%
We can see that after training, we achieve a huge improvements in the accuracy. If you remember, vanilla Llama 3B had a score of 29% and Llama 1B had a score of 0.66%. The 1B model is now able to follow the desired output format and 3B has improved by near 3x. Following intuition that longer prompts better steer the model, we see that for the most part, a long prefix results in higher scores.
Conclusion
Soft-prompt tuning demonstrates remarkable potential for modern language models, as shown by our experiments with Llama. By optimizing just 16-64 learnable tokens while keeping the base model frozen, we improved Llama 3B's accuracy on the Banking-77 classification task from 29% to 83% - a nearly threefold increase. This dramatic improvement, achieved with minimal parameter updates, validates soft prompt tuning as an efficient alternative to full model fine-tuning for modern decoder-only architectures. While previous research focused primarily on older encoder-decoder models like T5, our results show that soft prompt tuning can be effectively adapted for contemporary LLMs like Llama, filling an important gap in the field of parameter-efficient fine-tuning.