Can RLHF with Preference Optimization Techniques Help LLMs Surpass GPT4-Quality Models?
LLaMA-3, Supervised finetuning (SFT), RLHF, Preference-Alignment, DPO, ORPO, Instruction-Following Models
The intro
In recent research, many different finetuning (FT) techniques claim to outperform the latest state-of-the-art advanced models like GPT4 for a specific task at hand. I was able to sucessfully use some of these techniques to significantly improve LLM performance on news articles' classification on bank-speficic and systemic risk events while incorporating the reasoning necessary to make that decision.
The work is done in the context of news articles about banking risk, where smaller language models like BERT-based models lack the necessary reasoning for complex tasks. Through this comprehensive evaluation, I aim to provide insights into the effectiveness of different FT methods, including Reinforcement Learning with Human Feedback (RLHF) methods and their potential to enhance NLP performance across text classification and generation tasks.
This blog post not only shows how these experiments can be done with limited hardware resources, but also compares its performance to GPT4-32k(0613) - spoiler alert: I was able to obtain better performance with the FT models :).
Previous work
Previous iterations of my work utilized prompt engineering to enhance a BERT-based model, which initially exhibited an F1-score of 55-60%.
This performance was hindered by two primary issues:
- Lack of Representativeness: The dataset comprised only one year of data.
- Unbalanced Dataset: Classes representing systemic and bank-specific risks were underrepresented.
To address these issues, GPT-3.5-turbo and prompt engineering were employed to augment the dataset for minority classes and subsequently retrain the BERT-based model. While prompt engineering proved effective in many cases, the model's performance was highly sensitive to minor modifications in the prompt. Additionally, few-shot prompts often yielded suboptimal results.
Finetuning techniques
Starting point
The initial performance metrics are displayed below. The LLMs compared with the BERT-based model have been tuned soleny via prompt engineering techniques.
SFT and preference optimization
The image below highlights different tuning techniques that can be applied on top of the base model or even a finetuned model to improve the model performance.
Typically, full fine-tuning updates all weights of the base model during training, but this becomes unfeasible with larger models. Thus, we will use QLoRa, a Parameter Efficient Fine-Tuning (PEFT), which freezes the base model and trains only a few additional parameters, in a LoRa quantized version (e.g., 4-bit models), further reducing memory requirements and maintaining efficiency.
Proximal Policy Optimization (PPO) is an alignment with reinforcement learning algorithm that maximizes rewards. It uses three models for fine-tuning LLMs: (1) The active model (the main model being fine-tuned), (2) The value model (which estimates the value of different actions to guide learning), and (3) The reference model (which provides a baseline to prevent updates from being too drastic). PPO tries to balance improving the policy and keeping changes within reasonable bounds (read the PPO paper for more insights).
However given the cost implications to run this approach (load three models) I will be focusing the work on RLHF techniques that, instead of the traditional set of [instructions, input, output], uses prefered and rejected outputs ([instructions, input, desired output, rejected output]. In particular Direct Preference Optimization (DPO) and Odds Ratio Preference Optimization (ORPO) will use the new dataset format to align the LLM outputs with the preferences from the Experts.
DPO introduces a new objective where the original model learns to optimize preferences directly using a dataset with human preferences. It uses two models: the active model (the main model being fine-tuned) and the reference model. The formula uses the log probabilities of the desired responses and the log probabilities of the rejected ones, aiming to maximize the difference. This method is differentiable, allowing for backpropagation and learning towards this goal.
Lastly, another method explored in this article is the ORPO technique. It changes the goal by integrating a penalty for unwanted generation styles into the loss function. This penalty quantifies the preference disparity between chosen and rejected answers, which helps in increasing performance and refining the specific style of the answer. For detailed insights, refer to the Odds Ratio Preference Optimization paper.
Methodology
Below, you can see the methodology pipeline used to gather the necessary training data for the experiment.
An example from the training dataset is shown in the table below (formatted for readability).
role | content |
---|---|
system | You are an expert in assessing bank-specific and systemic risks in the European Banking Sector. |
user | I will send you a request. To answer that, use as context the following list of important TARGET_BANKS = ['BNP PARIBAS', 'CREDIT SUISSE', 'UBS'] |
assistant | I'm ready! Please go ahead and send your request. I'll use the list of TARGET_BANKS = ['BNP PARIBAS', 'CREDIT SUISSE', 'UBS'] as context to answer your question. |
user | Assign a label of Bank-Specific Risk (limited set of banks impacted), Systemic Risk (the whole European Banking Sector can be impacted, directly or indirectly) or not applicable for the following news snippet: INPUT={ News headline: "UBS boss Ermotti flags Credit Suisse client migration challenge"; Snippet: "UBS, which acquired Credit Suisse in March 2023 after it collapsed due to a series of financial setbacks, recently began transferring the fallen bank's clients onto its system. Ermotti said client migration was the most difficult stage of the integration, and that UBS would shortly move across accounts based in Singapore. Last month, UBS said it had completed those transfers in Luxembourg and Hong Kong." } |
assistant | News headline: "UBS boss Ermotti flags Credit Suisse client migration challenge", Snippet: "UBS, which acquired Credit Suisse in March 2023 after it collapsed due to a series of financial setbacks, recently began transferring the fallen bank's clients onto its system. Ermotti said client migration was the most difficult stage of the integration, and that UBS would shortly move across accounts based in Singapore. Last month, UBS said it had completed those transfers in Luxembourg and Hong Kong.", Label: "Bank-Specific Risk (limited set of banks impacted)", Reasoning: "The client migration challenge mentioned in the news snippet is specific to the integration of Credit Suisse into UBS, which is a one-time event related to the acquisition of a single bank. While UBS and Credit Suisse are both supervised institutions, the challenge does not seem to have a broader impact on the entire European Banking Sector. The fact that UBS has already completed client transfers in Luxembourg and Hong Kong, and is still working on transferring accounts based in Singapore, suggests that this is a limited issue that is being addressed by the bank. Therefore, I would assign a label of Bank-Specific Risk (limited set of banks impacted) to this news snippet." |
Dataset remarks:
- Including (1) the definition of the type of risk, (2) the extra interaction with the assistant to give context of relevant banks, and (3) the request for a reasoning field in the prompt instructions has shown to deliver better results across multiple experiments.
- The training dataset (15k training samples, 2.5k validation samples, and 2.5k test samples) was reviewed by experts, focusing exclusively on the assigned labels. The reasoning field was generated using GPT-4 to avoid the significant time and effort required for manual creation.
- For training DPO and ORPO algorithms, the dataset was augmented with the reasoning from rejected samples, derived from LLM predictions that did not align with the labels validated by the experts.
Training Tips:
The HuggingFace team provides the TRL library, which simplifies the implementation of DPO and ORPO training techniques. To learn more about running SFT with preference optimization methods like DPO, check out this article by Philipp Schmid.
Evaluating the results
In this section I focus on identifying the best FT method for the specific classification problem by evaluating the previous described techniques. I investigate whether these methods can surpass the performance of BERT-based models on the classification task. Additionally, I explore the capability of these techniques to perform reasoning for summarization and question-answering (Q&A) tasks, leveraging the added knowledge obtained during fine-tuning.
This comparison highlights that SFT with DPO on top is the most balanced and effective approach for tasks requiring high precision, recall, and F1-score.
Analyzing examples
The results show that not only the LLaMA model is able to assign the correct classification label, it also effectively captures the nuances and reasoning behind its decisions, while the GPT family models failed in understanding and applying the required reasoning.
Important metrics to supervise during the FT experiments
To be able to understand if the models are indeed learning and improving, there are relevant metrics to watch out for. This monitoring can be performed using tools like TensorBoard (as shown below) or Weights & Biases (wandb) dashboards.
The image displays four plots that compare the performance of three techniques—ORPO, SFT with QLoRA, and DPO—based on metrics such as learning rates, rewards/margins, and accuracy across 600 training steps. Here's a detailed explanation:
The Top-Left Plot: Learning Rates. This plot compares the learning rates for three techniques:
- ORPO (5e-6): Starts with a low learning rate that gradually decays.
- SFT with QLoRA (2e-5): Begins with a moderately high learning rate, decaying steadily.
- DPO (5e-5): Initiates with the highest learning rate and decays faster than the others.
- The variations in learning rates indicate different optimization strategies tailored for these methods.
Top-Central Plot: DPO Rewards/Margins - This plot tracks the rewards/margins for the DPO method. The metric reflects the difference between rewards for chosen and rejected outputs.
- The upward trend signifies that the model is effectively learning to assign higher rewards to chosen outputs compared to rejected ones, improving its discriminative capability.
Top-Right Plot: ORPO Rewards/Accuracies - Tracks ORPO’s overall accuracy, which peaks at around 0.8.
Bottom-Right Plot: ORPO Rewards/Margins - This plot shows the margin between rewards for chosen and rejected outputs for ORPO.
- While accuracy looks high (around 0.8), the margin between chosen and rejected rewards remains very low, indicating potential challenges in distinguishing between the two categories despite strong performance on accuracy.
Overall Comparison: DPO appears superior in terms of learning to assign higher rewards to chosen outputs. Although ORPO achieves high accuracy, it struggles with differentiation, which may suggest that it did not fully learn the distinction between chosen and rejected outputs. These observations are consistent with the evaluation results presented in the charts.
Note: These dashboards are extremely useful for understanding what is happening during training. For instance, in the initial iterations with DPO, I realized that I was using a very low learning rate, which negatively impacted the model’s ability to learn effectively.
Conclusions
Final considerations on dataset size: In my initial experiments, I used a much smaller dataset (<5k samples). However, this was insufficient to deliver any improvements in the model's reasoning capabilities. While the model could predict the class accurately, it simultaneously lost its ability to perform well in instruction-based chat scenarios, highlighting the importance of using adequately large datasets for achieving balanced performance.
The current results (with the 20k samples dataset) shows that applying SFT with DPO is effective in improving the reasoning capabilities of the model for the task at hand, demonstrating that preference optimization techniques can significantly enhance the performance of LLMs for specific problems.
Future work will focus on broader evaluations to further explore the utility of FT LLMs for other complex generative scenarios.
Nevertheless, these performance results alone should not be the ultimate deciding factor when determining an optimal approach. Key aspects such as deployment costs and the flexibility of the solution must be carefully analyzed, especially if the model is intended for production setups. Further exploration into combining RLHF with alternative preference optimization strategies or hybrid methods could offer promising pathways for future research and application.