A busy llama studying for a professional certificate (SDXL v 1.0)
Large language models (LLMs) have complexified the process of fine-tuning models in NLP. Initially, when models like ChatGPT first popped on the scene, the foremost approach involved training a reward model first and optimising the LLM policy. Reinforcement Learning from Human Feedback (RLHF) pushed the needle significantly and moved aside many long-fought challenges in NLP. However, it was hard work, requiring appropriate and relevant data. As well as complex multi-model architectures. Furthermore, the improved quality was only sometimes evident, and models acquired a tendency to imitate and hallucinate.
However, recent advancements have introduced simpler and more efficient methods. One such method is Direct Preference Optimization (DPO).
What is Direct Preference Optimization (DPO)?
DPO is a method introduced to achieve precise control over LLMs. Reinforcement Learning from Human Feedback (RLHF) was based on training a Reward Model and then using Proximal Policy Optimization (PPO) to align the language model's output with human preferences. This method, while effective, was complex and unstable.
DPO, on the other hand, treats the constrained reward maximization problem as a classification problem on human preference data. This approach is stable, efficient, and computationally lightweight. It eliminates the need for reward model fitting, extensive sampling, and hyperparameter tuning.
How Does DPO Work?
The DPO pipeline can be broken down into two main stages:
- Supervised Fine-tuning (SFT): This is the initial step where the model is fine-tuned on the dataset(s) of interest.
- Preference Learning: After SFT, the model undergoes preference learning using preference data, ideally from the same distribution as the SFT examples.
The beauty of DPO lies in its simplicity. Instead of training a reward model first and then optimizing a policy based on that, DPO directly defines the preference loss as a function of the policy. This means that there's no need to train a reward model first.
During the fine-tuning phase, DPO uses the LLM as a reward model. It optimizes the policy using a binary cross-entropy objective, leveraging human preference data to determine which responses are preferred and which are not. By comparing the model's responses to the preferred ones, the policy is adjusted to enhance its performance.
Supervised Fine-tuning
Supervised fine-tuning (SFT) is the first step of DPO. SFT is a specialized method where an LLM is further trained on a labelled dataset. This dataset provides a clear mapping between specific inputs and the desired outputs. The essence of SFT, especially when combined with preference learning, is to mould the model's responses based on human-defined criteria, ensuring it aligns more closely with specific requirements.
Imagine a company looking to deploy a conversational AI to assist users in navigating their new application. While an off-the-shelf LLM like Falcon-7B might provide technically accurate answers, it might not resonate with the company's tone or branding. For instance, if a user asks about a feature like "Collaborative Editing," Falcon might offer a generic description. However, for a seamless user experience, the response should be user-friendly, detailed, and even offer troubleshooting tips. SFT refines the model's outputs to ensure they're not just accurate but also appropriate and consistent.
Understanding Preference Data in NLP
Preference data is a curated set of options or alternatives related to a specific prompt. These options are then evaluated by annotators based on certain guidelines. The goal is to rank these options from the most preferred to the least preferred. This ranking provides insights into human preferences used to fine-tune models to produce outputs that align with human expectations.
The process of creating Preference Data contains a few steps:
Prompt Selection
The foundation of PD is the prompt. There are various strategies to select prompts. Some might opt for a predefined set, while others might use templates to generate prompts dynamically. Another approach is to amalgamate predefined prompts with random ones sourced from databases.
Answer Selection
Once the prompt is decided, the next step is to determine the answers. These answers can be generated from a specific version of a model or various checkpoints. The number of answers to be ranked can vary. While some might prefer a binary ranking system (best-worst), others might opt for a more granular approach, ranking answers on a scale, say from 1 to 5.
Annotation Guidelines
It's imperative to have clear annotation guidelines. These guidelines ensure that the ranking process is standardized and minimizes individual biases or interpretations.
Public Datasets for Preference Data
Several datasets are available for those looking to dive into Preference Data. For instance:
- OpenAI WebGPT Comparisons: This dataset offers 20k comparisons, each consisting of a question, a pair of model answers, and human-rated preference scores for each answer.
- OpenAI Summarization: This dataset provides 64k text summarization examples, inclusive of human-written responses and human-rated model responses.
- Reddit ELI5: Sourced from Q&A subreddits, this dataset boasts 270k examples of questions, answers, and scores.
- Human ChatGPT Comparison Corpus (HC3): This dataset provides 60k human answers and 27K ChatGPT answers for approximately 24K questions.
Implementing DPO with TRL: A Step-by-Step Guide
For those keen on harnessing the power of Direct Preference Optimization (DPO), the TRL (Transformer Reinforcement Learning) library offers a streamlined approach with its DPO Trainer. Here's a comprehensive guide to get you started:
Supervised Fine-Tuning (SFT)
Begin with training your SFT model. It's imperative to ensure the data used for SFT is in-distribution, setting the stage for the DPO algorithm to work effectively.
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer
dataset = load_dataset("your-domain-dataset", split="train")
model = AutoModelForCausalLM.from_pretrained("your-foundation-model-of-choice")
trainer = SFTTrainer(
model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
)
trainer.train()
Understanding the Dataset Format
The DPO trainer mandates a specific dataset format. Given that the model is trained to optimize the preference between two sentences directly, the dataset should reflect this structure. The dataset should comprise three key entries: prompt, chosen, and rejected.
- prompt: Contains the context inputs.
- chosen: Houses the corresponding chosen responses.
- rejected: Lists the corresponding negative (or rejected) responses.
Notably, a single prompt can correspond to multiple responses reflected in the repeated entries in the dataset's arrays.
For instance, a sample dataset might look like:
dpo_dataset_dict = {
"prompt": ["hello", "how are you", ...],
"chosen": ["hi, nice to meet you", "I am fine", ...],
"rejected": ["leave me alone", "I am not fine", ...],
}
Leveraging the DPOTrainer
To kickstart the process, initialize the DPOTrainer. This involves specifying the model to be trained, a reference ref_model
(used to compute the implicit rewards for preferred and rejected responses), the beta
hyperparameter for the implicit reward, and the dataset with the three aforementioned entries.
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
Once set up, initiate the training with:
dpo_trainer.train()
Do note that the beta
parameter is the temperature parameter for the DPO loss, typically ranging between 0.1 to 0.5. As beta
approaches 0, the reference model is disregarded.
Monitoring with Logging
During the training and evaluation phases, several reward metrics are recorded to gauge the model's performance. These include:
- rewards/chosen: Represents the mean difference between the log probabilities of the policy model and the reference model for the chosen responses, scaled by
beta
. - rewards/rejected: Denotes the mean difference between the log probabilities of the policy model and the reference model for the rejected responses, scaled by
beta
. - rewards/accuracies: Indicates the mean frequency of instances where the chosen rewards surpass the corresponding rejected rewards.
- rewards/margins: Captures the mean difference between the chosen and corresponding rejected rewards.
By following this guide, you'll be well-equipped to implement DPO using the TRL library, optimizing your language models to align more closely with human preferences.
Conclusion
Fine-tuning LLMs is an essential process for generating text that aligns with specific guidelines. While the process has been complex, methods like DPO have simplified the approach, making it more accessible and efficient. With DPO, not only is the process streamlined, but it also offers a stable and computationally lightweight method to achieve precise control over LLMs. As the machine learning community continues to grow and evolve, it's exciting to see such advancements that make the journey of model training and optimization smoother and more efficient.