Skip to main content

Overview

AWS Batch is the compute engine for Metaflow workflows running on AWS. The @batch decorator allows you to execute individual steps on AWS Batch, providing access to scalable cloud compute resources including CPUs, GPUs, and specialized hardware like AWS Inferentia.
AWS Batch is used automatically when deploying to AWS Step Functions. You can also use @batch for individual steps while running flows locally.

Basic Usage

Simple Batch Step

Add the @batch decorator to any step:
from metaflow import FlowSpec, step, batch

class MyFlow(FlowSpec):
    @step
    def start(self):
        self.next(self.process)
    
    @batch(cpu=4, memory=16000)
    @step
    def process(self):
        # This step runs on AWS Batch
        import numpy as np
        self.results = np.random.rand(1000000).mean()
        self.next(self.end)
    
    @step
    def end(self):
        print(f"Results: {self.results}")
Run with AWS Batch:
python myflow.py run --with batch

Resource Configuration

CPU and Memory

Specify compute resources for your step:
@batch(cpu=8, memory=32000)  # 8 CPUs, 32GB RAM
@step
def heavy_computation(self):
    # Your compute-intensive code
    pass

GPU Resources

Request GPU instances:
@batch(cpu=4, memory=16000, gpu=1)
@step
def train_model(self):
    import torch
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Your GPU training code
    pass
Ensure your AWS Batch compute environment has GPU instances configured.

Inferentia and Trainium

Use AWS Inferentia chips for ML inference or Trainium for training:
@batch(cpu=4, memory=16000, inferentia=1)
@step
def inference(self):
    # Run inference on Inferentia
    pass

# Trainium is an alias for inferentia
@batch(cpu=16, memory=32000, trainium=4)
@step
def train_large_model(self):
    # Distributed training on Trainium
    pass
Do not specify both inferentia and trainium parameters - use only one.

Docker Images

Default Image

By default, Metaflow uses a Python image matching your local Python version:
@batch  # Uses python:3.9 if you're running Python 3.9
@step
def my_step(self):
    pass

Custom Image

Specify a custom Docker image:
@batch(image='my-registry.com/my-image:v1.0')
@step
def custom_env_step(self):
    # Runs in your custom container
    pass

Configuration-Based Image

Set a default image via environment variable:
export METAFLOW_BATCH_CONTAINER_IMAGE="my-company/ml-base:latest"
Or specify a registry:
export METAFLOW_BATCH_CONTAINER_REGISTRY="123456789.dkr.ecr.us-east-1.amazonaws.com"
If you specify a registry without an image path, images will be pulled from that registry:
@batch(image='myimage:v1')  # Pulled from configured registry

Queue and IAM Configuration

Job Queue

Specify which AWS Batch job queue to use:
@batch(queue='high-priority-queue', cpu=4, memory=8000)
@step
def urgent_task(self):
    pass
Or set a default:
export METAFLOW_BATCH_JOB_QUEUE="default-queue"

IAM Roles

Configure IAM roles for container access:
@batch(
    iam_role='arn:aws:iam::123456789:role/BatchJobRole',
    execution_role='arn:aws:iam::123456789:role/BatchExecutionRole'
)
@step
def secure_step(self):
    # Access AWS resources with this role
    pass
Environment variables:
export METAFLOW_ECS_S3_ACCESS_IAM_ROLE="arn:aws:iam::123456789:role/BatchJobRole"
export METAFLOW_ECS_FARGATE_EXECUTION_ROLE="arn:aws:iam::123456789:role/FargateRole"

Advanced Features

Shared Memory

Increase shared memory for /dev/shm:
@batch(cpu=4, memory=16000, shared_memory=8000)  # 8GB shared memory
@step
def parallel_processing(self):
    # Useful for multiprocessing with shared memory
    pass

Swap Configuration

@batch(
    cpu=4,
    memory=16000,
    max_swap=8000,      # 8GB swap space
    swappiness=60       # 0-100, default is 60
)
@step
def memory_intensive(self):
    pass

Elastic Fabric Adapter (EFA)

