Skip to main content

Overview

Mamba is a selective state space model that uses input-dependent dynamics to filter and process sequences. Unlike LTI models with fixed dynamics, Mamba’s state transition matrices (A, B, C) are computed from the input at each timestep, enabling selective memory and forgetting. Mamba achieves state-of-the-art performance on language modeling while maintaining linear scaling with sequence length.

Paper Reference

Mamba: Linear-Time Sequence Modeling with Selective State Spaces GitHub: https://github.com/state-spaces/mamba Original paper: https://arxiv.org/abs/2312.00752

Installation

from lrnnx.models.ltv import Mamba

Parameters

d_model
int
required
Model dimension - size of input and output features.
d_state
int
default:"16"
SSM state dimension (N). Typically smaller than LTI models (16-64 vs 64-256).
d_conv
int
default:"4"
Convolution kernel size for temporal mixing before SSM. Usually 3-4.
expand
int
default:"2"
Expansion factor for inner dimension. d_inner = expand * d_model.
dt_rank
Union[int, str]
default:"'auto'"
Rank for delta (timestep) projection. 'auto' sets it to ceil(d_model / 16).
dt_min
float
default:"0.001"
Minimum value for delta initialization.
dt_max
float
default:"0.1"
Maximum value for delta initialization.
dt_init
str
default:"'random'"
Delta initialization method: 'random' or 'constant'.
dt_scale
float
default:"1.0"
Scale factor for dt initialization.
dt_init_floor
float
default:"1e-4"
Minimum floor value for dt initialization.
conv_bias
bool
default:"True"
Whether to use bias in the Conv1d layer.
bias
bool
default:"False"
Whether to use bias in linear projections.
use_fast_path
bool
default:"True"
Whether to use fused CUDA kernels when available. Significantly faster.
layer_idx
int
default:"None"
Layer index for multi-layer caching in inference.
discretization
str
default:"'mamba'"
Discretization method: 'mamba', 'zoh', 'bilinear', or 'dirac'.
device
torch.device
default:"None"
Device for model parameters.
dtype
torch.dtype
default:"None"
Data type for model parameters.

Usage Example

Basic Usage

import torch
from lrnnx.models.ltv import Mamba

# Create Mamba model
model = Mamba(d_model=64, d_state=16, d_conv=4)

# Forward pass
x = torch.randn(2, 128, 64)  # (batch, length, features)
y = model(x)

print(y.shape)  # torch.Size([2, 128, 64])

Language Modeling Configuration

import torch
from lrnnx.models.ltv import Mamba

# Typical configuration for language modeling
model = Mamba(
    d_model=768,
    d_state=16,
    d_conv=4,
    expand=2,
    dt_rank="auto",
    use_fast_path=True,
)

x = torch.randn(4, 2048, 768)  # (batch, seq_len, d_model)
y = model(x)

Autoregressive Inference

import torch
from lrnnx.models.ltv import Mamba

model = Mamba(d_model=256, d_state=16)
batch_size = 2
max_seqlen = 1024

# Allocate inference cache
cache = model.allocate_inference_cache(
    batch_size=batch_size,
    max_seqlen=max_seqlen
)

# Initialize seqlen_offset
cache["seqlen_offset"] = 0

# Generate sequence token-by-token
for t in range(100):
    x_t = torch.randn(batch_size, 1, 256)  # (B, 1, D)
    y_t, cache = model.step(x_t, cache)
    # y_t.shape: (batch_size, 1, 256)

Event-Based Processing (Async Mode)

import torch
from lrnnx.models.ltv import Mamba

# Mamba with event-based discretization
model = Mamba(d_model=64, d_state=16, discretization="mamba")

# Input with variable timesteps
x = torch.randn(2, 128, 64)
timesteps = torch.rand(2, 128) * 0.1  # Variable time intervals

# Forward with integration_timesteps
y = model(x, integration_timesteps=timesteps)

Key Features

Selective State Space

Mamba’s core innovation is input-dependent selection:
# Project input to get B, C, and delta
B, C, delta = project(x)  # All depend on input!

# Selective scan
y = selective_scan(x, delta, A, B, C, D)
This allows the model to:
  • Focus on relevant information
  • Forget irrelevant details
  • Adapt dynamics per timestep

