Skip to main content
CUTLASS Python packages provide seamless integration with PyTorch, enabling you to use CUTLASS kernels directly with PyTorch tensors and build custom CUDA extensions.

Overview

Direct Integration

Use CUTLASS kernels with PyTorch tensors without conversion

CUDA Extensions

Export kernels as PyTorch CUDA extensions for deployment

Custom Autograd

Implement custom backward passes for training

Stream Integration

Work with PyTorch CUDA streams

Using PyTorch Tensors

With cutlass_cppgen

The CUTLASS Python interface accepts PyTorch tensors directly:
import cutlass
import torch

# Create PyTorch tensors on GPU
M, N, K = 2048, 2048, 2048
A = torch.randn(M, K, dtype=torch.float16, device='cuda')
B = torch.randn(K, N, dtype=torch.float16, device='cuda')
C = torch.zeros(M, N, dtype=torch.float16, device='cuda')
D = torch.zeros(M, N, dtype=torch.float16, device='cuda')

# Create and run GEMM plan
plan = cutlass.op.Gemm(
    element=torch.float16,
    layout=cutlass.LayoutType.RowMajor
)
plan.run(A, B, C, D)  # D = A @ B + C

# Verify against PyTorch
ref = torch.matmul(A, B) + C
torch.testing.assert_close(D, ref, rtol=1e-3, atol=1e-3)
print("Results match!")

With CuTe DSL

Convert PyTorch tensors using DLPack:
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import torch

# Create PyTorch tensors
A = torch.randn(1024, 512, device='cuda', dtype=torch.float32)
B = torch.randn(1024, 512, device='cuda', dtype=torch.float32)
C = torch.zeros(1024, 512, device='cuda', dtype=torch.float32)

# Convert to CuTe tensors
mA = from_dlpack(A).mark_layout_dynamic()
mB = from_dlpack(B).mark_layout_dynamic()
mC = from_dlpack(C).mark_layout_dynamic()

# Use in CuTe DSL kernel
@cute.jit
def elementwise_add(mA, mB, mC):
    # ... kernel implementation ...
    pass

compiled = cute.compile(elementwise_add, mA, mB, mC)
compiled(mA, mB, mC)

# Result is in PyTorch tensor C
print(f"Result: {C[:3, :3]}")

Data Type Conversion

cutlass.torch.dtype

Convert between CUTLASS and PyTorch types:
import cutlass.torch as cutlass_torch
import cutlass
import torch

# CUTLASS -> PyTorch
torch_fp16 = cutlass_torch.dtype(cutlass.Float16)
assert torch_fp16 == torch.float16

torch_fp32 = cutlass_torch.dtype(cutlass.Float32)
assert torch_fp32 == torch.float32

# Supports all types
torch_bf16 = cutlass_torch.dtype(cutlass.BFloat16)
torch_fp8_e4m3 = cutlass_torch.dtype(cutlass.Float8E4M3FN)
torch_fp8_e5m2 = cutlass_torch.dtype(cutlass.Float8E5M2)
torch_int8 = cutlass_torch.dtype(cutlass.Int8)

# Use in kernel creation
dtype = cutlass_torch.dtype(cutlass.Float16)
tensor = torch.randn(128, 128, dtype=dtype, device='cuda')

CUDA Stream Integration

Getting Current Stream

import cutlass.torch as cutlass_torch
import torch

# Get current PyTorch stream
torch_stream = torch.cuda.current_stream()

# Convert to CUDA driver stream for CuTe DSL
import cuda.bindings.driver as cuda
current_stream = cuda.CUstream(torch_stream.cuda_stream)

@cute.jit
def kernel_with_stream(tensor, stream):
    my_kernel(tensor).launch(
        grid=[...],
        block=[...],
        stream=stream
    )

# Compile and run on current stream
compiled = cute.compile(kernel_with_stream, cute_tensor, current_stream)
compiled(cute_tensor, current_stream)

# Ensure completion
torch.cuda.synchronize()

Using Multiple Streams

import torch

# Create custom streams
stream1 = torch.cuda.Stream()
stream2 = torch.cuda.Stream()

# Run operations on different streams
with torch.cuda.stream(stream1):
    # CUTLASS operation on stream1
    plan1.run(A1, B1, C1, D1)

