Skip to main content

Overview

The MHA class implements multi-head self-attention with several advanced features:
  • Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)
  • Optional 1D convolution for local context modeling
  • Rotary position embeddings (RoPE)
  • Integrated MLP for efficiency
  • Optimized KV caching for inference
  • Flash Attention support for faster computation

Class Definition

from lrnnx.layers.mha import MHA

mha = MHA(
    embed_dim=768,
    num_heads=12,
    num_heads_kv=None,
    causal=True,
    rotary_emb_dim=64
)

Parameters

embed_dim
int
required
Embedding dimension of the input.
num_heads
int
required
Number of attention heads for queries.
num_heads_kv
int
default:"None"
Number of key-value heads for Multi-Query Attention (MQA) or Grouped-Query Attention (GQA). If None, uses num_heads for standard multi-head attention. Must divide num_heads evenly.
head_dim
int
default:"None"
Dimension per attention head. If None, uses embed_dim // num_heads. Allows for non-standard head dimensions.
mlp_dim
int
default:"0"
Dimension of integrated MLP (gated MLP with SiLU activation). If 0, no MLP is used. The dimension is rounded up to the nearest multiple of 256.
qkv_proj_bias
bool
default:"True"
Whether to include bias terms in the QKV projection layer.
out_proj_bias
bool
default:"True"
Whether to include bias term in the output projection layer.
softmax_scale
float
default:"None"
Scale factor for attention scores before softmax. If None, uses 1/sqrt(head_dim) as per the standard Transformer.
causal
bool
default:"False"
Whether to use causal (masked) attention. Set to True for autoregressive models.
layer_idx
int
default:"None"
Layer index for KV caching during inference. Required when using inference mode.
d_conv
int
default:"0"
Kernel size for 1D causal convolution applied to QKV before attention. If 0, no convolution is used. Adds local inductive bias.
rotary_emb_dim
int
default:"0"
Dimension for rotary position embeddings (RoPE). If 0, no rotary embeddings are used. Typically set to head_dim or a fraction like head_dim // 2.
rotary_emb_base
float
default:"10000.0"
Base value for computing rotary embeddings frequencies. Higher values result in slower position encoding decay.
rotary_emb_interleaved
bool
default:"False"
Whether to use interleaved rotary embeddings format. If False, uses the standard format.
device
torch.device
default:"None"
Device to place tensors on (e.g., torch.device('cuda')).
dtype
torch.dtype
default:"None"
Data type for tensors (e.g., torch.float16, torch.bfloat16).

Methods

forward

output = mha.forward(x, inference_params=None)
Perform multi-head attention computation.

Parameters

x
torch.Tensor
required
Input tensor of shape (batch_size, seq_len, embed_dim).
inference_params
Any
default:"None"
Parameters for inference mode. Should contain:
  • key_value_memory_dict: Dictionary mapping layer indices to KV caches
  • seqlen_offset: Current sequence position offset
  • max_seqlen: Maximum sequence length
  • lengths_per_sample: Per-sample sequence lengths (optional)

Returns

output
torch.Tensor
Output tensor of shape (batch_size, seq_len, embed_dim).

allocate_inference_cache

kv_cache, conv_state = mha.allocate_inference_cache(
    batch_size=1,
    max_seqlen=2048,
    dtype=torch.float16
)
Allocate cache for efficient autoregressive inference.

Parameters

batch_size
int
required
Batch size for inference.
max_seqlen
int
required
Maximum sequence length for inference.
dtype
torch.dtype
default:"None"
Data type for cache tensors. If None, uses the output projection weight dtype.

Returns

kv_cache
torch.Tensor
Tensor of shape (batch_size, max_seqlen, 2, num_heads_kv, head_dim) for storing key-value states.
conv_state
torch.Tensor | None
Tensor of shape (batch_size, qkv_dim, d_conv) for convolution state, or None if d_conv=0.

Usage Examples

Basic Multi-Head Attention

import torch
from lrnnx.layers.mha import MHA

# Standard multi-head attention
mha = MHA(
    embed_dim=768,
    num_heads=12,
    causal=True
)

