ORPO:Redefine RLHF

No Image
9 min read

Motivation

In today's world of Large Language Models (LLMs), there is a remarkable technique that allows these models to predict the next word based on previously seen words. This is achieved through an autoregressive model utilizing the innovative structure of transformer blocks. The interplay of numerous transformer blocks with the autoregressive technique enables these impressive capabilities.

This initial training phase is often referred to as Supervised Fine-Tuning (SFT). However, the results from the SFT phase may not always align perfectly with our exact requirements. To address this, various alignment strategies have been developed. Two notable methods are Reinforcement Learning from Human Feedback (RLHF) and Direct Preference Optimization (DPO). Recently, a new method called Optimized Reinforcement Policy Optimization (ORPO) has emerged, showing promising benchmark results.

In this blog post, we will delve into the mathematical aspects of ORPO, exploring its loss function and the model adjustments it entails to understand how it achieves these impressive results.

Alignment Techniques

Several techniques for preference alignment aim to mitigate the need for reinforcement learning, such as RLHF and DPO. RLHF involves first training a reward model and then optimizing the policy (text generation) to align with the reward model using the PPO algorithm. In contrast, DPO combines these two parts—1) the reward model and 2) PPO training for reinforcement learning—to align the LLM following structured preferences.

But why do we need these techniques? Why can't Supervised Fine-Tuning (SFT) alone achieve alignment, considering that cross-entropy loss in SFT should theoretically handle this? The issue lies in the SFT loss function, which penalizes rejected generations. The goal of cross-entropy loss in model fine-tuning is to penalize the model if the predicted logits for the reference answers are low:

\[\begin{align} loss = L = \frac{-1}{M} \sum_{k=1}^{M} \sum_{i=1}^{|V|} y_{i}^k \log(p_i^k) \end{align}\]

Here, \(y_{i}\) is a boolean value indicating if the ith token in the vocabulary set V is a label token, \(p_{i}\) refers to the probability of the ith token, and M is the length of the sequence. Using cross-entropy alone gives no direct penalty or compensation for the logits of non-answer tokens, as \(y_{i}\) will be set to 0.

To address this, ORPO introduces a concept called Odds Ratio to the cross-entropy loss to penalize non-answer or rejected answers. This modification helps ensure that the model is better aligned with desired outcomes. Hence, the method is named Odds Ratio Policy Optimization.

ORPO

Let's dive into the details of Odds Ratio Policy Optimization (ORPO). Imagine we have a set of parameters, \(\theta\), representing the LLM network that generates words. For each sequence, we can determine the likelihood of generating it.

sample dataset of ORPO:

orpo_dataset_dict = { "prompt": [ "hello", "how are you", "What is your name?", "What is your name?", "Which is the best programming language?", "Which is the best programming language?", "Which is the best programming language?", ], "chosen": [ "hi nice to meet you", "I am fine", "My name is Mary", "My name is Mary", "Python", "Python", "Java", ], "rejected": [ "leave me alone", "I am not fine", "Whats it to you?", "I dont have a name", "Javascript", "C++", "C++", ], }

Given an input sequence x, the average log-likelihood of generating the output sequence y of length m tokens is computed by \(p_{\theta}(y|x)\). The odds of generating \(y\) are given by:

\[ odds_{\theta}(y|x) = \frac{p_{\theta}(y|x)}{1 - p_{\theta}(y|x)} \]

Intuitively, if \(odds_{\theta}(y|x) = k\), it implies that it is ktimes more likely for the model \(\theta\) to generate the output sequence y than not generating it. The odds ratio between a preferred sequence \(y_w\) and a less preferred sequence \(y_l\) is:

\[ OR_{\theta}(y_w, y_l) = \frac{odds_{\theta}(y_w|x)}{odds_{\theta}(y_l|x)} \]

Optimization aims to increase the ratio of the chosen sequence to the rejected sequence. To construct the loss function, ORPO uses the log sigmoid function. This approach stabilizes the loss by converting the optimization problem into a minimization task:

\[ loss_{OR} = -\log (\sigma (\log \frac{odds_{\theta}(y_w|x)}{odds_{\theta}(y_l|x)})) \]

To integrate this with the total loss, ORPO combines it with the SFT loss using a coefficient \(\beta\):

\[ L_{ORPO} = L_{SFT} + \beta L_{OR} \]