with torch.cuda.stream(stream2):
    # CUTLASS operation on stream2
    plan2.run(A2, B2, C2, D2)

# Synchronize both
torch.cuda.synchronize()

Building PyTorch CUDA Extensions

Method 1: JIT Compilation

Build and load extensions at runtime:
import torch
from torch.utils.cpp_extension import load_inline
import cutlass

# Define C++ wrapper code
cpp_source = """
torch::Tensor my_gemm_forward(
    torch::Tensor A,
    torch::Tensor B,
    torch::Tensor C
) {
    // Call CUTLASS kernel
    auto D = torch::zeros_like(C);
    // ... CUTLASS plan.run(A, B, C, D) ...
    return D;
}
"""

cuda_source = """
// CUDA kernel code generated by CUTLASS
// Or CuTe DSL compiled output
"""

# Load extension
my_extension = load_inline(
    name='my_gemm',
    cpp_sources=[cpp_source],
    cuda_sources=[cuda_source],
    functions=['my_gemm_forward'],
    verbose=True,
    extra_cuda_cflags=['-O3', '--use_fast_math']
)

# Use extension
A = torch.randn(128, 128, device='cuda')
B = torch.randn(128, 128, device='cuda')
C = torch.zeros(128, 128, device='cuda')
result = my_extension.my_gemm_forward(A, B, C)

Method 2: Setuptools Extension

Build extensions with setup.py:
1

Create Extension Directory

mkdir my_cutlass_extension
cd my_cutlass_extension
2

Write Kernel Code

cutlass_kernel.py:
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.kernel
def my_gemm_kernel(mA, mB, mC, ...):
    # Kernel implementation
    pass

@cute.jit
def my_gemm(mA, mB, mC):
    # Launch kernel
    pass

# Compile once during build
def build_kernel():
    import torch
    dummy_A = torch.zeros(128, 128, device='cuda', dtype=torch.float32)
    dummy_B = torch.zeros(128, 128, device='cuda', dtype=torch.float32)
    dummy_C = torch.zeros(128, 128, device='cuda', dtype=torch.float32)
    
    mA = from_dlpack(dummy_A)
    mB = from_dlpack(dummy_B)
    mC = from_dlpack(dummy_C)
    
    compiled = cute.compile(my_gemm, mA, mB, mC)
    
    # Export C++ code
    return compiled.export_cpp()
3

Create Python Wrapper

extension.cpp:
#include <torch/extension.h>
#include "cutlass_kernel.h"  // Generated from CuTe DSL

torch::Tensor my_gemm_forward(
    torch::Tensor A,
    torch::Tensor B,
    torch::Tensor C
) {
    auto D = torch::zeros_like(C);
    
    // Call generated kernel
    launch_my_gemm(
        A.data_ptr<float>(),
        B.data_ptr<float>(),
        C.data_ptr<float>(),
        D.data_ptr<float>(),
        A.size(0), A.size(1), B.size(1)
    );
    
    return D;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &my_gemm_forward, "My GEMM forward");
}
4

Create setup.py

