GEMM Operations
GEMM (General Matrix Multiplication) is the fundamental operation at the heart of CUTLASS. This page explains how CUTLASS decomposes GEMM operations hierarchically to achieve optimal performance on NVIDIA GPUs.What is GEMM?
GEMM computes the matrix product:- A is an M×K matrix
- B is a K×N matrix
- C is an M×N matrix (source)
- D is an M×N matrix (destination)
- α and β are scalar coefficients
CUTLASS supports various GEMM variants including batched GEMM, grouped GEMM, split-K GEMM, and sparse GEMM.
GEMM Coordinate System
CUTLASS uses a coordinate system to navigate the GEMM problem space. TheGemmCoord structure represents positions within the computation:
include/cutlass/gemm_coord.h:86
Hierarchical Decomposition
CUTLASS decomposes GEMM into a hierarchy of smaller operations to efficiently utilize GPU hardware:1. Device Level
The entire GEMM problem is mapped across the GPU grid. Each CUDA threadblock processes one or more tiles of the output matrix.2. Threadblock Level
Each threadblock computes a tile (e.g., 128×128) of the output matrix by:- Loading tiles from global memory to shared memory
- Performing warp-level operations
- Writing results back to global memory
include/cutlass/gemm_coord.h:42
3. Warp Level
Warps (groups of 32 threads) collaborate to compute smaller tiles using Tensor Core instructions when available.4. Thread Level
Individual threads process the smallest granularity of data, performing scalar or vector operations.This hierarchical approach enables CUTLASS to:
- Maximize data reuse through shared memory
- Leverage Tensor Cores for accelerated computation
- Achieve high memory bandwidth utilization
- Scale across different GPU architectures
Basic GEMM Example
Here’s a simple example instantiating a CUTLASS GEMM kernel:examples/00_basic_gemm/basic_gemm.cu:79
Batched and Grouped GEMM
CUTLASS extends basic GEMM to support multiple matrix multiplications:Batched GEMM
Computes multiple identical-sized GEMMs in parallel:include/cutlass/gemm_coord.h:252
Grouped GEMM
Computes multiple GEMMs with different sizes in a single kernel launch, ideal for dynamic batching scenarios.Data Movement Strategies
Efficient GEMM requires careful orchestration of data movement:-
Global Memory → Shared Memory
- Use cooperative loads across threadblock
- Leverage asynchronous copy instructions (SM80+)
-
Shared Memory → Registers
- Partition data across warps
- Use swizzling to avoid bank conflicts
-
Register → Tensor Cores
- Feed matrix fragments to MMA instructions
- Maximize computational throughput
Memory Hierarchy Performance
Memory Hierarchy Performance
Different memory levels have vastly different bandwidths:
- Registers: ~20 TB/s (per SM)
- Shared Memory: ~10 TB/s (per SM)
- L2 Cache: ~3-5 TB/s
- Global Memory (HBM): ~2-3 TB/s
GEMM Variants
CUTLASS supports numerous GEMM specializations:- Split-K: Parallelizes the K dimension across threadblocks
- Stream-K: Dynamic work distribution for improved load balancing
- Sparse GEMM: Exploits structured sparsity in matrices
- Complex GEMM: Native support for complex number arithmetic
- Mixed Precision: Different input/output data types
Next Steps
CuTe Library
Learn about the tensor abstraction layer
Tensor Cores
Understand hardware-accelerated matrix operations
Memory Layouts
Explore data layout strategies
Quick Start
Build your first CUTLASS kernel