SentenceTransformers
SentenceTransformers is a useful library for training various BERT-based models, including two-tower embedding models and cross encoder reranking models.
Cross Encoder
SentenceTransformers v4.0 updated their cross encoder training interface (see the v4.0 blogpost). Here we try to follow the key components for cross encoder training using their API.
The main class for training is CrossEncoderTrainer
. We rely on a Huggingface datasets.Dataset
class to provide training and validation data. CrossEncoderTrainer
requires that the dataset format matches the chosen loss function.
The loss overview page provides a summary of cross encoder losses and the required dataset format. In general for cross encoder training, we have two sentences which are either positively or negatively related to each other. Which loss function we choose depends on the specific dataset format we possess.
BinaryCrossEntropyLoss
Use this loss if we have inputs in the form of (sentence_A, sentence_B)
and a label of either 0: negative, 1: positive
or a float score between 0-1
. In the huggingface dataset
, we would need to ensure that the label column is named label
or score
, and have two other input columns corresponding to sentence_A
and sentence_B
. For sentence_transformers
package in general, order of columns matter, so we should set it to sentence_A, sentence_B, label
.
Inspecting the source code would show that each sentence pair is tokenized and encoded by the cross encoder model. The cross encoder must output a single logit (i.e. initialized with num_labels=1
). Thus we get a prediction vector of dim batch_size
. The torch.nn.BCEWithLogitsLoss
is then used to compute the binary cross entropy loss of the prediction logits against the actual labels , according to the standard bce loss:
This is a simple and effective loss. The user should ensure that the labels are well distributed (between and ) without any severe class imbalance.
CrossEntropyLoss
The CrossEntropyLoss is used for a classification task, where for a given input sentence pair (sentence_A, sentence_B)
, the label is a class. For example, we may have data where each sentence pair is tagged to a 1-5
rating scale. We need to instatiate the CrossEncoder
class with num_labels=num_classes
for this use case. This creates a prediction head for each class.
Looking at the source code, we see that this loss simply takes the prediction logits from the model (of dimension num_labels
) and computes the torch.nn.CrossEntropyLoss
against the actual labels.
Note that the cross entropy loss takes the following form. Given num_labels=C
and logits of , where the correct label is index , we have:
MultipleNegativesRankingLoss
This is basically InfoNCE loss or in-batch negatives loss. The inputs to this loss can take the following forms:
(anchor, positive)
sentences(anchor, positive, negative)
sentences(anchor, positive, negative_1, ..., negative_n)
sentences
The documentation page has a nice description of what this loss does: Given an anchor, assign the highest similarity to the corresponding positive document out of every single positive and negative document in the batch.
Diving into the source code:
- The inputs are
list[list[str]]
, where the outer list corresponds to[anchor, positive, *negatives]
. The inner list corresponds to the batch size. scores
of dimension(batch_size)
are computed for each anchor, positive pairget_in_batch_negatives
is then called to mine negatives for each anchor.- candidates (positive and negatives) are extracted at
inputs[1:]
and flattened into a long list - A mask is created such that for each anchor, all the matching positive and negative candidates are masked out (not participating)
- The matching negatives do not participate because they will be added later on
- Amongst the remaining negatives,
torch.multinomial
is used to selectself.num_negatives
number of documents per anchor at random self.num_negatives
defaults to4
- These randomly selected negative texts are then returned as
list[str]
- candidates (positive and negatives) are extracted at
- For each negative in num_negatives mined in-batch negatives:
score
of dimension(batch_size)
is computed for the anchor, negative pair- The result is appended to
scores
- Similarly, for each hard matching negative:
score
of dimension(batch_size)
is computed for the anchor, hard negative pair- The result is appended to
scores
Now scores
is passed into calculate_loss
:
- Recall that
scores
is a list of tensors where the outer list is size1 + num_rand_negatives + num_hard_negatives
, and each tensor is of dimensionbatch_size
- Thus
torch.cat
+tranpose
is called to make it(batch_size, 1 + num_rand_negatives + num_hard_negatives)
- Note that for each row, the first column corresponds to the positive document
- Hence the
labels
may be created astorch.zeros(batch_size)
- Then
torch.nn.CrossEntropyLoss()(scores, labels)
may be called to get the loss
This sums up the loss computation for MultipleNegativesRankingLoss
.