FlaxAutoencoderKL is a VAE model with KL loss for encoding and decoding images to and from latent representations.
FlaxAutoencoderKL
Model architecture
The VAE consists of:- Encoder: Converts images to latent representations using downsampling blocks
- Decoder: Reconstructs images from latents using upsampling blocks
- Diagonal Gaussian distribution: For sampling latents from the distribution
- Quantization convolutions: Optional 1x1 convolutions before/after latent space
src/maxdiffusion/models/vae_flax.py:764
Configuration parameters
Number of channels in the input image
Number of channels in the output
Tuple of downsample block types
Tuple of upsample block types
Tuple of block output channels
Number of ResNet layers per block
Number of channels in the latent space
The number of groups for normalization
The component-wise standard deviation of the trained latent space. Used to scale the latent space to have unit variance when training the diffusion model
Methods
encode
Encodes an image into its latent representation.
Parameters:
Input image tensor
Whether to use deterministic encoding
Encoded outputs represented as the mean and logvar. Allows for sampling latents from the distribution
decode
Decodes latent representations into images.
Parameters:
Latent representations to decode
Whether to use deterministic decoding
The decoded output sample from the last layer of the model with shape
(batch_size, num_channels, height, width)FlaxDiagonalGaussianDistribution
Distribution class for VAE latent space located atsrc/maxdiffusion/models/vae_flax.py:725.
Methods
sample
Samples from the distribution.
Random key for sampling
Sampled latent tensor
mode
Returns the mode of the distribution (mean).
The mean of the distribution
Sub-components
FlaxEncoder
VAE encoder implementation located atsrc/maxdiffusion/models/vae_flax.py:483.
FlaxDecoder
VAE decoder implementation located atsrc/maxdiffusion/models/vae_flax.py:603.
FlaxResnetBlock2D
2D ResNet block with group normalization located atsrc/maxdiffusion/models/vae_flax.py:130.