Enable high-performance networking for distributed workloads:
@batch(cpu=16, memory=64000, efa=1)
@step
def distributed_training(self):
    # Use EFA for multi-node communication
    pass

Tmpfs (Temporary File System)

Mount a tmpfs volume for fast temporary storage:
@batch(
    cpu=4,
    memory=16000,
    use_tmpfs=True,
    tmpfs_size=4000,           # 4GB tmpfs
    tmpfs_path='/metaflow_temp',
    tmpfs_tempdir=True         # Set METAFLOW_TEMPDIR to tmpfs path
)
@step
def fast_io(self):
    import tempfile
    # temp files use tmpfs for fast I/O
    with tempfile.NamedTemporaryFile() as f:
        f.write(b'fast writes to RAM')
    pass
Tmpfs is not available on AWS Fargate compute environments.

Host and EFS Volumes

Mount host directories or EFS file systems:
@batch(
    cpu=4,
    memory=8000,
    host_volumes=['/host/path:/container/path'],
    efs_volumes=['fs-12345:/efs/mount']
)
@step
def persistent_storage(self):
    # Access mounted volumes
    pass

Ephemeral Storage (Fargate)

For Fargate compute environments, configure ephemeral storage:
@batch(
    cpu=4,
    memory=8000,
    ephemeral_storage=100  # 100 GiB (21-200 range)
)
@step
def fargate_task(self):
    pass

Custom Logging

Configure custom log drivers:
@batch(
    cpu=2,
    memory=4000,
    log_driver='awslogs',
    log_options=[
        'awslogs-group:my-log-group',
        'awslogs-region:us-west-2',
        'awslogs-stream-prefix:myflow'
    ]
)
@step
def custom_logging(self):
    pass

Privileged Mode

Run containers in privileged mode (use cautiously):
@batch(cpu=2, memory=4000, privileged=True)
@step
def privileged_task(self):
    # Container runs with elevated privileges
    pass

Timeouts and Retries

Execution Timeout

Set a maximum runtime for your step:
from metaflow import timeout

@batch(cpu=4, memory=8000)
@timeout(hours=2)
@step
def long_running_task(self):
    # Task will be terminated after 2 hours
    pass

Retry Logic

from metaflow import retry

@batch(cpu=4, memory=8000)
@retry(times=3)
@step
def resilient_task(self):
    # Automatically retries up to 3 times on failure
    pass
With backoff:
@batch(cpu=4, memory=8000)
@retry(times=3, minutes_between_retries=5)
@step
def retry_with_delay(self):
    # Wait 5 minutes between retries
    pass

Tagging

AWS Batch Tags

Add AWS tags for cost tracking and organization:
@batch(
    cpu=4,
    memory=8000,
    aws_batch_tags={
        'project': 'recommendation-engine',
        'team': 'ml-platform',
        'environment': 'production'
    }
)
@step
def tagged_task(self):
    pass

Default Tags

Set default tags via configuration:
# In your config
BATCH_DEFAULT_TAGS = {
    'company': 'acme',
    'cost-center': 'ml'
}
Metaflow automatically tags jobs with:
  • app: metaflow
  • metaflow.flow_name
  • metaflow.run_id
  • metaflow.step_name
  • metaflow.user
  • metaflow.version
To disable automatic tagging:
export METAFLOW_BATCH_EMIT_TAGS=false

Multi-Node Jobs

Run distributed jobs across multiple nodes:
@batch(cpu=8, memory=32000)
@step
def distributed_job(self):
    from metaflow import current
    
    if current.parallel.node_index == 0:
        print("I'm the main node")
    
    print(f"Node {current.parallel.node_index} of {current.parallel.num_nodes}")
    print(f"Main node IP: {current.parallel.main_ip}")
Run with multiple nodes:
python myflow.py run --with batch --num-parallel 4

Environment Variables

Custom Environment

Set environment variables for your step:
from metaflow import environment

@batch(cpu=4, memory=8000)
@environment(vars={
    'MODEL_PATH': 's3://my-bucket/models/',
    'API_KEY': 'secret-key-123'
})
@step
def configured_step(self):
    import os
    model_path = os.environ['MODEL_PATH']
    pass