Hardware-Efficient Design

  1. Conv1d: Short temporal convolution for local mixing
  2. SSM: Selective state space for long-range dependencies
  3. Gating: Multiplicative gating for expressiveness
# Mamba architecture
x, z = split(input_proj(x))  # Dual pathways
x = conv1d(x)                # Local mixing
x = ssm(x)                   # Selective SSM
y = x * silu(z)              # Gating
out = output_proj(y)         # Output

S4D Initialization

Mamba uses S4D-style initialization for A:
A = -torch.arange(1, d_state + 1)  # Diagonal: [-1, -2, ..., -N]
A_log = log(A)
This encourages exponential decay patterns that work well for selective processing.

Fast Path (CUDA Kernels)

When use_fast_path=True and CUDA kernels are available:
# Fused kernel combines:
# - Conv1d
# - SSM projection
# - Selective scan  
# - Gating
# - Output projection
out = mamba_inner_fn(...)  # Single fused operation
This is significantly faster than the PyTorch fallback.

Architecture Details

Forward Pass Structure

  1. Input Projection: x, z = split(in_proj(x))
    • Creates two pathways: SSM and gate
  2. Conv1d: x = conv1d(x)
    • Short temporal convolution (d_conv=4)
    • Causal (no future information)
  3. SSM Projection: delta, B, C = x_proj(x)
    • Project to get input-dependent parameters
    • delta: (B, D, L) - per-channel timesteps
    • B: (B, N, L) - input matrix
    • C: (B, N, L) - output matrix
  4. Selective Scan: y = selective_scan(x, delta, A, B, C, D)
    • Core SSM computation
    • Uses input-dependent B, C, delta
    • Fixed A (learned, but not input-dependent)
  5. Gating: y = y * silu(z)
    • Multiplicative gating with SiLU
  6. Output: out = out_proj(y)
    • Final linear projection

Discretization Methods

Mamba supports multiple discretization schemes:

Mamba (Default)

dA = exp(delta * A)
dB = delta * B
Simplest, works well for most tasks.

Zero-Order Hold (ZOH)

dA = exp(delta * A)
dB = (dA - 1) / A * B
More principled, better for control theory tasks.

Bilinear

v = 0.5 * delta * A
dA = (1 + v) / (1 - v)
dB = delta / (1 - v) * B
Better frequency preservation.

Event-Based Mode (Async)

When integration_timesteps is provided:
# Asymmetric discretization
dtA = integration_timesteps * softplus(dtA_proj.bias)
dtB = softplus(dt_proj(x) + dt_proj.bias)

# Separate time scales for A and B
dA = exp(dtA * A)
dB = dtB * B
This enables event-driven processing with variable time intervals.

State Representation

Mamba maintains:
conv_state: (batch_size, d_inner, d_conv) dtype=float
lrnn_state: (batch_size, d_inner, d_state) dtype=float
seqlen_offset: int
  • conv_state: Last d_conv timesteps for causal conv
  • lrnn_state: SSM hidden state
  • seqlen_offset: Current position (for autoregressive)

Performance Tips

Always use use_fast_path=True in production. The CUDA kernels are 3-5x faster than PyTorch fallback.
Make sure to install the optional CUDA dependencies (causal-conv1d and mamba-ssm) for best performance:
pip install causal-conv1d>=1.1.0
pip install mamba-ssm
For language modeling, typical values are:
  • d_state=16 (small state)
  • d_conv=4 (short convolution)
  • expand=2 (2x expansion)
These balance performance and efficiency.

When to Use Mamba

Use Mamba when:
  • You need selective processing (focus on important info)
  • Working on language modeling or NLP
  • You want state-of-the-art performance
  • Long sequences are common
  • You have GPU/CUDA available
Consider alternatives when:
  • Training speed is critical → S4D (faster training)
  • Minimal parameters needed → LRU
  • Simpler model preferred → S5 or RG-LRU

Comparison with Other Models

ModelTypeSelectiveSpeed (Train)Speed (Infer)Performance
MambaLTV⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
S4DLTI⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
RG-LRULTV⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
S7LTV⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐
Mamba offers the best accuracy-efficiency tradeoff for selective tasks.

See Also

  • RG-LRU - Simpler gated selective model
  • S7 - Alternative selective SSM
  • S4D - Non-selective but faster training

Build docs developers (and LLMs) love