Wang 2024 - LLM for Pinterest Search
This paper explains Pinterest use of pre-trained LLMs to improve their search ranking model. They first fine-tune a Llama-8B model on human-labelled search relevance data, then used the model to label large amounts of unlabelled data. The unlabelled data was used to train a smaller MLP, which led to significant improvement in search performance.
Background
Pinterest curates human-labelled dataset for search relevance. For each query, a pin is classified into five ordered relevance levels:
- L5: Excellent / Highly Relevant
- L4: Good / Relevant
- L3: Complementary / Marginally Relevant
- L2: Poor / Irrelevant
- L1: Highly Irrelevant
Teacher Model
The teacher model takes the form of a standard cross encoder setup. The teacher model is either a BERT-based LLM or a decoder-style LLM. The text inputted to the LLM is of the following format:
[CLS] Query [SEP] Pin Text
The embedding of the [CLS]
token is taken for BERT-based models, whilst the embedding of the final non-padding token is taken for decoder-based models (such as Llama). The embedding is passed through several fully-connected layers, and the final output dimension corresponds to the 5
relevance levels. The LLMs are fine-tuned during training by minimizing the pointwise multi-class cross entropy loss.
To enhance the representation of the pin text for the teacher model, several features are concatenated together:
- Pin titles and descriptions are used as-is
- Synthetic image captions are generated using an off-the-shelf model called BLIP
- High-engagement query tokens. Queries resulting in high engagement with a pin are aggregated over the past 2 years, and the top tokens are selected.
- User-curated Board Titles. Users curate their pins onto boards. The board titles are aggregated across users and top tokens are selected.
- Link Titles and Descriptions. A long click is defined as when a user clicks into a link to the webpage of a pin and stays
10 seconds
there. The incoming link url and description contains useful text which is mined.
A few off-the-shelf LLMs are tested as teacher models and the results reported below.
Student Model
The teacher model is used to label large amounts of unlabelled search impressions data. Specifically, for each row of data, the teacher model generates a softmax probability distribution over the 5-point relevance labels. The student model is then taught to mimic the predicted probability distribution using cross-entropy loss.
Note: The exact loss function for knowledge distillation is not covered in the paper. But the classical way is to perform knowledge distillation using KL-divergence loss. An implementation of the loss can be found here. The basic idea is to minimize KL-divergence between the teacher predicted probability distribution over labels and the student predicted probability distribution.
The student model is a simple feed forward network on a diverse set of features. It seems like some feature engineering is performed to optimize performance. Features include:
- Query-side embeddings. Pinterest has classifiers to categorize the query into a query interest type, shopping interest type etc. These categorical features are embedded using an embedding table. The SearchSage embedding for the query is also included.
- Pin-side embeddings. The PinSage pin embedding, SearchSage pin embedding etc. are included
- Query-pin interaction features. Standard query-pin match scores like BM25 score, % of query terms matched are also included as features
All features are presumably concatenated together, and passed into the MLP network. As mentioned above, this student network is trained using knowledge distillation from the teacher predictions above.
Results
The teacher and student model are evaluated using human-annotated search relevance dataset. The train set for the teacher model is around 280k rows of human-annotated data, and the remaining 30k rows are used for evaluation. The teacher model is then used to run inference on 30 million rows of unlabelled data, and then used to train the student model.
For the teacher model, scaling up the base model consistently produces better test performance (accuracy):
- BERT-base: 53.5%
- T5-base: 56.9%
- DeBERTaV3-base: 58.0%
- Llama3-8b: 60.2%
Note that the Llama3-8b model was trained using qLora.
The experiments also showed that:
- Scaling up the amount of teacher-inferred data was crucial for improving student model performance
- The teacher LLM was able to successfully transfer learning from a purely English human-annotated dataset to multiple other languages, thus allowing the student model to learn multi-lingual behaviour.
- The inclusion of additional text features for the teacher model above helped to improve performance significantly.
Takeaways
This is a simple paper that shows the power of using LLM models to bootstrap learning from a smaller product-specific dataset to distill an effective student model. This is a useful paradigm because in search systems, we often have vast amounts of unlabelled impressions data that we can use a teacher model to run inference on.
The paper also shows that scaling up the LLM is able to produce non-trivial increases in search relevance performance on the human-annotated dataset.