AI & Machine Learning10 min read97 views

Building Generative AI from Scratch: A Complete Guide to Image Generation

By Adminabout 1 month ago
Building Generative AI from Scratch: A Complete Guide to Image Generation

Enjoying this article? Get practical AI and engineering notes in your inbox.

We respect your privacy using double opt-in. Unsubscribe at any time.

Building Generative AI from Scratch: A Complete Guide to Image Generation

Generative AI has revolutionized how we create images, from generating photorealistic faces to creating artistic masterpieces. But how does it actually work? In this tutorial, we'll build image-generating AI models from scratch, understanding every component along the way.

What You'll Learn

By the end of this tutorial, you'll understand:

  • The core principles behind generative models
  • How to implement a Variational Autoencoder (VAE) from scratch
  • The architecture and training of Generative Adversarial Networks (GANs)
  • Practical techniques for generating high-quality images
  • How to train your own image generator

Prerequisites

Before diving in, you should have:

  • Basic Python programming knowledge
  • Understanding of neural networks and backpropagation
  • Familiarity with NumPy and basic linear algebra
  • PyTorch or TensorFlow installed (we'll use PyTorch)

What is Generative AI?

Generative AI refers to models that can create new data similar to their training data. Unlike discriminative models that classify or predict, generative models learn the underlying probability distribution of data and can sample from it to create entirely new examples.

For images, this means learning what makes an image look like a face, a cat, or any other category, then generating new images that belong to that distribution.

Part 1: Understanding the Mathematics

The Core Problem

Generative modeling aims to learn a probability distribution p(x) where x represents our data (images). Once we know this distribution, we can sample from it to generate new images.

The challenge: Real-world distributions like "all possible cat images" are incredibly complex and high-dimensional.

Latent Variable Models

The key insight is to introduce a latent (hidden) variable z that captures meaningful features in a lower-dimensional space. We then model:

p(x) = ∫ p(x|z) p(z) dz

Where:

  • z is our latent code (compressed representation)
  • p(z) is a simple prior (often Gaussian)
  • p(x|z) is our decoder that maps latent codes to images

Part 2: Building a Variational Autoencoder (VAE)

VAEs are a foundational generative model. They learn to compress images into a latent space and reconstruct them, while ensuring the latent space is well-structured for generation.

Architecture Overview

A VAE consists of two main components:

  1. Encoder: Maps images to latent distributions q(z|x)
  2. Decoder: Reconstructs images from latent codes p(x|z)

Implementation: Setting Up

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

Building the Encoder

The encoder takes an image and outputs parameters of a distribution (mean and log-variance):

class Encoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(Encoder, self).__init__()
        
        # Fully connected layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
    def forward(self, x):
        # Flatten image
        x = x.view(x.size(0), -1)
        
        # Hidden layer with ReLU
        h = F.relu(self.fc1(x))
        
        # Output mean and log-variance
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        
        return mu, logvar

Building the Decoder

The decoder takes a latent code and reconstructs an image:

class Decoder(nn.Module):
    def __init__(self, latent_dim=20, hidden_dim=400, output_dim=784):
        super(Decoder, self).__init__()
        
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, z):
        # Hidden layer
        h = F.relu(self.fc1(z))
        
        # Output layer with sigmoid (for pixel values in [0,1])
        x_reconstructed = torch.sigmoid(self.fc2(h))
        
        return x_reconstructed

The Reparameterization Trick

To backpropagate through stochastic sampling, we use the reparameterization trick:

Instead of sampling z ~ N(μ, σ²), we sample ε ~ N(0, 1) and compute z = μ + σ * ε

def reparameterize(mu, logvar):
    """
    Reparameterization trick: z = mu + sigma * epsilon
    """
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std
    return z

Complete VAE Model

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
        
    def forward(self, x):
        # Encode
        mu, logvar = self.encoder(x)
        
        # Reparameterize
        z = reparameterize(mu, logvar)
        
        # Decode
        x_reconstructed = self.decoder(z)
        
        return x_reconstructed, mu, logvar
    
    def generate(self, num_samples=1):
        """Generate new images by sampling from prior"""
        with torch.no_grad():
            # Sample from standard normal
            z = torch.randn(num_samples, self.decoder.fc1.in_features)
            
            # Decode
            samples = self.decoder(z)
            
            return samples

