Chapter 5 — Fine-tuning & Alignment ⭐
Part 1: Supervised Fine-tuning (SFT) — From Raw Model to Helpful Assistant
This is the most practically important chapter. Pre-training gives you a "completion machine." SFT turns it into an assistant. Even tiny SFT datasets can produce dramatic behavior changes.
The Gap: Pretrained vs. Assistant
A freshly pretrained LLM (like LLaMA base) is a completion machine. It predicts the next token, full stop. Give it a question, and it'll complete it — sometimes by generating more questions, sometimes by writing Wikipedia-like text, sometimes by following the pattern it learned from web forums.
Prompt: "What is the capital of France?"
Raw LLM completion (base model):
"What is the capital of France?
What is the capital of Germany?
What is the capital of Italy?
..." ← It continues the pattern of quiz questions!
SFT-tuned LLM:
"The capital of France is Paris." ← Actually answers
SFT teaches the model: "when you see a question/instruction, generate a helpful response."
The SFT Data Format
The key is the data format — specifically, how you structure the conversation in the training data.
ChatML Format (OpenAI standard, used by LLaMA 3, Mistral):
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is the capital of France?<|im_end|>
<|im_start|>assistant
The capital of France is Paris, a city known for the Eiffel Tower and its rich cultural history.<|im_end|>
Alpaca Format (LLaMA 1 era):
### Instruction:
What is the capital of France?
### Response:
The capital of France is Paris.
During Training:
Only the assistant/response portion contributes to the loss. The instruction part is used as input (context), but we don't penalize the model for not predicting it. This is implemented with a loss mask:
def create_masked_labels(input_ids, response_start_idx):
"""
Set labels to -100 (ignore) for the instruction portion.
Only compute loss on the response tokens.
"""
labels = input_ids.clone()
labels[:response_start_idx] = -100 # -100 = ignore in cross_entropy
return labels
# In PyTorch, F.cross_entropy ignores positions where target == -100
loss = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1), ignore_index=-100)
Why only train on the response? The instruction/system prompt appears in almost every training example (with minor variations). If you train on it too, the model would learn to generate system prompts — which is useless and wastes gradient updates on boilerplate text.
Data: What Makes a Good SFT Dataset?
Quality >> Quantity
The Alpaca paper showed 52,000 examples can produce a usable instruction-following model. The key is quality of those examples.
Good SFT data characteristics:
- Diverse — covers many task types: summarization, Q&A, coding, math, writing, analysis
- Clear instructions — the instruction should unambiguously specify the desired response
- Accurate responses — wrong answers in SFT directly teach the model to be wrong
- Appropriate length — response length should match the instruction complexity
- Natural — sounds like how a human would actually ask and respond
The LIMA paper (2023) — "Less Is More for Alignment": 1,000 carefully selected examples produced a model competitive with SFT on 52K examples. Quality of each example matters enormously.
Creating SFT Data
Option 1: Human-written — Most expensive, highest quality. Used in InstructGPT's first stage.
Option 2: GPT-4-generated — Cheap, scalable, good quality. Used in Alpaca, WizardLM, many open models.
# Prompt template for generating SFT data via GPT-4
prompt = """Generate an instruction and a high-quality response for it.
Format as JSON: {"instruction": "...", "response": "..."}
The instruction should be something a user might ask an AI assistant.
Topics: {topic}
Difficulty: {difficulty}
"""
Option 3: Existing task datasets — Convert NLP datasets (SuperGLUE, BIG-Bench, etc.) into instruction format.
Option 4: Self-instruct — Use the base LLM to generate its own training data. Start from a small seed set of hand-written examples, prompt the LLM to generate more, filter for quality.
The Training Recipe for SFT
# SFT training configuration (typical settings for 7B model)
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./sft_model",
# Number of passes over the dataset
# SFT typically uses 1-3 epochs only!
# More epochs → overfitting to the SFT data, forgetting pretraining
num_train_epochs=2,
# Batch size (effective = per_device × gradient_accumulation)
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # Effective batch: 4×4 = 16
# Learning rate: MUCH smaller than pretraining (1e-5 to 3e-5)
# Too high → catastrophic forgetting of pretrained knowledge
# Too low → no instruction-following behavior learned
learning_rate=2e-5,
# Cosine LR schedule with short warmup
lr_scheduler_type="cosine",
warmup_ratio=0.03,
# Mixed precision for efficiency
bf16=True,
# Save frequently
save_strategy="steps",
save_steps=100,
logging_steps=10,
# Gradient clipping (same as pretraining)
max_grad_norm=1.0,
)
Why such a low LR for SFT? The pretrained model already has enormous capabilities encoded in its weights. SFT should nudge it toward instruction-following behavior, not overwrite its knowledge. High LR causes "catastrophic forgetting" — the model forgets everything it learned during pretraining.
Catastrophic Forgetting and How to Avoid It
When you fine-tune on a new distribution, the model can forget its prior capabilities. SFT for instruction-following can cause:
- Loss of factual knowledge encoded during pretraining
- Degraded performance on tasks not in the SFT dataset
- Reduced calibration (the model becomes overconfident)
Mitigation strategies:
- Low learning rate: 2e-5 instead of 3e-4. Smaller updates preserve more pretrained knowledge.
- Few epochs: 1-3 epochs maximum. More epochs → more forgetting.
- Data mixing: Mix SFT data with a small amount (~5%) of pretraining data. Forces the model to maintain its pretraining capabilities.
- LoRA (see Part 2): Only update a small subset of parameters. Pretrained weights stay frozen.
- Diverse SFT data: A narrow SFT dataset (only coding questions) causes forgetting of non-coding capabilities. Use diverse task coverage.
Full Fine-tuning vs. LoRA: When to Use Which
| Aspect | Full Fine-tuning | LoRA |
|---|---|---|
| Params updated | All (7B for 7B model) | 0.1-1% (7M-70M) |
| GPU memory | 14GB model + 56GB optimizer states | ~16GB total (with quantization) |
| Training speed | 1× (baseline) | 1.5-2× faster |
| Quality | Marginally better | ~95-98% of full FT quality |
| Overfitting risk | Higher | Lower (fewer params to overfit) |
| Use case | Large compute budget, quality-critical | Limited GPU, most practical cases |
Rule of thumb: Use LoRA for everything unless you have a specific reason not to. The quality tradeoff is minimal and the compute savings are enormous.
Interview Corner Cases — SFT 🎯
- "What is catastrophic forgetting and how do you mitigate it in SFT?" → Neural networks tend to overwrite previous learning when trained on new data. For SFT: use very low LR (1e-5 to 3e-5), few epochs (1-3), mix in pretraining data, and prefer LoRA over full fine-tuning.
- "Why do we only compute loss on the response in SFT?" → Two reasons: (1) Computing loss on instructions would make the model try to predict the instruction format during training, which is wasteful and introduces noise. (2) Instruction text is often templated/repeated — training on it would bias the model toward reproducing templates rather than generating good responses.
- "What is instruction tuning vs. RLHF? Is SFT alone enough?" → SFT teaches the format: "instructions get responses." RLHF teaches the quality: "this response is better than that one." SFT alone often produces responses that are correctly formatted but not necessarily helpful, harmless, or honest. RLHF (or DPO) is needed for the final alignment step.
- "How many SFT examples do you need?" → The LIMA paper suggests 1,000 high-quality examples can achieve competitive results. Practically, 5K-100K is common. Quality matters much more than quantity — one carefully crafted example is worth 100 automatically generated mediocre ones.
- "What happens if you SFT a model on data with wrong answers?" → The model learns to confidently produce wrong answers. SFT doesn't teach the model to check its own accuracy — it just teaches behavior patterns. This is why data quality is so critical, and why RLHF/DPO (with human preference data) is needed to teach the model to prefer accurate over inaccurate responses.
- "Why is 2-3 epochs typical for SFT but pretraining uses 1 epoch?" → Pretraining datasets are huge (1T+ tokens) — one epoch already exposes the model to enormous diversity. SFT datasets are tiny (1K-100K examples) — you need multiple passes to learn from them, but not too many passes (overfitting/forgetting).
Next: LoRA and QLoRA: Fine-tuning on a Budget — The practical way to fine-tune without 8×A100s.