Skip to main content
The FlaxAutoencoderKL is a VAE model with KL loss for encoding and decoding images to and from latent representations.

FlaxAutoencoderKL

Model architecture

The VAE consists of:
  • Encoder: Converts images to latent representations using downsampling blocks
  • Decoder: Reconstructs images from latents using upsampling blocks
  • Diagonal Gaussian distribution: For sampling latents from the distribution
  • Quantization convolutions: Optional 1x1 convolutions before/after latent space
Located in src/maxdiffusion/models/vae_flax.py:764

Configuration parameters

in_channels
int
default:"3"
Number of channels in the input image
out_channels
int
default:"3"
Number of channels in the output
down_block_types
Tuple[str]
default:"('DownEncoderBlock2D',)"
Tuple of downsample block types
up_block_types
Tuple[str]
default:"('UpDecoderBlock2D',)"
Tuple of upsample block types
block_out_channels
Tuple[int]
default:"(64,)"
Tuple of block output channels
layers_per_block
int
default:"1"
Number of ResNet layers per block
latent_channels
int
default:"4"
Number of channels in the latent space
norm_num_groups
int
default:"32"
The number of groups for normalization
scaling_factor
float
default:"0.18215"
The component-wise standard deviation of the trained latent space. Used to scale the latent space to have unit variance when training the diffusion model

Methods

encode

Encodes an image into its latent representation. Parameters:
sample
jnp.ndarray
Input image tensor
deterministic
bool
default:"True"
Whether to use deterministic encoding
Returns:
latent_dist
FlaxDiagonalGaussianDistribution
Encoded outputs represented as the mean and logvar. Allows for sampling latents from the distribution

decode

Decodes latent representations into images. Parameters:
latents
jnp.ndarray
Latent representations to decode
deterministic
bool
default:"True"
Whether to use deterministic decoding
Returns:
sample
jnp.ndarray
The decoded output sample from the last layer of the model with shape (batch_size, num_channels, height, width)

FlaxDiagonalGaussianDistribution

Distribution class for VAE latent space located at src/maxdiffusion/models/vae_flax.py:725.

Methods

sample

Samples from the distribution.
key
jax.random.PRNGKey
Random key for sampling
sample
jnp.ndarray
Sampled latent tensor

mode

Returns the mode of the distribution (mean).
mode
jnp.ndarray
The mean of the distribution

Sub-components

FlaxEncoder

VAE encoder implementation located at src/maxdiffusion/models/vae_flax.py:483.

FlaxDecoder

VAE decoder implementation located at src/maxdiffusion/models/vae_flax.py:603.

FlaxResnetBlock2D

2D ResNet block with group normalization located at src/maxdiffusion/models/vae_flax.py:130.

Build docs developers (and LLMs) love