Overview
Direct Integration
Use CUTLASS kernels with PyTorch tensors without conversion
CUDA Extensions
Export kernels as PyTorch CUDA extensions for deployment
Custom Autograd
Implement custom backward passes for training
Stream Integration
Work with PyTorch CUDA streams
Using PyTorch Tensors
With cutlass_cppgen
The CUTLASS Python interface accepts PyTorch tensors directly:import cutlass
import torch
# Create PyTorch tensors on GPU
M, N, K = 2048, 2048, 2048
A = torch.randn(M, K, dtype=torch.float16, device='cuda')
B = torch.randn(K, N, dtype=torch.float16, device='cuda')
C = torch.zeros(M, N, dtype=torch.float16, device='cuda')
D = torch.zeros(M, N, dtype=torch.float16, device='cuda')
# Create and run GEMM plan
plan = cutlass.op.Gemm(
element=torch.float16,
layout=cutlass.LayoutType.RowMajor
)
plan.run(A, B, C, D) # D = A @ B + C
# Verify against PyTorch
ref = torch.matmul(A, B) + C
torch.testing.assert_close(D, ref, rtol=1e-3, atol=1e-3)
print("Results match!")
With CuTe DSL
Convert PyTorch tensors using DLPack:import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import torch
# Create PyTorch tensors
A = torch.randn(1024, 512, device='cuda', dtype=torch.float32)
B = torch.randn(1024, 512, device='cuda', dtype=torch.float32)
C = torch.zeros(1024, 512, device='cuda', dtype=torch.float32)
# Convert to CuTe tensors
mA = from_dlpack(A).mark_layout_dynamic()
mB = from_dlpack(B).mark_layout_dynamic()
mC = from_dlpack(C).mark_layout_dynamic()
# Use in CuTe DSL kernel
@cute.jit
def elementwise_add(mA, mB, mC):
# ... kernel implementation ...
pass
compiled = cute.compile(elementwise_add, mA, mB, mC)
compiled(mA, mB, mC)
# Result is in PyTorch tensor C
print(f"Result: {C[:3, :3]}")
Data Type Conversion
cutlass.torch.dtype
Convert between CUTLASS and PyTorch types:import cutlass.torch as cutlass_torch
import cutlass
import torch
# CUTLASS -> PyTorch
torch_fp16 = cutlass_torch.dtype(cutlass.Float16)
assert torch_fp16 == torch.float16
torch_fp32 = cutlass_torch.dtype(cutlass.Float32)
assert torch_fp32 == torch.float32
# Supports all types
torch_bf16 = cutlass_torch.dtype(cutlass.BFloat16)
torch_fp8_e4m3 = cutlass_torch.dtype(cutlass.Float8E4M3FN)
torch_fp8_e5m2 = cutlass_torch.dtype(cutlass.Float8E5M2)
torch_int8 = cutlass_torch.dtype(cutlass.Int8)
# Use in kernel creation
dtype = cutlass_torch.dtype(cutlass.Float16)
tensor = torch.randn(128, 128, dtype=dtype, device='cuda')
CUDA Stream Integration
Getting Current Stream
import cutlass.torch as cutlass_torch
import torch
# Get current PyTorch stream
torch_stream = torch.cuda.current_stream()
# Convert to CUDA driver stream for CuTe DSL
import cuda.bindings.driver as cuda
current_stream = cuda.CUstream(torch_stream.cuda_stream)
@cute.jit
def kernel_with_stream(tensor, stream):
my_kernel(tensor).launch(
grid=[...],
block=[...],
stream=stream
)
# Compile and run on current stream
compiled = cute.compile(kernel_with_stream, cute_tensor, current_stream)
compiled(cute_tensor, current_stream)
# Ensure completion
torch.cuda.synchronize()
Using Multiple Streams
import torch
# Create custom streams
stream1 = torch.cuda.Stream()
stream2 = torch.cuda.Stream()
# Run operations on different streams
with torch.cuda.stream(stream1):
# CUTLASS operation on stream1
plan1.run(A1, B1, C1, D1)
with torch.cuda.stream(stream2):
# CUTLASS operation on stream2
plan2.run(A2, B2, C2, D2)
# Synchronize both
torch.cuda.synchronize()
Building PyTorch CUDA Extensions
Method 1: JIT Compilation
Build and load extensions at runtime:import torch
from torch.utils.cpp_extension import load_inline
import cutlass
# Define C++ wrapper code
cpp_source = """
torch::Tensor my_gemm_forward(
torch::Tensor A,
torch::Tensor B,
torch::Tensor C
) {
// Call CUTLASS kernel
auto D = torch::zeros_like(C);
// ... CUTLASS plan.run(A, B, C, D) ...
return D;
}
"""
cuda_source = """
// CUDA kernel code generated by CUTLASS
// Or CuTe DSL compiled output
"""
# Load extension
my_extension = load_inline(
name='my_gemm',
cpp_sources=[cpp_source],
cuda_sources=[cuda_source],
functions=['my_gemm_forward'],
verbose=True,
extra_cuda_cflags=['-O3', '--use_fast_math']
)
# Use extension
A = torch.randn(128, 128, device='cuda')
B = torch.randn(128, 128, device='cuda')
C = torch.zeros(128, 128, device='cuda')
result = my_extension.my_gemm_forward(A, B, C)
Method 2: Setuptools Extension
Build extensions with setup.py:Write Kernel Code
cutlass_kernel.py:
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.kernel
def my_gemm_kernel(mA, mB, mC, ...):
# Kernel implementation
pass
@cute.jit
def my_gemm(mA, mB, mC):
# Launch kernel
pass
# Compile once during build
def build_kernel():
import torch
dummy_A = torch.zeros(128, 128, device='cuda', dtype=torch.float32)
dummy_B = torch.zeros(128, 128, device='cuda', dtype=torch.float32)
dummy_C = torch.zeros(128, 128, device='cuda', dtype=torch.float32)
mA = from_dlpack(dummy_A)
mB = from_dlpack(dummy_B)
mC = from_dlpack(dummy_C)
compiled = cute.compile(my_gemm, mA, mB, mC)
# Export C++ code
return compiled.export_cpp()
Create Python Wrapper
extension.cpp:
#include <torch/extension.h>
#include "cutlass_kernel.h" // Generated from CuTe DSL
torch::Tensor my_gemm_forward(
torch::Tensor A,
torch::Tensor B,
torch::Tensor C
) {
auto D = torch::zeros_like(C);
// Call generated kernel
launch_my_gemm(
A.data_ptr<float>(),
B.data_ptr<float>(),
C.data_ptr<float>(),
D.data_ptr<float>(),
A.size(0), A.size(1), B.size(1)
);
return D;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &my_gemm_forward, "My GEMM forward");
}
Create setup.py
setup.py:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='my_cutlass_extension',
ext_modules=[
CUDAExtension(
name='my_cutlass_extension',
sources=[
'extension.cpp',
'cutlass_kernel.cu', # Generated from CuTe DSL
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--use_fast_math',
'-gencode', 'arch=compute_80,code=sm_80',
'-gencode', 'arch=compute_90,code=sm_90',
]
}
)
],
cmdclass={
'build_ext': BuildExtension
}
)
Custom Autograd Functions
Implement custom forward and backward passes:import torch
from torch.autograd import Function
import cutlass
class CUTLASSGemmFunction(Function):
@staticmethod
def forward(ctx, A, B):
"""Forward pass: C = A @ B"""
M, K = A.shape
K2, N = B.shape
assert K == K2, "Inner dimensions must match"
# Save for backward
ctx.save_for_backward(A, B)
# Run CUTLASS GEMM
C_in = torch.zeros(M, N, dtype=A.dtype, device=A.device)
C_out = torch.zeros(M, N, dtype=A.dtype, device=A.device)
plan = cutlass.op.Gemm(
element=A.dtype,
layout=cutlass.LayoutType.RowMajor
)
plan.run(A, B, C_in, C_out)
return C_out
@staticmethod
def backward(ctx, grad_output):
"""Backward pass: compute gradients"""
A, B = ctx.saved_tensors
# grad_A = grad_output @ B^T
grad_A = None
if ctx.needs_input_grad[0]:
grad_A_in = torch.zeros_like(A)
grad_A_out = torch.zeros_like(A)
plan_grad_A = cutlass.op.Gemm(
element=A.dtype,
layout=cutlass.LayoutType.RowMajor
)
plan_grad_A.run(grad_output, B.t(), grad_A_in, grad_A_out)
grad_A = grad_A_out
# grad_B = A^T @ grad_output
grad_B = None
if ctx.needs_input_grad[1]:
grad_B_in = torch.zeros_like(B)
grad_B_out = torch.zeros_like(B)
plan_grad_B = cutlass.op.Gemm(
element=B.dtype,
layout=cutlass.LayoutType.RowMajor
)
plan_grad_B.run(A.t(), grad_output, grad_B_in, grad_B_out)
grad_B = grad_B_out
return grad_A, grad_B
# Create functional interface
cutlass_gemm = CUTLASSGemmFunction.apply
# Use in training
A = torch.randn(128, 256, device='cuda', requires_grad=True)
B = torch.randn(256, 512, device='cuda', requires_grad=True)
# Forward
C = cutlass_gemm(A, B)
# Backward
loss = C.sum()
loss.backward()
print(f"grad_A: {A.grad}")
print(f"grad_B: {B.grad}")
Custom nn.Module
Wrap CUTLASS operations in PyTorch modules:import torch
import torch.nn as nn
import cutlass
class CUTLASSLinear(nn.Module):
"""Linear layer using CUTLASS GEMM"""
def __init__(self, in_features, out_features, bias=True, dtype=torch.float16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Initialize weight and bias
self.weight = nn.Parameter(torch.randn(
out_features, in_features, dtype=dtype, device='cuda'
))
if bias:
self.bias = nn.Parameter(torch.zeros(
out_features, dtype=dtype, device='cuda'
))
else:
self.register_parameter('bias', None)
# Create CUTLASS plan
self.plan = cutlass.op.Gemm(
element=dtype,
layout=cutlass.LayoutType.RowMajor
)
def forward(self, x):
"""
Args:
x: Input tensor of shape (batch, in_features)
Returns:
Output tensor of shape (batch, out_features)
"""
batch_size = x.size(0)
# Prepare output tensors
C = self.bias.unsqueeze(0).expand(batch_size, -1) if self.bias is not None \
else torch.zeros(batch_size, self.out_features,
dtype=x.dtype, device=x.device)
D = torch.zeros_like(C)
# Run CUTLASS GEMM: D = x @ weight^T + C
self.plan.run(x, self.weight.t(), C, D)
return D
def extra_repr(self):
return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'
# Usage
model = nn.Sequential(
CUTLASSLinear(512, 256, dtype=torch.float16),
nn.ReLU(),
CUTLASSLinear(256, 128, dtype=torch.float16),
nn.ReLU(),
CUTLASSLinear(128, 10, dtype=torch.float16),
)
# Forward pass
x = torch.randn(32, 512, dtype=torch.float16, device='cuda')
y = model(x)
print(f"Output shape: {y.shape}")
# Training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
optimizer.zero_grad()
output = model(x)
target = torch.randint(0, 10, (32,), device='cuda')
loss = criterion(output.float(), target) # Convert to fp32 for loss
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
Grouped GEMM Example
Export grouped GEMM as PyTorch extension:import cutlass
import torch
# Setup grouped GEMM
problems = [
cutlass.GemmCoord(256, 256, 128),
cutlass.GemmCoord(512, 256, 256),
cutlass.GemmCoord(128, 512, 256),
]
# Create grouped GEMM plan
plan = cutlass.op.GroupedGemm(
element=torch.float16,
layout=cutlass.LayoutType.RowMajor
)
# Prepare tensors for each problem
As, Bs, Cs, Ds = [], [], [], []
for problem in problems:
M, N, K = problem.m(), problem.n(), problem.k()
As.append(torch.randn(M, K, dtype=torch.float16, device='cuda'))
Bs.append(torch.randn(K, N, dtype=torch.float16, device='cuda'))
Cs.append(torch.zeros(M, N, dtype=torch.float16, device='cuda'))
Ds.append(torch.zeros(M, N, dtype=torch.float16, device='cuda'))
# Run grouped GEMM
plan.run(As, Bs, Cs, Ds)
# Verify each problem
for i, (A, B, D) in enumerate(zip(As, Bs, Ds)):
ref = torch.matmul(A, B)
torch.testing.assert_close(D, ref, rtol=1e-3, atol=1e-3)
print(f"Problem {i}: Success!")
Performance Optimization
Tensor Layout
Ensure optimal memory layout:import torch
# Create contiguous tensors
A = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
assert A.is_contiguous()
# Transpose creates non-contiguous view
A_t = A.t()
assert not A_t.is_contiguous()
# Make contiguous for optimal performance
A_t_contiguous = A_t.contiguous()
assert A_t_contiguous.is_contiguous()
# Use in CUTLASS
plan.run(A_t_contiguous, B, C, D)
Warm-up Compilation
Compile kernels before timing:import torch
import cutlass
import time
# Create plan
plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor)
A = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
B = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
C = torch.zeros(1024, 1024, device='cuda', dtype=torch.float16)
D = torch.zeros(1024, 1024, device='cuda', dtype=torch.float16)
# Warm-up: compile and run once
plan.run(A, B, C, D)
torch.cuda.synchronize()
# Now benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
plan.run(A, B, C, D)
torch.cuda.synchronize()
elapsed = time.time() - start
print(f"Average time: {elapsed / 100 * 1000:.3f} ms")
print(f"TFLOPS: {2 * 1024**3 / (elapsed / 100) / 1e12:.2f}")
Memory Pinning
Use pinned memory for faster CPU-GPU transfers:import torch
# Create pinned memory on CPU
A_cpu = torch.randn(1024, 1024, dtype=torch.float16, pin_memory=True)
# Transfer to GPU (faster with pinned memory)
A_gpu = A_cpu.cuda()
# Use in CUTLASS
# ...
Troubleshooting
Data Type Mismatch
Data Type Mismatch
Ensure PyTorch and CUTLASS types match:
# Check tensor dtype
print(f"PyTorch dtype: {tensor.dtype}")
# Convert if needed
tensor = tensor.to(torch.float16)
# Or create with correct dtype
tensor = torch.randn(..., dtype=torch.float16)
Non-Contiguous Tensors
Non-Contiguous Tensors
Make tensors contiguous:
if not tensor.is_contiguous():
tensor = tensor.contiguous()
Stream Synchronization
Stream Synchronization
Synchronize streams when needed:
# After CUTLASS operation
torch.cuda.synchronize()
# Or synchronize specific stream
stream.synchronize()
GPU Memory
GPU Memory
Monitor and free memory:
# Check memory
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
# Free unused memory
torch.cuda.empty_cache()
# Clear specific tensors
del tensor
torch.cuda.empty_cache()
Example: Complete Training Loop
import torch
import torch.nn as nn
import cutlass
class CUTLASSModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = CUTLASSLinear(784, 256)
self.fc2 = CUTLASSLinear(256, 128)
self.fc3 = nn.Linear(128, 10) # Standard PyTorch layer
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x.float()) # Convert to fp32 for final layer
return x
# Setup
device = 'cuda'
model = CUTLASSModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# Dummy data (replace with real data loader)
train_loader = [
(torch.randn(32, 1, 28, 28, dtype=torch.float16, device=device),
torch.randint(0, 10, (32,), device=device))
for _ in range(100)
]
# Training loop
for epoch in range(10):
total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
print("Training complete!")
Next Steps
CuTe DSL
Learn to write custom kernels
Examples
Explore more integration examples
API Reference
Browse the complete API
Quickstart
Get started quickly