Skip to main content

Overview

Phoenix is the machine learning system that powers both retrieval (finding relevant candidates from millions of posts) and ranking (scoring and ordering candidates by predicted engagement). The system uses transformer-based architectures adapted from the Grok-1 open source release by xAI, with custom input embeddings and attention masking designed specifically for recommendation systems.
The code is representative of the production model with the exception of specific scaling optimizations.

Two-Stage Architecture

Phoenix operates in two distinct stages:

Retrieval

Two-Tower ModelNarrows millions of posts to hundreds using approximate nearest neighbor search
  • User Tower: Encodes user + engagement history
  • Candidate Tower: Encodes all posts
  • Similarity: Dot product for top-K selection

Ranking

Transformer with Candidate IsolationScores retrieved candidates using full transformer
  • Input: User context + candidate posts
  • Attention: Candidates isolated from each other
  • Output: Probabilities for multiple engagement types
┌──────────────────────────────────────────────────────────────┐
│              PHOENIX RECOMMENDATION PIPELINE                 │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  ┌────────┐   ┌─────────────────┐   ┌──────────────────┐    │
│  │  User  │──▶│  STAGE 1:       │──▶│  STAGE 2:        │──▶ │
│  │Request │   │  RETRIEVAL      │   │  RANKING         │    │
│  └────────┘   │  (Two-Tower)    │   │  (Transformer)   │    │
│               │  Millions→1000s │   │  1000s→Ranked    │    │
│               └─────────────────┘   └──────────────────┘    │
│                                                              │
└──────────────────────────────────────────────────────────────┘

Retrieval: Two-Tower Model

The retrieval stage efficiently finds relevant candidates from a massive corpus.

Architecture

User Tower encodes user features and engagement history:
phoenix/recsys_retrieval_model.py
class UserTower(hk.Module):
    """User tower that processes engagement history into user representation."""
    
    def __call__(self, batch, embeddings):
        # Combine user and history embeddings
        user_embedding, user_mask = block_user_reduce(
            batch.user_hashes,
            embeddings.user_embeddings,
            num_user_hashes=self.hash_config.num_user_hashes,
            emb_size=self.transformer_config.emb_size,
        )
        
        history_embedding, history_mask = block_history_reduce(
            batch.history_post_hashes,
            embeddings.history_post_embeddings,
            # ... encode posts, authors, actions
        )
        
        # Pass through transformer
        transformer_input = jnp.concatenate([user_embedding, history_embedding], axis=1)
        output = self.transformer(transformer_input, padding_mask)
        
        # Extract and normalize user representation
        user_representation = output[:, 0, :]  # First position = user
        return normalize(user_representation)
Candidate Tower projects post + author embeddings:
class CandidateTower(hk.Module):
    """Candidate tower that encodes posts into shared embedding space."""
    
    def __call__(self, post_author_embedding):
        # Two-layer MLP with SiLU activation
        hidden = jnp.dot(post_author_embedding, proj_1)
        hidden = jax.nn.silu(hidden)
        candidate_embeddings = jnp.dot(hidden, proj_2)
        
        # L2 normalization
        return normalize(candidate_embeddings)
Once both towers produce normalized embeddings:
  1. Index building: All posts encoded offline into [N, D] matrix
  2. Query encoding: User tower produces [B, D] embedding at request time
  3. Top-K retrieval: Dot product similarity → select top candidates
# Similarity scores via dot product (since normalized, this is cosine similarity)
scores = user_representation @ candidate_embeddings.T  # [B, N]
top_k_indices = jnp.argsort(scores, axis=-1)[..., -K:]  # Top K
Because both representations are L2-normalized, dot product equals cosine similarity, enabling efficient approximate nearest neighbor search with libraries like FAISS or ScaNN.

Ranking: Transformer with Candidate Isolation

The ranking model scores the retrieved candidates using a full transformer architecture with a critical design choice: candidates cannot attend to each other.

Model Architecture

phoenix/recsys_model.py
class RecsysModel(hk.Module):
    """Recommendation model for ranking candidates."""
    
    def __call__(self, batch, embeddings):
        # 1. Reduce hash embeddings
        user_embedding, user_mask = block_user_reduce(...)
        history_embedding, history_mask = block_history_reduce(...)
        candidate_embedding, candidate_mask = block_candidate_reduce(...)
        
        # 2. Concatenate sequence: [user, history, candidates]
        transformer_input = jnp.concatenate([
            user_embedding,      # [B, 1, D]
            history_embedding,   # [B, S, D]
            candidate_embedding, # [B, C, D]
        ], axis=1)
        
        # 3. Create candidate isolation mask
        padding_mask = create_candidate_isolation_mask(
            user_mask, history_mask, candidate_mask
        )
        
        # 4. Transform with special attention masking
        outputs = self.transformer(
            transformer_input,
            padding_mask=padding_mask,
        )
        
        # 5. Extract candidate outputs and predict actions
        candidate_outputs = outputs[:, history_len:, :]  # [B, C, D]
        logits = self.unembedding_layer(candidate_outputs)  # [B, C, num_actions]
        
        return RecsysModelOutput(logits=logits)

Candidate Isolation Mask

The attention mask ensures candidates only attend to user/history, never to each other:
                  Keys (what we attend TO)
     ────────────────────────────────────────────────▶

     │ User │  History (S)  │   Candidates (C)   │
 ┌───┼──────┼───────────────┼────────────────────┤
 │   │      │               │                    │
 │ U │  ✓   │  ✓   ✓   ✓   │  ✗   ✗   ✗   ✗     │
 │   │      │               │                    │
 ├───┼──────┼───────────────┼────────────────────┤