setup.py:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='my_cutlass_extension',
    ext_modules=[
        CUDAExtension(
            name='my_cutlass_extension',
            sources=[
                'extension.cpp',
                'cutlass_kernel.cu',  # Generated from CuTe DSL
            ],
            extra_compile_args={
                'cxx': ['-O3'],
                'nvcc': [
                    '-O3',
                    '--use_fast_math',
                    '-gencode', 'arch=compute_80,code=sm_80',
                    '-gencode', 'arch=compute_90,code=sm_90',
                ]
            }
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)
5

Build and Install

pip install .

# Or for development
pip install -e .
6

Use Extension

import torch
import my_cutlass_extension

A = torch.randn(512, 512, device='cuda')
B = torch.randn(512, 512, device='cuda')
C = torch.zeros(512, 512, device='cuda')

D = my_cutlass_extension.forward(A, B, C)
print(f"Result: {D}")

Custom Autograd Functions

Implement custom forward and backward passes:
import torch
from torch.autograd import Function
import cutlass

class CUTLASSGemmFunction(Function):
    @staticmethod
    def forward(ctx, A, B):
        """Forward pass: C = A @ B"""
        M, K = A.shape
        K2, N = B.shape
        assert K == K2, "Inner dimensions must match"
        
        # Save for backward
        ctx.save_for_backward(A, B)
        
        # Run CUTLASS GEMM
        C_in = torch.zeros(M, N, dtype=A.dtype, device=A.device)
        C_out = torch.zeros(M, N, dtype=A.dtype, device=A.device)
        
        plan = cutlass.op.Gemm(
            element=A.dtype,
            layout=cutlass.LayoutType.RowMajor
        )
        plan.run(A, B, C_in, C_out)
        
        return C_out
    
    @staticmethod
    def backward(ctx, grad_output):
        """Backward pass: compute gradients"""
        A, B = ctx.saved_tensors
        
        # grad_A = grad_output @ B^T
        grad_A = None
        if ctx.needs_input_grad[0]:
            grad_A_in = torch.zeros_like(A)
            grad_A_out = torch.zeros_like(A)
            
            plan_grad_A = cutlass.op.Gemm(
                element=A.dtype,
                layout=cutlass.LayoutType.RowMajor
            )
            plan_grad_A.run(grad_output, B.t(), grad_A_in, grad_A_out)
            grad_A = grad_A_out
        
        # grad_B = A^T @ grad_output
        grad_B = None
        if ctx.needs_input_grad[1]:
            grad_B_in = torch.zeros_like(B)
            grad_B_out = torch.zeros_like(B)
            
            plan_grad_B = cutlass.op.Gemm(
                element=B.dtype,
                layout=cutlass.LayoutType.RowMajor
            )
            plan_grad_B.run(A.t(), grad_output, grad_B_in, grad_B_out)
            grad_B = grad_B_out
        
        return grad_A, grad_B

# Create functional interface
cutlass_gemm = CUTLASSGemmFunction.apply

# Use in training
A = torch.randn(128, 256, device='cuda', requires_grad=True)
B = torch.randn(256, 512, device='cuda', requires_grad=True)

# Forward
C = cutlass_gemm(A, B)

# Backward
loss = C.sum()
loss.backward()

print(f"grad_A: {A.grad}")
print(f"grad_B: {B.grad}")

Custom nn.Module

Wrap CUTLASS operations in PyTorch modules:
import torch
import torch.nn as nn
import cutlass

class CUTLASSLinear(nn.Module):
    """Linear layer using CUTLASS GEMM"""
    
    def __init__(self, in_features, out_features, bias=True, dtype=torch.float16):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Initialize weight and bias
        self.weight = nn.Parameter(torch.randn(
            out_features, in_features, dtype=dtype, device='cuda'
        ))
        if bias:
            self.bias = nn.Parameter(torch.zeros(
                out_features, dtype=dtype, device='cuda'
            ))
        else:
            self.register_parameter('bias', None)
        
        # Create CUTLASS plan
        self.plan = cutlass.op.Gemm(
            element=dtype,
            layout=cutlass.LayoutType.RowMajor
        )
    
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch, in_features)
        Returns:
            Output tensor of shape (batch, out_features)
        """
        batch_size = x.size(0)
        
        # Prepare output tensors
        C = self.bias.unsqueeze(0).expand(batch_size, -1) if self.bias is not None \
            else torch.zeros(batch_size, self.out_features, 
                           dtype=x.dtype, device=x.device)
        D = torch.zeros_like(C)
        
        # Run CUTLASS GEMM: D = x @ weight^T + C
        self.plan.run(x, self.weight.t(), C, D)
        
        return D
    
    def extra_repr(self):
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'

# Usage
model = nn.Sequential(
    CUTLASSLinear(512, 256, dtype=torch.float16),
    nn.ReLU(),
    CUTLASSLinear(256, 128, dtype=torch.float16),
    nn.ReLU(),
    CUTLASSLinear(128, 10, dtype=torch.float16),
)

# Forward pass
x = torch.randn(32, 512, dtype=torch.float16, device='cuda')
y = model(x)
print(f"Output shape: {y.shape}")

# Training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    optimizer.zero_grad()
    output = model(x)
    target = torch.randint(0, 10, (32,), device='cuda')
    loss = criterion(output.float(), target)  # Convert to fp32 for loss
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Grouped GEMM Example

Export grouped GEMM as PyTorch extension:
import cutlass
import torch

# Setup grouped GEMM
problems = [
    cutlass.GemmCoord(256, 256, 128),
    cutlass.GemmCoord(512, 256, 256),
    cutlass.GemmCoord(128, 512, 256),
]

# Create grouped GEMM plan
plan = cutlass.op.GroupedGemm(
    element=torch.float16,
    layout=cutlass.LayoutType.RowMajor
)

# Prepare tensors for each problem
As, Bs, Cs, Ds = [], [], [], []
for problem in problems:
    M, N, K = problem.m(), problem.n(), problem.k()
    As.append(torch.randn(M, K, dtype=torch.float16, device='cuda'))
    Bs.append(torch.randn(K, N, dtype=torch.float16, device='cuda'))
    Cs.append(torch.zeros(M, N, dtype=torch.float16, device='cuda'))
    Ds.append(torch.zeros(M, N, dtype=torch.float16, device='cuda'))

# Run grouped GEMM
plan.run(As, Bs, Cs, Ds)

# Verify each problem
for i, (A, B, D) in enumerate(zip(As, Bs, Ds)):
    ref = torch.matmul(A, B)
    torch.testing.assert_close(D, ref, rtol=1e-3, atol=1e-3)
    print(f"Problem {i}: Success!")

Performance Optimization

Tensor Layout

Ensure optimal memory layout:
import torch

# Create contiguous tensors
A = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
assert A.is_contiguous()

# Transpose creates non-contiguous view
A_t = A.t()
assert not A_t.is_contiguous()

# Make contiguous for optimal performance
A_t_contiguous = A_t.contiguous()
assert A_t_contiguous.is_contiguous()

# Use in CUTLASS
plan.run(A_t_contiguous, B, C, D)

Warm-up Compilation

Compile kernels before timing:
import torch
import cutlass
import time

# Create plan
plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor)

A = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
B = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
C = torch.zeros(1024, 1024, device='cuda', dtype=torch.float16)
D = torch.zeros(1024, 1024, device='cuda', dtype=torch.float16)

# Warm-up: compile and run once
plan.run(A, B, C, D)
torch.cuda.synchronize()

# Now benchmark
torch.cuda.synchronize()
start = time.time()

for _ in range(100):
    plan.run(A, B, C, D)

torch.cuda.synchronize()
elapsed = time.time() - start

print(f"Average time: {elapsed / 100 * 1000:.3f} ms")
print(f"TFLOPS: {2 * 1024**3 / (elapsed / 100) / 1e12:.2f}")

Memory Pinning

Use pinned memory for faster CPU-GPU transfers:
import torch

# Create pinned memory on CPU
A_cpu = torch.randn(1024, 1024, dtype=torch.float16, pin_memory=True)

# Transfer to GPU (faster with pinned memory)
A_gpu = A_cpu.cuda()

# Use in CUTLASS
# ...

Troubleshooting

Ensure PyTorch and CUTLASS types match:
# Check tensor dtype
print(f"PyTorch dtype: {tensor.dtype}")

# Convert if needed
tensor = tensor.to(torch.float16)

# Or create with correct dtype
tensor = torch.randn(..., dtype=torch.float16)
Make tensors contiguous:
if not tensor.is_contiguous():
    tensor = tensor.contiguous()
Synchronize streams when needed:
# After CUTLASS operation
torch.cuda.synchronize()

# Or synchronize specific stream
stream.synchronize()
Monitor and free memory:
# Check memory
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Free unused memory
torch.cuda.empty_cache()

# Clear specific tensors
del tensor
torch.cuda.empty_cache()

Example: Complete Training Loop

import torch
import torch.nn as nn
import cutlass

class CUTLASSModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = CUTLASSLinear(784, 256)
        self.fc2 = CUTLASSLinear(256, 128)
        self.fc3 = nn.Linear(128, 10)  # Standard PyTorch layer
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x.float())  # Convert to fp32 for final layer
        return x

# Setup
device = 'cuda'
model = CUTLASSModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Dummy data (replace with real data loader)
train_loader = [
    (torch.randn(32, 1, 28, 28, dtype=torch.float16, device=device),
     torch.randint(0, 10, (32,), device=device))
    for _ in range(100)
]

# Training loop
for epoch in range(10):
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        
        output = model(data)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")

print("Training complete!")

Next Steps

CuTe DSL

Learn to write custom kernels

Examples

Explore more integration examples

API Reference

Browse the complete API

Quickstart

Get started quickly

Build docs developers (and LLMs) love