AI Art with Generative Adversarial Networks (GANs) in Python

Introduction

Generative Adversarial Networks (GANs) are a powerful class of neural networks used for generating realistic data. In the art world, GANs have opened new avenues for creativity, enabling artists to create unique visuals. In this article, we will walk through the process of creating AI-generated art using GANs in Python.

What are GANs?

GANs consist of two neural networks: a generator and a discriminator. The generator creates fake data, while the discriminator evaluates the authenticity of the data. These networks are trained together in a process where the generator aims to create data indistinguishable from real data, and the discriminator aims to correctly identify real vs. fake data.

Setting Up the Environment

To get started, we need to set up our Python environment. We will use TensorFlow and Keras for building our GAN.

!pip install tensorflow
!pip install keras

Building the GAN

First, let's import the necessary libraries.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

Creating the Generator

The generator creates images from random noise. Here's how to build a simple generator.

def build_generator():
    model = Sequential()
    model.add(Dense(256, input_dim=100))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(28 * 28 * 1, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model

Creating the Discriminator

The discriminator evaluates the authenticity of the generated images.

def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    return model

Compiling the GAN

Now, let's compile the GAN by combining the generator and discriminator.

def compile_gan(generator, discriminator):
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
    discriminator.trainable = False
    gan_input = tf.keras.Input(shape=(100,))
    img = generator(gan_input)
    validity = discriminator(img)
    gan = tf.keras.Model(gan_input, validity)
    gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return gan

Training the GAN

Let's train the GAN with MNIST data.

def train_gan(gan, generator, discriminator, epochs=10000, batch_size=64, save_interval=1000):
    (X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    X_train = X_train / 127.5 - 1.0
    X_train = np.expand_dims(X_train, axis=3)
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        d_loss_real = discriminator.train_on_batch(imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan.train_on_batch(noise, valid)
        if epoch % save_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]}%] [G loss: {g_loss}]")

generator = build_generator()
discriminator = build_discriminator()
gan = compile_gan(generator, discriminator)
train_gan(gan, generator, discriminator)

Visualizing the Results

Finally, let's visualize some of the generated images.

import matplotlib.pyplot as plt

def visualize_generated_images(generator, examples=10, dim=(1, 10), figsize=(10, 1)):
    noise = np.random.normal(0, 1, (examples, 100))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5
    plt.figure(figsize=figsize)
    for i in range(examples):
        plt.subplot(dim[0], dim[1], i + 1)
        plt.imshow(generated_images[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

visualize_generated_images(generator)

Conclusion

Generative Adversarial Networks (GANs) are powerful tools for creating AI-generated art. By following this guide, you can start experimenting with GANs to create unique and compelling visuals. As you gain more experience, you can further refine your models and explore more complex architectures.