Yan 2025 - Semantic IDs
- Article: https://eugeneyan.com/writing/semantic-ids/
- Code: https://github.com/eugeneyan/semantic-ids-llm
Eugene Yan's experience of training an LLM-recommender hybrid that can recommend items using natural language.
Extend the vocab
Insert semantic ID tokens like <|sid_0|>, <|sid_1|> into the vocab. Sequences of these tokens will be used to represent the catalog. So an item may be represented like:
<|sid_start|><|sid_173|><|sid_324|><|sid_764|><|sid_768|><|sid_end|>
Data
He used video games category of Amazon Reviews 2023, because it has rich product metadata and lots of behavioural data.
- Keep only products with titles longer that
20 charsand description longer than100 chars 66kproducts737krows of interactions
Training Semantic IDs
Used RQ-VAE to train semantic IDs. Will not expound on RQ-VAE here. Eugene used:
3-levelcodebook- Each level has
256codes - This has collisions on around
10%of the66k products - Added a sequentially increasing token to each ID to ensure uniqueness
Baselines
- Train a SASRec model on semantic IDs to compare against traditional SASRec
- Use Qwen3-Embedding-0.6B to encode product metadata into embeddings
- Finetune Qwen3-8B to recommend items via semantic IDs
Data Cleaning
- Clean item descriptions using Gemini 2.5 Flash to remove HTML, reduce verbosity
- Remove promotional text, standardize formatting for titles etc.
- Augment data by extracting structured metadata like product type, platform, genre, hardware type etc.
- Build user interaction histories
Training RQ-VAE
First, we get the 1,024 embeddings using Qwen3-0.6B, then l2-normalize. The RQ VAE consists of:
- An encoder
- A codebook with 3 quantization levels
- A symmetric decoder
Some tricks:
- Use rotation trick as a replacement for the straight through estimator
- Initialize codebooks with k means clustering
- Reset unused codes
- Large batch size
Metrics for measuring quality of RQ-VAE:
loss/reconstruction: How well we can reconstruct the original item embeddings after compressing and decompressingloss/vq: Combined codebook and commitment loss across all levels.- Ensures that encoder outputs and codebook vectors stay close together
loss/total: sum ofloss/vqandloss/reconstructionloss/validation: total loss but on the held out validation set. Our deciding metric.metrics/avg_residual_norm: Leftover residuals between the quantized embedding and the original embeddingmetrics/unique_ids_proportion: % of items with unique IDS in a batch- Helps to check against codebook collapse
- We want this metric to be high
- Codebook distribution
- Plot the item distribution amongst the
256codes at each level - It should look like a uniform distribution
- Plot the item distribution amongst the
Some hyperparameter tuning:
- Set :
- Tried . is the commitment weight that balances reconstruction accuracy with codebook commitment
- had the lowest validation loss, so it was chosen
- Investing in data cleaning significantly improved all the metrics
SASRec comparison
Eugene tested two SASRec variants:
- Traditional SASRec
- Each item has a unique ID with no semantic meaning
- Model uses 2 causal attention blocks, 64 dim hidden dimension, trained with binary cross entropy loss
- Dot product of embeddings is used to generate scores
- Semantic SASRec
- For semantic version, each item is a sequence of semantic tokens
- Hence for each position, the model needs to generate a sequence of tokens to represent an item
- Instead of binary cross entropy loss, we need to sum up the cross entropy loss at each position
- Question: does it make sense to add some weights to put more weightage on the first semantic token and decay it for subsequent positions?
- Teacher forcing is used for this training
- A larger model is used here, 4 causal attention blocks and 384 hidden dim
From this experiment, Eugene found that traditional SASRec is significantly better. But he puts this down to the difficulty of generating a sequence of tokens compared to directly generating one token. I also note that we are not using a pretrained LLM here, which means we are missing out on some pretrained capabilities that we could have tapped on.
Finetuning Qwen-8B
Now finally we train the language model to converse in semantic IDs.
First, we build the training dataset of 4.2 million conversations of various task types:
- Given a semantic ID, predict the item description
- Given item description, predict the semantic ID
- Predict the next item in the user's sequence
- Understand relationships between item categories
- Multi-hop reasoning
Each of these are formatted as a conversation with a system prompt, user prompt and assistant response. Now, we finetune the model in two phases.
In Phase 1, we focus on extending the model's vocabulary.
- We add
1,027new vocabulary items to the model's vocabulary, and resize the model's embedding matrix for them. - In this phase, we freeze all the model parameters except the input and output embedding layers, training
15%of total parameters - Train for
1,000steps with relatively high learning rate
In Phase 2, we do the full training:
- Train for 3 epochs on the full dataset
- 3% warmup, lower learning rate, 8-bit AdamW optimizer for memory efficiency
- Monitor progress with callbacks:
- Track gradient norms
- Measure hierarchical accuracy (I guess on different codebook granularities) on recommendation tasks
- Sample model outputs on fixed set of conversational inputs
Inference
Now the model is trained, we can have chat / recommendation conversations with it:
User: “Halo 3 Limited Edition -Xbox 360”, “Lost Planet: Extreme Condition - Xbox 360”, “Lost Planet: Extreme Condition Collector’s Edition -Xbox 360” <|rec|>
Model:
<|sid_start|><|sid_64|><|sid_351|><|sid_733|><|sid_768|><|sid_end|>
The model can also explain its recs:
User: I just finished “Dragon Quest Heroes II Explorer’s Edition - PlayStation 4”. Suggest another <|rec|> and explain why:
Model:
<|sid_start|><|sid_139|><|sid_289|><|sid_534|><|sid_768|><|sid_end|>If you liked Dragon Quest Heroes II, you might like Nights of Azure because both are action RPGs for the PlayStation 4 with a focus on combat and character progression. Both games offer a narrative-driven experience with a strong emphasis on combat mechanics, suggesting a shared player interest in this genre.
We can even have multi-turn conversation with the model:
User: I’m looking for games similar to Mario Kart. <|rec|>
Model:
<|sid_start|><|sid_131|><|sid_503|><|sid_738|><|sid_768|><|sid_end|>, <|sid_start|><|sid_131|><|sid_260|><|sid_577|><|sid_768|><|sid_end|>, <|sid_start|><|sid_145|><|sid_416|><|sid_671|><|sid_768|><|sid_end|>User: How about something similar but for Xbox? <|rec|>
Model:
<|sid_start|><|sid_183|><|sid_461|><|sid_517|><|sid_768|><|sid_end|>, <|sid_start|><|sid_183|><|sid_461|><|sid_562|><|sid_768|><|sid_end|>,<|sid_start|><|sid_183|><|sid_461|><|sid_616|><|sid_768|><|sid_end|>User: Suggest a name and description for the bundle.
Model: Xbox Racing Legends: Speed & Style Pack
Code
Here we deep dive into Eugene's code and how it is implemented. Most of the code is contained in the /src directory.
device_manager.py
The DeviceManager detects the device (cpu or cuda or mps) and is instantiated early in the scripts. The interesting part torch.set_float32_matmul_precision("high") performed when device is cuda. It's supposed to speed up float32 operations?
tokenize_items.py
This script tokenizes the product descriptions of video games:
- Uses
Qwen/Qwen3-Embedding-0.6B - Batch size =
32, max length =2048 - Reads the data from
data/output/Video_Games_items_updated.parquetinto polars - Looks for the
item_contextfield (already preprocessed) - Uses the following prompt which will be tokenized
Instruct: Given a product description, generate a semantic embedding that captures
its key features and characteristics.
Query: {original_item_text}
- Saves tokenized
input_idsandattention_masksand saves them in.npz(compressed numpy) format usingnp.savez_compressed
embed_items.py
Takes the tokenized file and embeds them.
- Writes embedded items into a parquet file
train_rqvae.py
Compresses an item embedding into hierarchical semantic IDs.
The RQVAEConfig defines the following parameters:
item_embedding_dim: embedding dim of our embedding modelencoder_hidden_dims:[512, 256, 128]the size of the VAE encodercodebook_embedding_dim:32Dimension of codebook vectors- Qn: this does not need to match qwen embedding dim?
codebook_quantization_levels:3levels in the codebookcodebook_size:256number of codes per levelcommitment_weight:0.25Commitment loss weight (beta)use_rotation_trick:TrueUse rotation trick for better gradient flowbatch_size:32768training batch size (why so large?)gradient_accumulation_steps:1num_epochs:20000scheduler_type:cosine_with_warmupwarmup_start_lr:1e-8used for cosine_with_warmupwarmup_steps:200used for cosine_with_warmupmax_lr:3e-4maximum learning rate (start of cosine)min_lr:1e-6minimum learning rate (end of cosine)use_gradient_clipping:Truegradient_clip_norm:1.0use_kmeans_init:Trueinitializes codebook vectors using k-meansreset_unused_codes:Truereset unused codes periodically to avoid collapsesteps_per_codebook_reset:2Reset unused codebook codes every N stepscodebook_usage_threshold:1.0only reset if usage falls below this proportion (0-1)val_split:0.05steps_per_train_log:10log every N stepssteps_per_val_log:200validate and checkpoint every N steps
EmbeddingDataset is a torch dataset holding the embeddings:
- Extracts all embeddings and holds them in a
torch.tensorat init- Not worried about OOM?
QuantizationOutput
QuantizationOutput is used to hold data for one quantized item:
- Holds the local loss for one codebook layer
- Subclasses
NamedTuplewhich is more lightweight thandataclass quantized_st: Tensor- The quantized vector which is passed onto the next layer
- Has the "gradient trick" (either straight through or rotation trick) applied
- Allows backpropagation into the encoder even though we passed through the non-differentiable codebook layer
quantized: Tensor- The raw nearest neighbour vectors from the codebook
- No gradients
- Observation
quantizedandquantized_stshould be identical in values, just that one has gradients attached
indices: Tensor- These are integer indices which represent the semantic IDs
loss: Tensor- The combined loss for this specific codebook layer
loss = codebook_loss + beta * commitment_loss
codebook_loss: Tensor- Measures how well the codebook vector matches the encoder output
commitment_loss: Tensor- Measures how well the encoder output matches the codebook vectors
VectorQuantizer
The VectorQuantizer implements one layer of the codebook and is the meat of the logic for training RQVAE. It will be stacked together layer multiple times to form the codebook.
- At initialization:
- Initializes with
RQVAEConfigto hold parameters - Initialize
self.embeddingto an embedding of sizecodebook_size=256, codebook_embedding_dim=32- Uniform initialization
self.embedding.weight.data.uniform_(-1 / self.codebook_size, 1 / self.codebook_size)
- Uniform initialization
- Registers some buffers for tracking codebook usage:
self.register_buffer("usage_count", torch.zeros(self.codebook_size))self.register_buffer("update_count", torch.tensor(0))
- Initializes with
find_nearest_codes(x):- Takes an input vector
x, compares it to all vectors in the codebook, and returns the nearest one - Simply uses
torch.cdistto compute distances, thentorch.argminto get the nearest - Returns a tuple of torch tensors:
- The nearest index (i.e. codeword) to
x - The quantized embedding at the index position
- The nearest index (i.e. codeword) to
- Takes an input vector
forward(x)->QuantizationOutput:- Finds the nearest index and quantized embeddings for a batch of
x- Call
find_nearest_codesto getindicesandquantized
- Call
- Applies the gradient estimator to get
quantized_st- This will be used for gradient backprop to the encoder later
apply_gradient_estimatoreither uses the straight through or rotation method
- Compute losses:
codebook_lossis the MSE loss betweenx.detach()andquantized- We want to pull the codebook embeddings toward
x
- We want to pull the codebook embeddings toward
commitment_lossis the MSE loss betweenxandquantized.detach()- We want to pull encoder output toward codebook embeddings
loss = codebook_loss + beta * commitment_loss
- Everything is packaged into
QuantizationOutputand returned self.update_usageis also called:- Updates counts of which indices were the nearest to
x - Updates the number of training steps
- Updates counts of which indices were the nearest to
- Finds the nearest index and quantized embeddings for a batch of
- Straight through
- The straight through gradient estimator simply returns
x + (quantized - x).detach() - Essentially, the embeddings passed forward is
quantized - But the vector used for gradient backprop is
x(hence straight-through back to the encodedx) - This is a naive method but works well enough
- The straight through gradient estimator simply returns
- Rotation
- The problem with the straight through estimator is that we use
quantizedfor the forward pass but usexfor the backward pass- This can be problematic especially if
qandxare far apart
- This can be problematic especially if
- The rotation idea is to apply a rotation to
xuntil it aligns withq- Since the rotation is differentiable, we get better gradients back to
x
- Since the rotation is differentiable, we get better gradients back to
- We compute (to check later):
- Let
- is the halfway vector between
uandq
- The problem with the straight through estimator is that we use
reset_unused_codes- Look up
self.usage_countto find unused indices (used0times) - Take the current batch of encoded data, and randomly select them to become the new codebook vectors
- This makes it likely for them to be used in the next forward pass since they correspond to actual encoder outputs
- All usage counters are reset after this
- Look up
RQVAE
The RQVAE class now assembles multiple VectorQuantizer into the actual VAE to create semantic IDs.
At initialization:
self.encoder: a simple MLP that shrinks the input embedding down to the codebook dimension- In this code, we go from
1024 -> 512 -> 256 -> 128 -> 32
- In this code, we go from
self.decoder: a simple MLP that goes backward from quantized vector up to embedding dimension- In this code, the decoder dims are just the reversed of the encoder dims
- So we go from
32 -> 128 -> 256 -> 512 -> 1024
- Both encoder and decoder are wrapped in
nn.Sequential self.vq_layerscontains theVectorQuantizers- It is an
nn.ModuleListof3VectorQuantizers
- It is an
forward: the main magic of this class- First we encode input item embedding
z = self.encode(x) - Also init
residual = z - Init
quantized_out = torch.zeros_like(z)- The quantization output will be the sum of
- Now we run a for loop through the vector quantizer layers:
- First we compute the quantization output for this level (which contains the mapped ID for this level etc.)
vq_output: QuantizationOutput = vq_layer(residual)
- Then we update the residual by subtracting the nearest codebook vector
residual -= vq_output.quantized.detach()
- We accumulate the codebook vectors (with gradients) into
quantized_outquantized_out += vq_output.quantized_st- Recall that the final representation passed to the decoder is
- We also accumulate the loss for each layer
vq_loss += vq_output.loss- This is the codebook + commitment loss, reconstruction loss comes later
- First we compute the quantization output for this level (which contains the mapped ID for this level etc.)
- Finally we get the total loss
- Compute the reconstruction loss
x_recon = self.decode(quantized_out)recon_loss = F.mse_loss(x_recon, x)loss = recon_loss + vq_loss
- Compute the reconstruction loss
- First we encode input item embedding
encode_to_semantic_ids: encodes an item embeddingxto an integer tensor representing its semantic IDdecode_from_semantic_ids: decodes an integer tensorsemantic_idsby looking up the codebook, summing up the levels and passing back into thedecoderkmeans_init- Runs kmeans on one batch of embeddings to initialize the codebook vectors
- Runs kmeans to get
256centroid vectors - Copies these vectors into the codebook directly
- Process layer by layer
finetune_qwen3_8b_vocab.py
This script performs Stage 1 of the qwen fine-tuning. It focuses on extending the vocabulary to include new semantic ID tokens and trains embeddings for these new tokens.
FineTuneConfig
Dataclass containing config for the training
model_name:unsloth/Qwen3-8B- Qn: Not instruction fine tuned?
load_in_4bit: Set toFalsefor embedding trainingload_in_8bit: Set toFalsenum_proc:32enable_thinking:Falsewe don't need thinking modeextend_vocabulary:Truecodebook_levels:4codebook_size:256num_semantic_tokens:1024system_prompt: see belowmax_training_samples:32000limit for training embeddinglearning_rate:1e-3batch_size:32max_steps:1000
The system prompt is as follows:
"You are a helpful AI assistant that understands and works with semantic IDs for product recommendations. Semantic IDs are hierarchical identifiers in the format
<|sid_start|><|sid_105|><|sid_307|><|sid_705|><|sid_769|><|sid_end|>that encode product relationships and categories. /no_think"
extend_tokenizer
extend_tokenizer(model, tokenizer, config: FineTuneConfig) adds semantic ID tokens to the tokenizer using Unsloth's add_new_tokens.
- Note that the vocab size affects two places:
model.get_input_embeddings().weight: the input embeddingsmodel.get_output_embeddings().weight: the language model head which predicts the next token
- First, we make sure that the vocab size of the tokenizer matches the vocab size of both the input and output embeddings
- We need to call
model.resize_token_embeddingsto get the model embedding sizes to match thetokenizer - This is because the model embeddings are padded to be multiples of
128for CUDA optimization reasons
- We need to call
- Next, we add new tokens using
unsloth.add_new_tokens:- Special tokens of
<|rec|>,<|sid_start|>,<|sid_end|> - Semantic IDs of
<|sid_0|>to<|sid_1023|>
- Special tokens of
prepare_model
Prepares the model for training with some additional checks:
- Freezes gradients for all parameters
- Unfreezes only the weights for the
model.get_input_embeddings()andmodel.get_output_embeddings() - Checks the trainable parameter %
load_sid_dataset
Loads the semantic IDs training dataset:
- Checks if there are texts like
<|sid_start|>to make sure processing is correct - Applies chat template to the rows (but keeps as text)
There are 5 distinct categories of training data:
- SemanticID -> text:
- Input: "Product
<|sid_start|>...<|sid_end|>has title:" - Output: "Super Mario Bros"
- Variations: ID to title, description, category, features or full context
- Input: "Product
- Text -> SemanticID:
- Input: "The product Super Mario Bros has SemanticID:"
- Output: "
<|sid_start|>...<|sid_end|>" - Variations: Similar variations to above
- Sequential Recommendation:
- Input: "Based on recent purchases etc., next item:"
- Output: "
<|sid_start|>...<|sid_end|>" - Variations: Various sequence lengths of 2, 3, or 5 items.
- Semantic Understanding:
- Input: "Products starting with
<|sid_start|><|sid_64|> are typically: - Output: "Nintendo switch games"
- Variations: Prefix to category, prefix to examples, similar items.
- Input: "Products starting with
- Multi-hop Reasoning:
- Input: "A user who bought
<|sid_a|>might also buy:" - Output: "
<|sid_b|> - Variations: Co-purchase patterns.
- Input: "A user who bought
DataInspectionCallback
Used to inspect training data and tokenization at each training step, by simply logging them to console.
Patterns:
DataInspectionCallbacksubclassestransformers.TrainerCallbackon_train_begin(self, args, state, control, **kwargs):- Checks the first batch of
train_dataloader - Checks batch keys
- Check shape of
batch['input_ids'] - Check shape of
batch['attention_mask'] - Check tokens and decoded of first row etc.
- Checks the first batch of
on_log(self, args, state, control, logs=None, **kwargs):- Only runs if
state.global_step % args.logging_steps == 0 - Check number of SID tokens
- Decode first example and check
- Only runs if
EmbeddingMonitorCallback
This callback aims to check how our embeddings are shifting over time.
- At initialization (or
on_train_begin), we copy the state of the initial embeddings and clone detach them - At each step, we compute the mean of the absolute difference between the current embeddings and the initial or state of embeddings from the previous step
- We also compute the per level codebook vector means etc.
- These are all logged to
wandb
SemanticIDGenerationCallback
This is a qualitative check to answer the question "If I ask the model for a recommendation right now, does it use the semantic ID tokens or does it just output plain text?"
- A fixed set of test cases are used
- The test cases are
apply_chat_template, then passed into thetokenizer, thenmodel.generateandmodel.decode - The messages are checked whether successful (SemanticIDs generated) and success rate is tracked
- Actual completion examples are logged into wandb as well
train_embeddings
The main method. Essentially we are just using unsloth SFTTrainer to do the training.
First, we set up trl.SFTConfig with a lot of the configuration we previously defined.
- Note that
dataset_text_field="text" report_to="wandb"
The trainer trl.SFTTrainer is initialized with the model, tokenizer, datasets, config and callbacks.
Then, trainer.train() is called.
Note that the model and tokenizer are initialized using unsloth.FastLanguageModel to use unsloth's optimized triton kernels.
finetune_qwen3_8b_full.py
The code is structurally very similar to the vocab finetuning run. The difference is that we are doing full training, so we unfreeze all parameters. Consequently, the learning rate needs to be much lower at 2e-5.
- Load the model from stage 1, namely
models/qwen3_8b_vocab_extended/final - A lot of the script focuses on the callbacks to evaluate recommendation quality