SentenceTransformers
SentenceTransformers is a useful library for training various BERT-based models, including two-tower embedding models and cross encoder reranking models.
Cross Encoder
SentenceTransformers v4.0 updated their cross encoder training interface (see the v4.0 blogpost). Here we try to follow the key components for cross encoder training using their API.
The main class for training is CrossEncoderTrainer. We rely on a Huggingface datasets.Dataset class to provide training and validation data. CrossEncoderTrainer requires that the dataset format matches the chosen loss function.
The loss overview page provides a summary of cross encoder losses and the required dataset format. In general for cross encoder training, we have two sentences which are either positively or negatively related to each other. Which loss function we choose depends on the specific dataset format we possess.
BinaryCrossEntropyLoss
Use this loss if we have inputs in the form of (sentence_A, sentence_B) and a label of either 0: negative, 1: positive or a float score between 0-1. In the huggingface dataset, we would need to ensure that the label column is named label or score, and have two other input columns corresponding to sentence_A and sentence_B. For sentence_transformers package in general, order of columns matter, so we should set it to sentence_A, sentence_B, label.
Inspecting the source code would show that each sentence pair is tokenized and encoded by the cross encoder model. The cross encoder must output a single logit (i.e. initialized with num_labels=1). Thus we get a prediction vector of dim batch_size. The torch.nn.BCEWithLogitsLoss is then used to compute the binary cross entropy loss of the prediction logits against the actual labels , according to the standard bce loss:
This is a simple and effective loss. The user should ensure that the labels are well distributed (between and ) without any severe class imbalance.
CrossEntropyLoss
The CrossEntropyLoss is used for a classification task, where for a given input sentence pair (sentence_A, sentence_B), the label is a class. For example, we may have data where each sentence pair is tagged to a 1-5 rating scale. We need to instatiate the CrossEncoder class with num_labels=num_classes for this use case. This creates a prediction head for each class.
Looking at the source code, we see that this loss simply takes the prediction logits from the model (of dimension num_labels) and computes the torch.nn.CrossEntropyLoss against the actual labels.
Note that the cross entropy loss takes the following form. Given num_labels=C and logits of , where the correct label is index , we have:
MultipleNegativesRankingLoss
This is basically InfoNCE loss or in-batch negatives loss. The inputs to this loss can take the following forms:
(anchor, positive)sentences(anchor, positive, negative)sentences(anchor, positive, negative_1, ..., negative_n)sentences
The documentation page has a nice description of what this loss does: Given an anchor, assign the highest similarity to the corresponding positive document out of every single positive and negative document in the batch.
Diving into the source code:
- The inputs are
list[list[str]], where the outer list corresponds to[anchor, positive, *negatives]. The inner list corresponds to the batch size. scoresof dimension(batch_size)are computed for each anchor, positive pairget_in_batch_negativesis then called to mine negatives for each anchor.- candidates (positive and negatives) are extracted at
inputs[1:]and flattened into a long list - A mask is created such that for each anchor, all the matching positive and negative candidates are masked out (not participating)
- The matching negatives do not participate because they will be added later on
- Amongst the remaining negatives,
torch.multinomialis used to selectself.num_negativesnumber of documents per anchor at random self.num_negativesdefaults to4- These randomly selected negative texts are then returned as
list[str]
- candidates (positive and negatives) are extracted at
- For each negative in num_negatives mined in-batch negatives:
scoreof dimension(batch_size)is computed for the anchor, negative pair- The result is appended to
scores
- Similarly, for each hard matching negative:
scoreof dimension(batch_size)is computed for the anchor, hard negative pair- The result is appended to
scores
Now scores is passed into calculate_loss:
- Recall that
scoresis a list of tensors where the outer list is size1 + num_rand_negatives + num_hard_negatives, and each tensor is of dimensionbatch_size - Thus
torch.cat+tranposeis called to make it(batch_size, 1 + num_rand_negatives + num_hard_negatives) - Note that for each row, the first column corresponds to the positive document
- Hence the
labelsmay be created astorch.zeros(batch_size) - Then
torch.nn.CrossEntropyLoss()(scores, labels)may be called to get the loss
This sums up the loss computation for MultipleNegativesRankingLoss.