Loss Function: The ELBO

VAEs are trained by maximizing the Evidence Lower Bound (ELBO):

ELBO = E[log p(x|z)] - KL(q(z|x) || p(z))

This consists of:

  1. Reconstruction loss: How well we reconstruct the input
  2. KL divergence: How close our encoded distribution is to the prior
def vae_loss(x_reconstructed, x, mu, logvar):
    """
    VAE loss = Reconstruction loss + KL divergence
    """
    # Reconstruction loss (binary cross-entropy)
    reconstruction_loss = F.binary_cross_entropy(
        x_reconstructed, 
        x.view(-1, 784), 
        reduction='sum'
    )
    
    # KL divergence
    # KL(N(mu, sigma^2) || N(0, 1))
    # = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return reconstruction_loss + kl_divergence

Training the VAE

def train_vae(model, train_loader, optimizer, epoch):
    model.train()
    train_loss = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        
        # Forward pass
        x_reconstructed, mu, logvar = model(data)
        
        # Compute loss
        loss = vae_loss(x_reconstructed, data, mu, logvar)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]'
                  f'\tLoss: {loss.item() / len(data):.4f}')
    
    avg_loss = train_loss / len(train_loader.dataset)
    print(f'====> Epoch {epoch} Average loss: {avg_loss:.4f}')
    
    return avg_loss

Putting It All Together

# Hyperparameters
batch_size = 128
latent_dim = 20
epochs = 10
learning_rate = 1e-3

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize model
model = VAE(input_dim=784, hidden_dim=400, latent_dim=latent_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(1, epochs + 1):
    train_vae(model, train_loader, optimizer, epoch)
    
    # Generate samples
    if epoch % 2 == 0:
        samples = model.generate(num_samples=16)
        # Visualize samples (implementation omitted for brevity)

Part 3: Generative Adversarial Networks (GANs)

GANs take a different approach: they pit two networks against each other in a game-theoretic framework.

The GAN Framework

  • Generator (G): Creates fake images from random noise
  • Discriminator (D): Tries to distinguish real from fake images

The generator tries to fool the discriminator, while the discriminator tries to correctly classify images. This adversarial training leads to highly realistic generations.

Building the Generator

class Generator(nn.Module):
    def __init__(self, latent_dim=100, output_dim=784):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            
            nn.Linear(1024, output_dim),
            nn.Tanh()  # Output in [-1, 1]
        )
        
    def forward(self, z):
        return self.model(z)

Building the Discriminator

class Discriminator(nn.Module):
    def __init__(self, input_dim=784):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()  # Output probability
        )
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)

GAN Training Loop

Training GANs is tricky because we need to balance two competing objectives:

def train_gan(generator, discriminator, train_loader, g_optimizer, d_optimizer, epoch):
    generator.train()
    discriminator.train()
    
    for batch_idx, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)
        
        # Labels
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        # ============================================
        # Train Discriminator
        # ============================================
        d_optimizer.zero_grad()
        
        # Real images
        real_outputs = discriminator(real_images)
        d_loss_real = F.binary_cross_entropy(real_outputs, real_labels)
        
        # Fake images
        z = torch.randn(batch_size, 100)
        fake_images = generator(z)
        fake_outputs = discriminator(fake_images.detach())
        d_loss_fake = F.binary_cross_entropy(fake_outputs, fake_labels)
        
        # Total discriminator loss
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()
        
        # ============================================
        # Train Generator
        # ============================================
        g_optimizer.zero_grad()
        
        # Generate fake images
        z = torch.randn(batch_size, 100)
        fake_images = generator(z)
        
        # Try to fool discriminator
        fake_outputs = discriminator(fake_images)
        g_loss = F.binary_cross_entropy(fake_outputs, real_labels)
        
        g_loss.backward()
        g_optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch}] Batch [{batch_idx}/{len(train_loader)}] '
                  f'D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')

Part 4: Advanced Techniques

Convolutional Architectures

For better image quality, use convolutional layers instead of fully connected:

