Solatorio 2024 - GISTEmbed
GISTEmbed: Guided In-sample Selection of Training Negatives for Text Embedding Fine-tuning
This paper proposes a way to learn embeddings using contrastive learning. The main idea is to use a guide model to filter out false negatives from training for better data quality and performance. GISTEmbed models are currently topping the MTEB benchmark.
Implementation
The method has an implementation in the SentenceTransformer package. We will walk through the implementation here.
Firstly, the loss is initialized with both a model
to train and a guide
model, both of which are SentenceTransformers
.
In the forward
step:
sentence_features
is the input which is a list ofdict[str, Tensor]
- The list is length
2
if we have anchor, positive - The list is length
3
if we have anchor, positive, negative
- The list is length
- The
sentence_features
is passed into both themodel
and theguide
to get the respective embeddings- For
guide
, we may need to re-tokenize it if the tokenizer of theguide
differs frommodel
- This is done by using
batch_decode
and thentokenize
again
- For
- Now we have
anchor
,positive
andnegative
embeddings, each of shape(batch_size, embed_dim)
- The
sim_matrix
is used to compute pairwise cosine similarities:- The code is simply
torch.nn.CosineSimilarity(dim=-1)(embed1.unsqueeze(1), embed2.unsqueeze(0))
embed1
becomes shape(batch_size, 1, embed_dim)
and embed2 becomes shape(1, batch_size, embed_dim)
- The similarity is compared at dimension
-1
- Broadcasting ensures that the comparison is done pairwise, such that the result is of shape
(batch_size, batch_size)
- This is a common way to do pairwise similarity
- The code is simply
- Now we obtain the pairwise similarity matrices:
ap_sim
,aa_sim
,pp_sim
,an_sim
guided_ap_sim
,guided_aa_sim
,guided_pp_sim
,guided_an_sim
- The anchor positive similarity threshold is used to filter away potential false negatives
- This is simply the
guided_ap_sim.diagonal()
which corresponds to the similarity between the anchor and positive in each row - Note that they use the guide model for determining the threshold
- This threshold is called
guided_sim
- This is simply the
mask_false_negatives
is used to suppress false negatives- Using the
absolute
strategy, cases where theguided_sim_mat > guided_sim - self.margin
will be suppressed (set totorch.inf
) - The idea is that negatives should not have a higher similarity than the threshold, otherwise there is a higher probability they are false negatives
- This function is applied to
ap_sim
,aa_sim
,pp_sim
andan_sim
to mask false negatives
- Using the
- Finally, we have
scores
scores = torch.cat([ap_sim, aa_sim, pp_sim, an_sim], dim=1) / self.temperature
- This is of shape
(batch_size, 4*batch_size)
- We create
labels
which istorch.arange(batch_size)
- This is because the correct label is in the diagonal of
scores
matrix
- This is because the correct label is in the diagonal of
- Finally the loss is computed via
torch.nn.CrossEntropyLoss()(scores, labels)
- Each row is considered as a classification task where the label corresponds to the column position where the correct class is found
- The log softmax loss is then computed per row and averaged across rows