In practice, we can build the \(\beta L_{OR}\) component using libraries for policy optimization. Suppose we have policy_chosen_logps and policy_rejected_logps, representing the log probabilities of the chosen and rejected sequences, respectively. The following Python code demonstrates how to implement the odds ratio loss:

import torch
import torch.nn.functional as F

def odds_ratio_loss(
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:

    log_odds = (policy_chosen_logps - policy_rejected_logps) - (
        torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
    )
    sig_ratio = F.sigmoid(log_odds)
    ratio = torch.log(sig_ratio)
    losses = self.beta * ratio

    chosen_rewards = self.beta * policy_chosen_logps.detach()
    rejected_rewards = self.beta * policy_rejected_logps.detach()

    return losses, chosen_rewards, rejected_rewards, torch.mean(ratio).item(), torch.mean(log_odds).item()

In this code, torch.log1p(-torch.exp(policy_chosen_logps)) computes \(1 - p\) from the log probability, and the subsequent log sigmoid and log operations apply to the log odds ratio to stabilize the loss computation.

Understanding the Gradient of ORPO

To better understand the loss expression \(L_{OR}\), let's delve into its gradient calculation. We'll start by showing the derivation of the sigmoid function's derivative, which is essential for this process.

The sigmoid function \(\sigma(x)\) is given by:

\[\begin{align} \sigma(x) = \frac{1}{1 + e^{-x}} \end{align}\]

To find its derivative, we proceed as follows:

\[\begin{align} \dfrac{d}{dx} \sigma(x) &= \dfrac{d}{dx} \left[ \dfrac{1}{1 + e^{-x}} \right] \\ &= \dfrac{d}{dx} \left( 1 + e^{-x} \right)^{-1} \\ &= -(1 + e^{-x})^{-2}(-e^{-x}) \\ &= \dfrac{e^{-x}}{\left(1 + e^{-x}\right)^2} \\ &= \dfrac{1}{1 + e^{-x}} \cdot \dfrac{e^{-x}}{1 + e^{-x}} \\ &= \dfrac{1}{1 + e^{-x}} \cdot \dfrac{(1 + e^{-x}) - 1}{1 + e^{-x}} \\ &= \dfrac{1}{1 + e^{-x}} \cdot \left( \dfrac{1 + e^{-x}}{1 + e^{-x}} - \dfrac{1}{1 + e^{-x}} \right) \\ &= \dfrac{1}{1 + e^{-x}} \cdot \left( 1 - \dfrac{1}{1 + e^{-x}} \right) \\ &= \sigma(x) \cdot (1 - \sigma(x)) \end{align}\]

Thus, the derivative of the sigmoid function is:

\[\begin{align} \sigma'(x) = \sigma(x) \cdot (1 - \sigma(x)) \end{align}\]

Additionally, we know that:

\[\begin{align} \sigma(x) + \sigma(-x) = 1 \end{align}\]

Now, let's consider the ratio of the odds of the winning and losing responses, denoted as $ g $:

\[\begin{align} g = \frac {\text{odds}_{\theta}(y_w|x)}{\text{odds}_{\theta}(y_l|x)} \end{align}\]

The gradient of the loss function \(l_{OR}\) is then:

\[\begin{align} \nabla l_{OR} &= \nabla \log \sigma \left( \log \left[ \frac {\text{odds}_{\theta}(y_w|x)}{\text{odds}_{\theta}(y_l|x)} \right] \right) \\ &= \frac{\sigma' (\log (g))}{\sigma (\log (g))} \\ &= \frac{[\log(g)]' \cdot \sigma (\log (g)) \cdot (1 - \sigma (\log (g)))}{\sigma (\log (g))} \\ &= \sigma (-\log (g)) \cdot [\log(g)]' \\ &= \left(1 + \frac{\text{odds}_{\theta}(y_w|x)}{\text{odds}_{\theta}(y_l|x)}\right)^{-1} \nabla \log\left(\frac{\text{odds}_{\theta}(y_w|x)}{\text{odds}_{\theta}(y_l|x)}\right) \end{align}\]

When the odds of the favored responses (winning responses) are significantly higher than the disfavored responses (losing responses), the first term in the final equation converges to 0. This behavior indicates that this term acts as a penalty term, accelerating the parameter updates if the model is more likely to generate the rejected (losing) responses.

By understanding the gradient of the ORPO loss, we can see how the model adjusts its parameters to favor the generation of preferred responses while penalizing the less favored ones, ensuring improved model performance over time.