Rafailov 2023 - Direct Preference Optimization

Here we trace the derivations from the DPO paper. Denote the model after SFT as , i.e. the policy is a probability function for each pair of inputs and answers (x, y). Naturally, we can use this policy to generate tokens by choosing the response with the highest probability (or approximate it in a greedy token-by-token manner).

To perform RLHF, we first need to build a reward model. First, we prompt the SFT model to obtain pairs of answers . These samples are presented to human labellers who record their preferences in the form , where wins and loses. These preferences are assumed to be generated from an underlying reward model which we do not have access to.

We wish to learn this reward model. Since we only have access to pairwise preferences instead of the reward score, a common approach is to model the pairwise preferences using the Bradley-Terry model. Specifically, we assume that the observed human preference decisions are related to the underlying human reward model in the following way:

Suppose we have a static dataset of comparisons sampled from the human preference model of . We can create a reward model and use the BT-model to express the negative log likelihood of . Note that are we using the expression of the BT-model as a sigmoid function. With this NLL expression, we can optimize for using gradient descent to learn the reward model from . Notice that the heart of Equation (2) is essentially just a difference in reward between the winning answer and losing answer .

Note that is usually initialized using the SFT model , but with an additional linear layer over the final transformer layer to generate a scalar value as the reward.

Having learned the reward model, we then need to use the learned reward function to fine-tune using reinforcement learning. Specifically, we set as the reference model, and initialize a new model that we wish to train. Usually, is also initialized as a copy of . The objective function we wish to maximize is:

Inspecting this objective, we see that we are trying to tune such that it generates answers that maximize the learned reward function , while at the same time ensuring that we do not deviate too far from the original reference model. is a hyperparameter controlling the degree of deviation. This penalty constraint serves to:

  1. Ensure that we do not drift too far from the (x, y) distribution on which the reward model is accurate
  2. Ensure that we maintain generational diversity and not just collapse into a single high-reward answer for a given prompt

Objective (3) may not be directly optimized, because we need to generate at each step from the current policy (not sure if I fully understand this). Hence typically this is optimized using reinforcement learning using PPO.

Direct Preference Optimization

The RLHF process in general is unstable, requires more memory / computation and requires tricks to make it work. Hence the authors of DPO set out to create an optimization procedure that:

  1. Avoids fitting an explicit, standalone reward model
  2. Avoids using reinforcement learning

DPO starts off with the KL-constrained reward maximization objective from Equation (3) above. The first step is to show that the optimal policy for this objective is of the following form for an arbitrary reward model :

The derivation for Equation (4) is as follows. For a given reward function :

For line 2 above, recall that is the expected value of if the random variable is drawn from . Since the outer expectation is over draws from , we can breakdown the KL-divergence by bringing the log difference into the expectation. Line 3 simply divides by and flips max to min. Line 4 uses to bring the reward term into the denominator of the left term, then introduces an arbitrary . Note that the two can be cancelled out if we brought them together, but we will be using them later on.

Now let us define the optimal policy . We will need to prove that is indeed optimal. Note that is a valid probability distribution as:

  • ; and
  • , since the denominator is just the sum over of the numerator

Since is not a function of , we can sub in and take out. The left term becomes a KL-divergence between which we are optimizing over and the optimal policy .

Finally, note that does not depend on , so we only need to consider the KL-divergence term. Gibb's inequality tells us that KL-divergence is minimized at if and only if the two distributions are identical. This completes our derivation of (4) by showing that is indeed the optimal policy.

Now that we have completed the derivation, let's consider what Equation (4) is saying. It tells us that we have an analytical solution for the policy that optimizes (3), and that it can be expressed in terms of (which we already have) and a given reward function .

Since we previously learned a reward model , we could simply plug that into (4) to get our optimal policy. Specifically, for a new input prompt , we can compute for all possible values of and pick the best . We can ignore since it is fixed for a given prompt . Intuitively, the new model scales the probability of high reward answers up with an exponential multiplier, and the degree of scaling is controlled by . However, this is not computationally practical as we need to evaluate over a very large space of (i.e. all possible answers for a given prompt ).

Hence, we want to find a form of the optimal policy which does not involve the partition function nor the reward model . We start by re-arranging Equation (4), taking log on both sides and re-arranging:

Since Equation (5) holds for any arbitrary reward model, we can use the optimal (unknown) human reward model back in Equation (1), . Also, note that in Equation (5) refers to the optimal policy under , so since we are using the optimal reward , we can call this optimal policy as well. Now we plug Equation (5) back into Equation (1).

The derivation of the above is quite simple, we just need to note that it is of the form , and use the expression of the BT-model as a sigmoid function. The great thing is that the pesky partition function cancels out because the BT-model simply ends up with the difference between two scores / rewards.

Equation (6) looks simple but is quite remarkable when we compare it to Equation (1), because we now have the probability of the human preference data in terms of just the reference policy and the optimal policy , without the reward model at all! Of course, the reward model is implicitly embedded in the equation. If we stare at Equation (6), we see the implicit reward function is .

The benefit of this new formulation is that we can write the maximum likelihood objective in terms of the optimal policy. We do not know the optimal policy, but we can parametrize it as and use the maximum likelihood objective to train the model. Our new policy objective becomes:

Application to Dual Encoder Retrieval

The DPO framework offers a way to fine-tune a policy according to human preferences, whilst ensuring stability against a reference model. In the original formulation, represents the probability of generative model generating given an input prompt .

A related problem is that of fine-tuning embedding models for search retrieval, in what is known as the dual encoder framework. In this case, we also have preference data in the form of triplets (query, positive_passage, negative_passage), and we wish to fine-tune embeddings such that the dot-product or cosine similarity between (query, positive_passage) is high whilst that of (query, negative_passage) is low. In this formulation, we could let the policy represent the normalized probability of a relevant match between query and passage. We can then borrow the framework of DPO to fine-tune our embeddings. The benefit of DPO compared to typical optimization objectives for dual encoder is the stability of the policy against the reference policy, which hopefully is a good form of regularization even when preference data is limited.

Specifically, we define a dual encoder policy , where represents a query and represents a passage, like so:

  1. Encode into a vector using a BERT model
  2. Encode into a vector using a BERT model (can be same model as the query encoder)
  3. Take the dot product
  4. Run the dot product through a sigmoid layer to convert it into a probability

We can then apply this method to a reference encoder model and call it , and optimize another encoder model against this reference model. Will need to conduct experiments to see if this method offers gains against the typical dual encoder objectives, such as Karpuhkhin 2020.

References