Cross Encoders

Cross encoder is a type of model architecture used for re-ranking a relatively small set of candidates (typically 1,000 or less) with great precision. In the Question-Answering or machine reading literature, typically the task involves finding the top matching documents to a given query. A typical task is the MS MARCO dataset, which seeks to find the top documents that are relevant to a given bing query.

Basic Setup

Typically, the base model is some kind of pre-trained BERT model, and a classification head is added on top to output a probability. Each (query, document) pair is concatenated with [SEP] token in-between to form a sentence. The sentence is fed into the classification model to output a probability. The model is trained using binary cross-entropy loss against 0,1 labels (irrelevant or relevant).

This is the setup used by <Nogeuira 2019>, possibly the first paper to propose the cross encoder. Some specifics for their setup:

  • The query is truncated to max 64 tokens, while the passage is truncated such that the concatenated sentence is max 512 tokens. They use the [CLS] embedding as input to a classifier head.
  • The loss for a single query is formulated as below. refers to the score from the classifier model, refers to the documents that are relevant, and refers to documents in the top 1,000 retrieved by BM25 that are not relevant. Note that this results in a very imbalanced dataset.

  • The model is fine-tuned with a batch size of 128 sentence pairs for 100k batches.

As opposed to bi-encoders (or dual encoders), which take a dot product between the query embedding and the document embedding, we cannot pre-compute embeddings in the cross encoder setting, because the cross encoder requires a forward pass on the concatenated (query, document) pair. Due to the bi-directional attention on the full concatenated sentence, we need the full sentence before we can compute the score, which requires the query that we only see at inference time. Hence, the cross encoder is limited to reranking a small set of candidates as it requires a full forward pass on each query, candidate_document pair separately.

Contrastive Loss

The vanilla binary cross entropy loss proposed above may be thought of as a loss, in which each document is either relevant or irrelevant in absolute terms. However, treating relevance as a concept often better reflects reality. For example, given the first page of search results for a Google query, most of the documents should be relevant to some extent, but some are more relevant than the rest (and get clicked on). Simply treating all clicks as relevant and all non-clicks as irrelevant naively ignores the context (i.e. the neighbouring search results) in which the clicks were generated. It assumes that across query sessions, the average level of relevance of the results is comparable. Treating relevance as a concept within the same query session weakens this assumption and hence often works better.

Thus <Gao 2021> proposes the Local Contrastive Estimation loss. For a given query q, a positive document is selected, and a few negative documents are sampled using a retriever (e.g. BM25). The contrastive loss then seeks to maximize the softmax probability of the positive document against the negative documents.

It is confirmed in multiple experiments in Gao 2021 and Pradeep 2022 that LCE consistently out-performs point-wise cross entropy loss. Furthermore, the performance consistently improves as the number of negative documents per query (i.e. ) increases. In Gao 2021, up to 7 negatives (i.e. batch size of 8) were used. Pradeep 2022 shows that increasing the batch size up to 32 continues to yield gains consistently (albeit diminishingly).

Other details

Pradeep 2022's experiments show that using a stronger retrieval model (a ColBERT-based model) during inference generates slight gains in final performance (as opposed to BM25). Although Gao 2021 argues that it is also important to use the same retrieval model during model training (so that the cross encoder sees the same distribution of negatives during training and inference), Pradeep 2022 argues that the alignment is not as important as the stronger retrieval performance during inference.

References