class ConvGenerator(nn.Module):
    def __init__(self, latent_dim=100):
        super(ConvGenerator, self).__init__()
        
        self.init_size = 7  # Initial spatial size
        self.fc = nn.Linear(latent_dim, 128 * self.init_size ** 2)
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh()
        )
        
    def forward(self, z):
        out = self.fc(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

Stabilizing GAN Training

GANs are notoriously difficult to train. Here are key techniques:

  1. Use label smoothing: Instead of 1s and 0s, use 0.9 and 0.1
  2. Add noise to discriminator inputs: Helps prevent mode collapse
  3. Use different learning rates: Often slower for discriminator
  4. Monitor both losses: If one dominates, adjust learning rates
# Label smoothing
real_labels = torch.ones(batch_size, 1) * 0.9
fake_labels = torch.zeros(batch_size, 1) + 0.1

# Different learning rates
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

Part 5: Evaluating Generative Models

Visual Inspection

The most basic evaluation: do the generated images look good?

def visualize_generations(model, num_images=16):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_images, 100)
        generated = model(z)
        
        fig, axes = plt.subplots(4, 4, figsize=(10, 10))
        for i, ax in enumerate(axes.flat):
            img = generated[i].view(28, 28).cpu().numpy()
            ax.imshow(img, cmap='gray')
            ax.axis('off')
        plt.tight_layout()
        plt.show()

Quantitative Metrics

Inception Score (IS): Measures quality and diversity

Fréchet Inception Distance (FID): Compares statistics of generated and real images (lower is better)

Part 6: Practical Tips and Tricks

Data Preprocessing

# Normalize to [-1, 1] for better training
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

Latent Space Interpolation

Explore what your model has learned by interpolating between latent codes:

def interpolate(model, z1, z2, steps=10):
    """Generate images interpolating between two latent codes"""
    model.eval()
    interpolations = []
    
    with torch.no_grad():
        for alpha in np.linspace(0, 1, steps):
            z = (1 - alpha) * z1 + alpha * z2
            img = model(z)
            interpolations.append(img)
    
    return interpolations

Conditional Generation

Generate specific types of images by conditioning on labels:

class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10):
        super(ConditionalGenerator, self).__init__()
        
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim * 2, 256),
            # ... rest of architecture
        )
        
    def forward(self, z, labels):
        # Concatenate noise with label embedding
        label_input = self.label_emb(labels)
        gen_input = torch.cat([z, label_input], dim=1)
        return self.model(gen_input)

Conclusion

You've now learned the fundamentals of building generative AI from scratch! We covered:

  • The mathematical foundations of generative modeling
  • Implementing VAEs for learning structured latent spaces
  • Building GANs for high-quality image synthesis
  • Advanced techniques for better training and evaluation

Next Steps

  • Experiment with different architectures (DCGAN, StyleGAN)
  • Try diffusion models, the current state-of-the-art
  • Apply these techniques to your own datasets
  • Explore conditional generation and controllability

Resources

  • Papers: "Auto-Encoding Variational Bayes" (Kingma & Welling), "Generative Adversarial Networks" (Goodfellow et al.)
  • Code: PyTorch Examples
  • Datasets: MNIST, CelebA, ImageNet

Remember: generative AI is an active research area. The models we built here are foundational, but modern architectures like diffusion models and transformers are pushing the boundaries even further. Use this knowledge as a springboard to explore cutting-edge techniques!

Complete Example Code

Here's a minimal working example you can run:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Simple VAE for MNIST
class SimpleVAE(nn.Module):
    def __init__(self):
        super(SimpleVAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x.view(-1, 784)))
        return self.fc21(h1), self.fc22(h1)

    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mu + eps*std
        return self.decode(z), mu, logvar

# Train it!
model = SimpleVAE()
optimizer = optim.Adam(model.parameters())
train_loader = DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True
)

for epoch in range(10):
    for data, _ in train_loader:
        optimizer.zero_grad()
        recon, mu, logvar = model(data)
        loss = nn.functional.binary_cross_entropy(recon, data.view(-1, 784), reduction='sum')
        loss += -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1} complete')

print('Training finished! Generate images with model.decode(torch.randn(16, 20))')

Happy generating! 🎨

Tags:

machine learninggenerative aideep learningcomputer visiontutorial

Subscribe to the newsletter

Get thoughtful updates on AI, engineering, and product work.

We respect your privacy using double opt-in. Unsubscribe at any time.