Learning GANs on a basic code provided by MIT
Abstract
This is a simple application of GANs to reconstruct digits from the MNIST dataset.
What follows is a conversation with #GPT / Bing to make sense on why the loss function used for discriminator and generator is the same. It turns out the LLM provides a decent explanation based on Goodfellow's white paper.
Hands-on training at MIT: GANs applied to reconstruct MNIST digits
Consider the code below which is executed in a double nested for loop, for each epoch and for each mini-batch:?
def perform_train_step(generator, discriminator, real_data, \
??loss_function, generator_optimizer, discriminator_optimizer, device = get_device()):
??""" Perform a single training step. """
??# 1. PREPARATION
??# Set real and fake labels.
??real_label, fake_label = 1.0, 0.0
??# Get images on CPU or GPU as configured and available
??# Also set 'actual batch size', whih can be smaller than BATCH_SIZE
??# in some cases.
??real_images = real_data[0].to(device)
??actual_batch_size = real_images.size(0)
??label = torch.full((actual_batch_size,1), real_label, device=device)
??# 2. TRAINING THE DISCRIMINATOR
??# Zero the gradients for discriminator
??efficient_zero_grad(discriminator)
??# Forward + backward on real images, reshaped
??real_images = real_images.view(real_images.size(0), -1)
??error_real_images = forward_and_backward(discriminator, real_images, \
????loss_function, label)
??# Forward + backward on generated images
??noise = generate_noise(actual_batch_size, device=device)
??generated_images = generator(noise)
领英推荐
??label.fill_(fake_label)
??error_generated_images =forward_and_backward(discriminator, \
????generated_images.detach(), loss_function, label)
??# Optim for discriminator
??discriminator_optimizer.step()
??# 3. TRAINING THE GENERATOR
??# Forward + backward + optim for generator, including zero grad
??efficient_zero_grad(generator)
??label.fill_(real_label)
??error_generator = forward_and_backward(discriminator, generated_images, loss_function, label)
??generator_optimizer.step()
??# 4. COMPUTING RESULTS
??# Compute loss values in floats for discriminator, which is joint loss.
??error_discriminator = error_real_images + error_generated_images
??# Return generator and discriminator loss so that it can be printed.
??return error_generator, error_discriminator
Question: why is the error function the same for generator and discriminator ??While in the goodfellow paper the formulas for error functions differ from generator and discriminator https://arxiv.org/pdf/1406.2661.pdf
The paper you are referring to is the original paper that introduced generative adversarial networks (GANs) by Goodfellow et al. in 2014. In this paper, the authors proposed two different formulations for the error functions of the generator and the discriminator. The first one is the standard GAN loss function, also known as the min-max loss, which is given by:
GminDmaxV(D,G)=Ex~pdata(x)[logD(x)]+Ez~pz(z)[log(1?D(G(z)))]
This loss function corresponds to a minimax game between the generator and the discriminator, where the generator tries to minimize the probability of the discriminator correctly classifying the fake samples, and the discriminator tries to maximize the probability of correctly classifying both the real and the fake samples.
The second formulation is the non-saturating GAN loss function, which is given by:
GminDmaxV(D,G)=Ex~pdata(x)[logD(x)]+Ez~pz(z)[logD(G(z))]
This loss function is a slight variation of the standard one, where the generator tries to maximize the probability of the discriminator being fooled by the fake samples, instead of minimizing the probability of the discriminator being correct. This formulation has the advantage of providing stronger gradients for the generator, especially in the early stages of training, when the discriminator can easily reject the fake samples.
The authors showed that both formulations are equivalent in terms of the optimal solution, which is achieved when the generator perfectly replicates the data distribution and the discriminator outputs 0.5 for any input. However, they also argued that the non-saturating loss function might be more practical for training GANs, as it avoids the problem of the generator gradient vanishing when the discriminator is too confident. Therefore, they used the non-saturating loss function for their experiments in the paper.
so that the reason why they use for both generator and discriminator the same code: def initialize_loss():
??""" Initialize loss function. """
??return nn.BCELoss()
Yes, that’s correct. The nn.BCELoss() function in PyTorch is used to compute the Binary Cross Entropy (BCE) loss, which is suitable for binary classification problems. In the context of GANs, the discriminator is essentially a binary classifier that tries to distinguish between real and fake (generated) samples. Therefore, BCE loss is a natural choice for the discriminator.
As for the generator, although its goal is different (i.e., to fool the discriminator), it’s indirectly trying to minimize the same BCE loss through the discriminator. The generator wants the discriminator to output a value close to 1 (real) for the fake samples it generates. This is equivalent to minimizing the BCE loss with respect to the generator’s parameters.
So, while the generator and the discriminator have different objectives, they both contribute to the same BCE loss function, but from different perspectives. This is why the same nn.BCELoss() function is used to initialize the loss for both the generator and the discriminator in the code you provided. The difference lies in how the loss is used in the training process of the generator and the discriminator.