Yi 2019 - LogQ Correction for In Batch Sampling

Yi 2019 - Sampling Bias Corrected Neural Modelling for Large Corpus Item Recommendations

This paper proposes a way to perform logQ correction for sampling bias introduced by in-batch negative sampling when training two tower models. The algorithm proposed is a streaming algorithm that estimates item frequencies based updates after seeing each mini batch.

Setup

Let , denote a user and item respectively, where there are users and items. Let and denote user and item embedding functions that map each and to . These functions are typically:

  • Some sentence transformer model for texts
  • Some hash embedding in the collaborative filtering setting

The output of the model is the inner product of the embeddings, i.e. . The goal is to train the model from a training dataset of user-item interactions, denoted by , where , are the interacting query and item and is the associated reward.

  • Typically to denote an interaction
  • We can also use to denote some quality weight, e.g. time spent on product

Given a query , we typically model the conditional probability of picking item based on the softmax function. parametrizes the embedding model:

We then design the loss function as a weighted log likelihood of the training interactions:

In Batch Sampling

In practice, the denominator for above is not feasible to compute when the number of items is very large. The common practice is to sample only a subset of items that are drawn in a mini batch. Hence given a mini batch of B pairs and for any , the batch softmax becomes:

Note that each refers to a positive pair. However, the batch softmax above is usually a very biased estimate of the full softmax. This is because our training data usually has a heavy bias toward popular items, hence the likelihood of a popular item being included in the denominator is usually quite skewed.

In other words, our model trained with this biased likelihood function may have a low training loss against popular items in the denominator during training. But when used in retrieval, the model may be assigning high scores to rare items that should be negatives, just that our model did not have a chance to discriminate against them due to the biased sampling during training.

This issues underlies the common phenomenon when training such retrieval embedding models where the reranking performance is good but retrieval performance is very bad. The reason is that reranking is often performed against popular items that the model sees often, but retrieval by definition searches across the whole item catalogue. Hence retrieval is (from this perspective) a harder task than reranking. Special attention must be paid during training to ensure that the model learns to discriminate well against all items in the catalogue, and this logQ correction is one of the methods at our disposal.

In Adaptive Importance Sampling to Accelerate Training of A Neural Probabilistic Language Model, the authors propose the following way to correct the biased batch softmax by correcting each score logit:

Where denotes the probability of sampling an item in a random batch. With this correction, we can denote the batch softmax as:

And finally we have the batch loss function as:

Estimating Sampling Probability in Stream Setting

Notably, the batch loss function does not require holding a fixed set of items in memory to serve as negative candidates, making it suitable for use in a streaming training data setting. Thus, the authors propose a method to estimate the sampling probability in a streaming fashion as well.

The first observation is that it is easier to track the number of steps (or batches) between two consecutive hits of item . e.g. if we only get one item once every 50 batches, then . The proposed algorithm is as follows:

  1. Initialize Arrays with size
  2. Let be a hash function from an item ID to
  3. At batch , sample a batch of items. For each item in the batch:
  4. At inference time, the sampling probability for item will be .

Other Notes

The authors note that adding l2-normalization to embeddings improves model trainability and leads to better retrieval quality. Also, adding a temperature to each logit helps to sharpen the predictions. In their experiment, the best is usually around 0.05 (i.e. logits get multipled by 20x).