x = torch.randn(2, 128, 768)  # (batch, seq_len, embed_dim)
output = mha(x)

Multi-Query Attention (MQA)

# MQA with single KV head
mqa = MHA(
    embed_dim=768,
    num_heads=12,
    num_heads_kv=1,  # All query heads share one KV head
    causal=True
)

output = mqa(x)

Grouped-Query Attention (GQA)

# GQA with 4 KV heads for 12 query heads
gqa = MHA(
    embed_dim=768,
    num_heads=12,
    num_heads_kv=4,  # Each KV head shared by 3 query heads
    causal=True
)

output = gqa(x)

With Rotary Embeddings

# Attention with RoPE
mha_rope = MHA(
    embed_dim=768,
    num_heads=12,
    rotary_emb_dim=64,  # Apply RoPE to first 64 dims
    rotary_emb_base=10000.0,
    causal=True
)

output = mha_rope(x)

With Local Convolution

# Attention with local context via 1D convolution
mha_conv = MHA(
    embed_dim=768,
    num_heads=12,
    d_conv=4,  # 4-element causal convolution kernel
    causal=True
)

output = mha_conv(x)

With Integrated MLP

# Attention with fused gated MLP
mha_mlp = MHA(
    embed_dim=768,
    num_heads=12,
    mlp_dim=2048,  # Add gated MLP with 2048 hidden dim
    causal=True
)

output = mha_mlp(x)

Inference with KV Caching

from dataclasses import dataclass

@dataclass
class InferenceParams:
    key_value_memory_dict: dict
    seqlen_offset: int
    max_seqlen: int
    batch_size_offset: int = 0
    lengths_per_sample: torch.Tensor = None

mha = MHA(
    embed_dim=768,
    num_heads=12,
    causal=True,
    layer_idx=0  # Required for caching
)

# Allocate cache
kv_cache, _ = mha.allocate_inference_cache(
    batch_size=1,
    max_seqlen=2048,
    dtype=torch.float16
)

# Setup inference params
inference_params = InferenceParams(
    key_value_memory_dict={0: (kv_cache, None)},
    seqlen_offset=0,
    max_seqlen=2048
)

# First forward pass (prefill)
prompt = torch.randn(1, 10, 768)
output = mha(prompt, inference_params=inference_params)

# Subsequent forward passes (decode one token at a time)
inference_params.seqlen_offset = 10
for step in range(100):
    next_token = torch.randn(1, 1, 768)
    output = mha(next_token, inference_params=inference_params)
    inference_params.seqlen_offset += 1

Architecture Details

Multi-Query and Grouped-Query Attention

  • Standard MHA: Each head has its own Q, K, V (memory intensive)
  • MQA (num_heads_kv=1): All query heads share one K, V pair (memory efficient)
  • GQA (num_heads_kv=k): Query heads are grouped, each group shares K, V (balanced trade-off)

Integrated MLP

When mlp_dim > 0, the module includes a gated MLP:
QKV, MLP_input = in_proj(x).split(...)
MLP_up, MLP_gate = MLP_input.chunk(2)
MLP_out = MLP_up * SiLU(MLP_gate)
final_output = out_proj(concat([attention_out, MLP_out]))
This fusion can be more efficient than separate attention and MLP blocks.

Rotary Position Embeddings

RoPE encodes position information by rotating query and key representations:
  • Applied only to the first rotary_emb_dim dimensions
  • Allows extrapolation to longer sequences than seen during training
  • No learned parameters required

Performance Considerations

  • Flash Attention: Automatically used when available for faster computation and lower memory usage
  • KV Caching: Essential for efficient autoregressive generation
  • MQA/GQA: Reduces KV cache size and memory bandwidth requirements
  • Fused Operations: Convolution and rotary embeddings can be fused with attention computation

Notes

  • Requires flash_attn package for optimal performance
  • When using rotary embeddings, flash_attn is required
  • Causal convolution requires causal_conv1d package for best performance
  • The layer_idx parameter must be set when using inference mode with caching

See Also

  • Block - Wrapper for combining MHA with normalization and residuals
  • GatedMLP - Standalone MLP implementation

Build docs developers (and LLMs) love