Q│   │      │               │                    │
u│ H │  ✓   │  ✓   ✓   ✓   │  ✗   ✗   ✗   ✗     │
e│ i │  ✓   │  ✓   ✓   ✓   │  ✗   ✗   ✗   ✗     │
r│ s │  ✓   │  ✓   ✓   ✓   │  ✗   ✗   ✗   ✗     │
i│   │      │               │                    │
e├───┼──────┼───────────────┼────────────────────┤
s│   │      │               │ Diagonal only!     │
││ C │  ✓   │  ✓   ✓   ✓   │  ✓   ✗   ✗   ✗     │
││ a │  ✓   │  ✓   ✓   ✓   │  ✗   ✓   ✗   ✗     │
▼│ n │  ✓   │  ✓   ✓   ✓   │  ✗   ✗   ✓   ✗     │
 │ d │  ✓   │  ✓   ✓   ✓   │  ✗   ✗   ✗   ✓     │
 │   │      │               │                    │
 └───┴──────┴───────────────┴────────────────────┘

 ✓ = Can attend (1)    ✗ = Cannot attend (0)
Why Candidate Isolation?Candidates are prevented from attending to each other to ensure score independence: the score for a post doesn’t depend on which other posts are in the batch. This makes scores consistent and cacheable across different batches.

Multi-Action Prediction

The model predicts probabilities for multiple engagement types simultaneously:
# Output shape: [B, num_candidates, num_actions]
logits = model(batch, embeddings)

# Actions predicted:
P(favorite)
P(reply)
P(repost)
P(quote)
P(click)
P(profile_click)
P(video_view)
P(photo_expand)
P(share)
P(dwell)
P(follow_author)
P(not_interested)  # Negative signal
P(block_author)    # Negative signal
P(mute_author)     # Negative signal
P(report)          # Negative signal

Hash-Based Embeddings

Both retrieval and ranking use multiple hash functions for embedding lookup:
class HashConfig:
    num_user_hashes: int = 2    # User ID → 2 hash functions
    num_item_hashes: int = 2    # Post ID → 2 hash functions
    num_author_hashes: int = 2  # Author ID → 2 hash functions
Each entity is hashed multiple times, and the resulting embeddings are combined:
def block_user_reduce(user_hashes, user_embeddings, ...):
    # user_hashes: [B, num_user_hashes]
    # user_embeddings: [B, num_user_hashes, D]
    
    # Learn a projection to combine hash embeddings
    projection = hk.get_parameter(...)
    combined = apply_projection(user_embeddings, projection)
    
    return combined  # [B, 1, D]
Multiple hash functions provide better representation capacity and collision resistance compared to a single hash table.

Integration with Home Mixer

Phoenix Source (Retrieval)

home-mixer/sources/phoenix_source.rs
pub struct PhoenixSource {
    pub phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient>,
}

impl Source<ScoredPostsQuery, PostCandidate> for PhoenixSource {
    async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec<PostCandidate>> {
        let sequence = query.user_action_sequence.as_ref()?;
        
        let response = self.phoenix_retrieval_client
            .retrieve(query.user_id, sequence.clone(), MAX_RESULTS)
            .await?;
        
        let candidates = response.top_k_candidates
            .into_iter()
            .map(|c| PostCandidate { 
                tweet_id: c.tweet_id,
                author_id: c.author_id,
                // ...
            })
            .collect();
        
        Ok(candidates)
    }
}

Phoenix Scorer (Ranking)

home-mixer/scorers/phoenix_scorer.rs
pub struct PhoenixScorer {
    pub phoenix_client: Arc<dyn PhoenixPredictionClient>,
}

impl Scorer<ScoredPostsQuery, PostCandidate> for PhoenixScorer {
    async fn score(&self, query: &ScoredPostsQuery, candidates: &[PostCandidate]) 
        -> Result<Vec<PostCandidate>> {
        
        let tweet_infos = candidates.iter().map(|c| TweetInfo {
            tweet_id: c.tweet_id,
            author_id: c.author_id,
            // ...
        }).collect();
        
        let response = self.phoenix_client
            .predict(query.user_id, sequence, tweet_infos)
            .await?;
        
        // Extract predictions and update candidates
        let scored_candidates = candidates.iter().map(|c| {
            let phoenix_scores = extract_scores(&response, c.tweet_id);
            PostCandidate {
                phoenix_scores,
                ..c.clone()
            }
        }).collect();
        
        Ok(scored_candidates)
    }
}

Running the Code

The repository includes example code demonstrating both retrieval and ranking:
uv run run_ranker.py

Key Design Decisions

Prevents the score for a candidate from depending on which other candidates are in the batch. This ensures:
  • Consistent scores across different batches
  • Ability to cache predictions
  • Simpler debugging and analysis
Multiple hash functions provide:
  • Better representation capacity than single lookup
  • Collision resistance for large entity spaces
  • Memory efficiency compared to explicit embedding tables
Rather than predicting a single “relevance” score, the model predicts probabilities for many actions:
  • Captures nuanced user preferences
  • Enables flexible weighting strategies
  • Incorporates negative signals (block, mute, report)
  • Retrieval: Fast, approximate search over millions of items
  • Ranking: Expensive, precise scoring for hundreds of items
  • This separation enables scaling to large corpora while maintaining quality

Performance

Typical Latencies
  • Retrieval (Two-Tower): ~20-50ms for top-1000 from millions
  • Ranking (Transformer): ~50-100ms for scoring 500 candidates
  • Total Phoenix latency: ~70-150ms

Home Mixer

Orchestration layer that uses Phoenix for candidate sourcing and scoring

Thunder

Provides in-network candidates to complement Phoenix’s out-of-network retrieval

Candidate Pipeline

Framework that integrates Phoenix into the overall recommendation flow

Build docs developers (and LLMs) love