Overview
AWS Batch is the compute engine for Metaflow workflows running on AWS. The @batch decorator allows you to execute individual steps on AWS Batch, providing access to scalable cloud compute resources including CPUs, GPUs, and specialized hardware like AWS Inferentia.
AWS Batch is used automatically when deploying to AWS Step Functions . You can also use @batch for individual steps while running flows locally.
Basic Usage
Simple Batch Step
Add the @batch decorator to any step:
from metaflow import FlowSpec, step, batch
class MyFlow ( FlowSpec ):
@step
def start ( self ):
self .next( self .process)
@batch ( cpu = 4 , memory = 16000 )
@step
def process ( self ):
# This step runs on AWS Batch
import numpy as np
self .results = np.random.rand( 1000000 ).mean()
self .next( self .end)
@step
def end ( self ):
print ( f "Results: { self .results } " )
Run with AWS Batch:
python myflow.py run --with batch
Resource Configuration
CPU and Memory
Specify compute resources for your step:
@batch ( cpu = 8 , memory = 32000 ) # 8 CPUs, 32GB RAM
@step
def heavy_computation ( self ):
# Your compute-intensive code
pass
GPU Resources
Request GPU instances:
@batch ( cpu = 4 , memory = 16000 , gpu = 1 )
@step
def train_model ( self ):
import torch
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
# Your GPU training code
pass
Ensure your AWS Batch compute environment has GPU instances configured.
Inferentia and Trainium
Use AWS Inferentia chips for ML inference or Trainium for training:
@batch ( cpu = 4 , memory = 16000 , inferentia = 1 )
@step
def inference ( self ):
# Run inference on Inferentia
pass
# Trainium is an alias for inferentia
@batch ( cpu = 16 , memory = 32000 , trainium = 4 )
@step
def train_large_model ( self ):
# Distributed training on Trainium
pass
Do not specify both inferentia and trainium parameters - use only one.
Docker Images
Default Image
By default, Metaflow uses a Python image matching your local Python version:
@batch # Uses python:3.9 if you're running Python 3.9
@step
def my_step ( self ):
pass
Custom Image
Specify a custom Docker image:
@batch ( image = 'my-registry.com/my-image:v1.0' )
@step
def custom_env_step ( self ):
# Runs in your custom container
pass
Configuration-Based Image
Set a default image via environment variable:
export METAFLOW_BATCH_CONTAINER_IMAGE = "my-company/ml-base:latest"
Or specify a registry:
export METAFLOW_BATCH_CONTAINER_REGISTRY = "123456789.dkr.ecr.us-east-1.amazonaws.com"
If you specify a registry without an image path, images will be pulled from that registry: @batch ( image = 'myimage:v1' ) # Pulled from configured registry
Queue and IAM Configuration
Job Queue
Specify which AWS Batch job queue to use:
@batch ( queue = 'high-priority-queue' , cpu = 4 , memory = 8000 )
@step
def urgent_task ( self ):
pass
Or set a default:
export METAFLOW_BATCH_JOB_QUEUE = "default-queue"
IAM Roles
Configure IAM roles for container access:
@batch (
iam_role = 'arn:aws:iam::123456789:role/BatchJobRole' ,
execution_role = 'arn:aws:iam::123456789:role/BatchExecutionRole'
)
@step
def secure_step ( self ):
# Access AWS resources with this role
pass
Environment variables:
export METAFLOW_ECS_S3_ACCESS_IAM_ROLE = "arn:aws:iam::123456789:role/BatchJobRole"
export METAFLOW_ECS_FARGATE_EXECUTION_ROLE = "arn:aws:iam::123456789:role/FargateRole"
Advanced Features
Shared Memory
Increase shared memory for /dev/shm:
@batch ( cpu = 4 , memory = 16000 , shared_memory = 8000 ) # 8GB shared memory
@step
def parallel_processing ( self ):
# Useful for multiprocessing with shared memory
pass
Swap Configuration
@batch (
cpu = 4 ,
memory = 16000 ,
max_swap = 8000 , # 8GB swap space
swappiness = 60 # 0-100, default is 60
)
@step
def memory_intensive ( self ):
pass
Elastic Fabric Adapter (EFA)
Enable high-performance networking for distributed workloads:
@batch ( cpu = 16 , memory = 64000 , efa = 1 )
@step
def distributed_training ( self ):
# Use EFA for multi-node communication
pass
Tmpfs (Temporary File System)
Mount a tmpfs volume for fast temporary storage:
@batch (
cpu = 4 ,
memory = 16000 ,
use_tmpfs = True ,
tmpfs_size = 4000 , # 4GB tmpfs
tmpfs_path = '/metaflow_temp' ,
tmpfs_tempdir = True # Set METAFLOW_TEMPDIR to tmpfs path
)
@step
def fast_io ( self ):
import tempfile
# temp files use tmpfs for fast I/O
with tempfile.NamedTemporaryFile() as f:
f.write( b 'fast writes to RAM' )
pass
Tmpfs is not available on AWS Fargate compute environments.
Host and EFS Volumes
Mount host directories or EFS file systems:
@batch (
cpu = 4 ,
memory = 8000 ,
host_volumes = [ '/host/path:/container/path' ],
efs_volumes = [ 'fs-12345:/efs/mount' ]
)
@step
def persistent_storage ( self ):
# Access mounted volumes
pass
Ephemeral Storage (Fargate)
For Fargate compute environments, configure ephemeral storage:
@batch (
cpu = 4 ,
memory = 8000 ,
ephemeral_storage = 100 # 100 GiB (21-200 range)
)
@step
def fargate_task ( self ):
pass
Custom Logging
Configure custom log drivers:
@batch (
cpu = 2 ,
memory = 4000 ,
log_driver = 'awslogs' ,
log_options = [
'awslogs-group:my-log-group' ,
'awslogs-region:us-west-2' ,
'awslogs-stream-prefix:myflow'
]
)
@step
def custom_logging ( self ):
pass
Privileged Mode
Run containers in privileged mode (use cautiously):
@batch ( cpu = 2 , memory = 4000 , privileged = True )
@step
def privileged_task ( self ):
# Container runs with elevated privileges
pass
Timeouts and Retries
Execution Timeout
Set a maximum runtime for your step:
from metaflow import timeout
@batch ( cpu = 4 , memory = 8000 )
@timeout ( hours = 2 )
@step
def long_running_task ( self ):
# Task will be terminated after 2 hours
pass
Retry Logic
from metaflow import retry
@batch ( cpu = 4 , memory = 8000 )
@retry ( times = 3 )
@step
def resilient_task ( self ):
# Automatically retries up to 3 times on failure
pass
With backoff:
@batch ( cpu = 4 , memory = 8000 )
@retry ( times = 3 , minutes_between_retries = 5 )
@step
def retry_with_delay ( self ):
# Wait 5 minutes between retries
pass
Tagging
Add AWS tags for cost tracking and organization:
@batch (
cpu = 4 ,
memory = 8000 ,
aws_batch_tags = {
'project' : 'recommendation-engine' ,
'team' : 'ml-platform' ,
'environment' : 'production'
}
)
@step
def tagged_task ( self ):
pass
Set default tags via configuration:
# In your config
BATCH_DEFAULT_TAGS = {
'company' : 'acme' ,
'cost-center' : 'ml'
}
Metaflow automatically tags jobs with:
app: metaflow
metaflow.flow_name
metaflow.run_id
metaflow.step_name
metaflow.user
metaflow.version
To disable automatic tagging:
export METAFLOW_BATCH_EMIT_TAGS = false
Multi-Node Jobs
Run distributed jobs across multiple nodes:
@batch ( cpu = 8 , memory = 32000 )
@step
def distributed_job ( self ):
from metaflow import current
if current.parallel.node_index == 0 :
print ( "I'm the main node" )
print ( f "Node { current.parallel.node_index } of { current.parallel.num_nodes } " )
print ( f "Main node IP: { current.parallel.main_ip } " )
Run with multiple nodes:
python myflow.py run --with batch --num-parallel 4
Environment Variables
Custom Environment
Set environment variables for your step:
from metaflow import environment
@batch ( cpu = 4 , memory = 8000 )
@environment ( vars = {
'MODEL_PATH' : 's3://my-bucket/models/' ,
'API_KEY' : 'secret-key-123'
})
@step
def configured_step ( self ):
import os
model_path = os.environ[ 'MODEL_PATH' ]
pass
Metaflow automatically sets:
METAFLOW_FLOW_NAME
METAFLOW_RUN_ID
METAFLOW_STEP_NAME
METAFLOW_TASK_ID
METAFLOW_RETRY_COUNT
METAFLOW_CODE_URL
AWS_BATCH_JOB_ID
AWS_BATCH_JOB_ATTEMPT
Monitoring
Metaflow automatically captures:
AWS Batch job ID
Job attempt number
Compute environment name
Job queue name
EC2 instance metadata (type, region, availability zone)
CloudWatch logs location
Access metadata:
from metaflow import Flow
run = Flow( 'MyFlow' ).latest_run
for step in run:
for task in step:
print (task.metadata_dict)
Spot Interruption Handling
Metaflow monitors for spot instance interruptions:
from metaflow import current
@batch ( cpu = 4 , memory = 8000 )
@step
def spot_aware_task ( self ):
# Check if running on spot
if current.spot_termination_notice:
print ( "Spot interruption detected!" )
# Save checkpoint
Local Testing
Test batch steps locally before deploying:
# Run locally (ignores @batch)
python myflow.py run
# Run with batch locally (simulates batch environment)
python myflow.py run --with batch
Managing Batch Jobs
List Running Jobs
python myflow.py batch list
Kill Running Jobs
# Kill jobs from latest run
python myflow.py batch kill
# Kill jobs from specific run
python myflow.py batch kill --run-id < run-i d >
# Kill all your jobs
python myflow.py batch kill --my-runs
Configuration
Key environment variables:
Variable Description Default METAFLOW_BATCH_CONTAINER_IMAGEDefault Docker image python:X.YMETAFLOW_BATCH_CONTAINER_REGISTRYDefault Docker registry None METAFLOW_BATCH_JOB_QUEUEDefault job queue None METAFLOW_ECS_S3_ACCESS_IAM_ROLEIAM role for container None METAFLOW_ECS_FARGATE_EXECUTION_ROLEFargate execution role None METAFLOW_BATCH_EMIT_TAGSEnable automatic tagging trueMETAFLOW_BATCH_DEFAULT_TAGSDefault AWS tags {}
See AWS Configuration for complete setup.
Best Practices
Right-size your resources
Start with minimal resources and scale up based on actual usage. Over-provisioning wastes cost. # Start small
@batch ( cpu = 1 , memory = 4000 )
# Then profile and adjust
Configure your Batch compute environment to use spot instances for cost savings (up to 90% cheaper).
Use smaller base images and multi-stage builds to reduce pull time: FROM python:3.9-slim # Smaller than python:3.9
For I/O-intensive workloads, use tmpfs to avoid disk bottlenecks: @batch ( use_tmpfs = True , tmpfs_size = 8000 )
Always set timeouts to prevent runaway jobs:
Troubleshooting
Job Stuck in RUNNABLE
Issue : Job stays in RUNNABLE state
Causes :
Insufficient compute resources in compute environment
Resource requirements too high for available instances
Service limits reached
Solution : Check AWS Batch console for specific errors and adjust resources or compute environment.
Container Pull Failures
Issue : “CannotPullContainerError”
Solution :
Verify image exists and is accessible
Check IAM permissions for ECR
Ensure execution role has ECR pull permissions
OOM (Out of Memory) Kills
Issue : Tasks killed due to memory
Solution : Increase memory:
@batch ( cpu = 4 , memory = 32000 ) # Increase from default 4096
Timeout Errors
Issue : Job exceeds time limit
Solution : Increase timeout or optimize code:
@timeout ( hours = 12 ) # Increase limit
Next Steps
Step Functions Deploy complete workflows to AWS Step Functions
Configuration Configure AWS credentials and resources
Scaling Learn more about scaling patterns
Resources Resource management best practices