Introduction
Generative Adversarial Networks (GANs) are a powerful class of neural networks used for generating realistic data. In the art world, GANs have opened new avenues for creativity, enabling artists to create unique visuals. In this article, we will walk through the process of creating AI-generated art using GANs in Python.
What are GANs?
GANs consist of two neural networks: a generator and a discriminator. The generator creates fake data, while the discriminator evaluates the authenticity of the data. These networks are trained together in a process where the generator aims to create data indistinguishable from real data, and the discriminator aims to correctly identify real vs. fake data.
Setting Up the Environment
To get started, we need to set up our Python environment. We will use TensorFlow and Keras for building our GAN.
!pip install tensorflow
!pip install keras
Building the GAN
First, let's import the necessary libraries.
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
Creating the Generator
The generator creates images from random noise. Here's how to build a simple generator.
def build_generator():
model = Sequential()
model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(28 * 28 * 1, activation='tanh'))
model.add(Reshape((28, 28, 1)))
return model
Creating the Discriminator
The discriminator evaluates the authenticity of the generated images.
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
return model
Compiling the GAN
Now, let's compile the GAN by combining the generator and discriminator.
def compile_gan(generator, discriminator):
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
discriminator.trainable = False
gan_input = tf.keras.Input(shape=(100,))
img = generator(gan_input)
validity = discriminator(img)
gan = tf.keras.Model(gan_input, validity)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
return gan
Training the GAN
Let's train the GAN with MNIST data.
def train_gan(gan, generator, discriminator, epochs=10000, batch_size=64, save_interval=1000):
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = X_train / 127.5 - 1.0
X_train = np.expand_dims(X_train, axis=3)
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, 100))
gen_imgs = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(imgs, valid)
d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, 100))
g_loss = gan.train_on_batch(noise, valid)
if epoch % save_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]}%] [G loss: {g_loss}]")
generator = build_generator()
discriminator = build_discriminator()
gan = compile_gan(generator, discriminator)
train_gan(gan, generator, discriminator)
Visualizing the Results
Finally, let's visualize some of the generated images.
import matplotlib.pyplot as plt
def visualize_generated_images(generator, examples=10, dim=(1, 10), figsize=(10, 1)):
noise = np.random.normal(0, 1, (examples, 100))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5
plt.figure(figsize=figsize)
for i in range(examples):
plt.subplot(dim[0], dim[1], i + 1)
plt.imshow(generated_images[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
visualize_generated_images(generator)
Conclusion
Generative Adversarial Networks (GANs) are powerful tools for creating AI-generated art. By following this guide, you can start experimenting with GANs to create unique and compelling visuals. As you gain more experience, you can further refine your models and explore more complex architectures.