Metaflow Environment Variables

Metaflow automatically sets:
  • METAFLOW_FLOW_NAME
  • METAFLOW_RUN_ID
  • METAFLOW_STEP_NAME
  • METAFLOW_TASK_ID
  • METAFLOW_RETRY_COUNT
  • METAFLOW_CODE_URL
  • AWS_BATCH_JOB_ID
  • AWS_BATCH_JOB_ATTEMPT

Monitoring

Task Metadata

Metaflow automatically captures:
  • AWS Batch job ID
  • Job attempt number
  • Compute environment name
  • Job queue name
  • EC2 instance metadata (type, region, availability zone)
  • CloudWatch logs location
Access metadata:
from metaflow import Flow

run = Flow('MyFlow').latest_run
for step in run:
    for task in step:
        print(task.metadata_dict)

Spot Interruption Handling

Metaflow monitors for spot instance interruptions:
from metaflow import current

@batch(cpu=4, memory=8000)
@step
def spot_aware_task(self):
    # Check if running on spot
    if current.spot_termination_notice:
        print("Spot interruption detected!")
        # Save checkpoint

Local Testing

Test batch steps locally before deploying:
# Run locally (ignores @batch)
python myflow.py run

# Run with batch locally (simulates batch environment)
python myflow.py run --with batch

Managing Batch Jobs

List Running Jobs

python myflow.py batch list

Kill Running Jobs

# Kill jobs from latest run
python myflow.py batch kill

# Kill jobs from specific run
python myflow.py batch kill --run-id <run-id>

# Kill all your jobs
python myflow.py batch kill --my-runs

Configuration

Key environment variables:
VariableDescriptionDefault
METAFLOW_BATCH_CONTAINER_IMAGEDefault Docker imagepython:X.Y
METAFLOW_BATCH_CONTAINER_REGISTRYDefault Docker registryNone
METAFLOW_BATCH_JOB_QUEUEDefault job queueNone
METAFLOW_ECS_S3_ACCESS_IAM_ROLEIAM role for containerNone
METAFLOW_ECS_FARGATE_EXECUTION_ROLEFargate execution roleNone
METAFLOW_BATCH_EMIT_TAGSEnable automatic taggingtrue
METAFLOW_BATCH_DEFAULT_TAGSDefault AWS tags{}
See AWS Configuration for complete setup.

Best Practices

Start with minimal resources and scale up based on actual usage. Over-provisioning wastes cost.
# Start small
@batch(cpu=1, memory=4000)
# Then profile and adjust
Configure your Batch compute environment to use spot instances for cost savings (up to 90% cheaper).
Use smaller base images and multi-stage builds to reduce pull time:
FROM python:3.9-slim  # Smaller than python:3.9
For I/O-intensive workloads, use tmpfs to avoid disk bottlenecks:
@batch(use_tmpfs=True, tmpfs_size=8000)
Always set timeouts to prevent runaway jobs:
@batch
@timeout(hours=6)

Troubleshooting

Job Stuck in RUNNABLE

Issue: Job stays in RUNNABLE state Causes:
  • Insufficient compute resources in compute environment
  • Resource requirements too high for available instances
  • Service limits reached
Solution: Check AWS Batch console for specific errors and adjust resources or compute environment.

Container Pull Failures

Issue: “CannotPullContainerError” Solution:
  • Verify image exists and is accessible
  • Check IAM permissions for ECR
  • Ensure execution role has ECR pull permissions

OOM (Out of Memory) Kills

Issue: Tasks killed due to memory Solution: Increase memory:
@batch(cpu=4, memory=32000)  # Increase from default 4096

Timeout Errors

Issue: Job exceeds time limit Solution: Increase timeout or optimize code:
@timeout(hours=12)  # Increase limit

Next Steps

Step Functions

Deploy complete workflows to AWS Step Functions

Configuration

Configure AWS credentials and resources

Scaling

Learn more about scaling patterns

Resources

Resource management best practices

Build docs developers (and LLMs) love