Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Tay 2022 - Differentiable Search Index

Transformer Memory as a Differentiable Search Index

This paper proposes a new paradigm of search. Instead of using Approximate Nearest Neighbour (ANN) search using embeddings, the paper proposes to train a transformer seq2seq model to predict the ID of a document given a search context.

Although the paper is focused on search and Question-Answering, the approach is general and can be used for any recommendation setting.

Comparison to Traditional Search Index

Traditional search index such as BM25 or ANN search has three steps:

  1. Indexing. Indexing involves storing representations of each document to enable easy search later on. For traditional lexical search like BM25, a sparse and high dimensional vector is stored for each document and an inverted index is built. For ANN search, a dense vector is stored and an index like HNSW is built.
  2. Training. For lexical search, the only "training" is to store the TF-IDF quantities in the index. For ANN search, a contrastive learning objective is used to shape the embedding space for effective retrieval.
  3. Retrieval. For lexical search, a lookup of the inverted index based on the query terms is performed. For ANN search, a Maximum Inner Product (MIPS) search is performed using some algorithm.

The differentiable search index operates differently. The authors argue that it is more elegant:

  1. Indexing. Indexing is just a specific form of model fine-tuning. Specifically, the task is to predict the docid given the content of the document, i.e. f(doc_content) = docid.
    • There is no need to store a separate search index as the only artifact we need becomes the transformer model itself
  2. Training. Training is a form of supervised fine-tuning where we predict f(query) = relevant docid.
  3. Retrieval. Retrieval is simply running a forward pass on the transformer on a new query.

The nice thing about this setup is that it extends naturally to a multi-task setting, by using different prefixes for different tasks. This is similar to the T5 approach for multi-task learning.

Ablations

The main idea is simple and intuitive, but there are many details to the implementation that are non-trivial.

Indexing Strategy

The indexing step aims to teach the transformer model how to link each docid with the contents of the document. This is entirely new knowledge to the transformer since the docids are arbitrary constructs.

The following strategies were explored:

  1. f(document context) = docid: this had the best performance
  2. f(docid) = document context: this performed terribly
  3. Mix of 1. and 2.: Mix training batches with both options 1 and 2, with a prefix string to indicate which direction.

Indexing and Training Order

Recall that we have two modes:

  • In indexing mode, our training data is f(document content) = docid
  • In training mode (or retrieval task), our training data is f(query) = relevant docid

The following strategies were explored:

  1. Index first, train second. This means that we train with the indexing mode until convergence, then switch to training mode
  2. Multi-task. This means that we mix training batches of both indexing mode and training mode and just use a prefix to indicate the different task setting

The multi-task strategy worked much better. The performance was sensitive to the mix between indexing and retrieval task. The ideal mix was found to be around 32 instances of indexing task to 1 instance of retrieval task.

Representation of Document IDs

The following strategies were explored:

  1. Arbitrary Atomic Identifiers. This means assigning an arbitrary random identifier for each document, and letting each identifier be a new token in the vocabulary.

    The downside is that there is an explosion in the vocabulary if we have a lot of items

  2. Semantically Structured IDs. This means representing each document as an ordered list of tokens, where the tokens contain some semantic meaning.

    This is implemented as a hierarchical k-means clustering.

    • Firstly, each document is embedded using a universal T5 embedder.
    • Next, k-means with 10 centroids is performed. The cluster ID 0-9 is assigned as the first document ID token for each document
    • Each cluster is then further clustered to generate the subsequent document ID tokens

Inference Time

To generate recommendations at inference time, given a query, we pass it into the transformer model and just do generative decoding with beam search. Beam search maintains at each decoding step top k candidates based on log probability. The top k candidates sorted in descending log probability is thus our list of recommended candidates.

Results

Overall, the semantic ID approach was the best performing, and had significant gains over a fine-tuned dual encoder model.

The scaling effects of larger models seemed much better with the DSI approach compared to dual encoders. In my experience, dual encoder models also don't seem to benefit much from scaling to larger models (minimal effect scaling from 20M models to 500M models.)

The DSI approach is closer in performance with hits@10 to dual encoder but had much better hits@1 performance.

Concerns

One finding was that shorter representations of the document (truncate at first L tokens) worked significantly better. The authors suggested that longer document lengths might make indexing memorization more difficult. This seems like a significant bottleneck to the paradigm, but future work might resolve this issue.

One big shortcoming of this current method as-is is that the embedding representation used is an OTS T5 model. Subsequent papers will find that relevance-based fine-tuned embeddings perform much better at generating good semantic IDs to represent each document.