import tensorflow.compat.v1 as tf !pip install tensorflow-gan import tensorflow_gan as tfgan import tensorflow_datasets as tfds import matplotlib.pyplot as plt import numpy as np # Allow matplotlib images to render immediately. %matplotlib inline tf.logging.set_verbosity(tf.logging.ERROR) # Disable noisy outputs.
import tensorflow_datasets as tfds import tensorflow.compat.v1 as tf def input_fn(mode, params): assert 'batch_size' in params assert 'noise_dims' in params bs = params['batch_size'] nd = params['noise_dims'] split = 'train' if mode == tf.estimator.ModeKeys.TRAIN else 'test' shuffle = (mode == tf.estimator.ModeKeys.TRAIN) just_noise = (mode == tf.estimator.ModeKeys.PREDICT) noise_ds = (tf.data.Dataset.from_tensors(0).repeat() .map(lambda _: tf.random.normal([bs, nd]))) if just_noise: return noise_ds def _preprocess(element): # Map [0, 255] to [-1, 1]. images = (tf.cast(element['image'], tf.float32) - 127.5) / 127.5 return images images_ds = (tfds.load('mnist:3.*.*', split=split) .map(_preprocess) .cache() .repeat()) if shuffle: images_ds = images_ds.shuffle( buffer_size=10000, reshuffle_each_iteration=True) images_ds = (images_ds.batch(bs, drop_remainder=True) .prefetch(tf.data.experimental.AUTOTUNE)) return tf.data.Dataset.zip((noise_ds, images_ds))
import matplotlib.pyplot as plt import tensorflow_datasets as tfds import tensorflow_gan as tfgan import numpy as np params = {'batch_size': 100, 'noise_dims':64} with tf.Graph().as_default(): ds = input_fn(tf.estimator.ModeKeys.TRAIN, params) numpy_imgs = next(iter(tfds.as_numpy(ds)))[1] img_grid = tfgan.eval.python_image_grid(numpy_imgs, grid_shape=(10, 10)) plt.axis('off') plt.imshow(np.squeeze(img_grid)) plt.show()
运行结果:
def _dense(inputs, units, l2_weight): return tf.layers.dense( inputs, units, None, kernel_initializer=tf.keras.initializers.glorot_uniform, kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight), bias_regularizer=tf.keras.regularizers.l2(l=l2_weight)) def _batch_norm(inputs, is_training): return tf.layers.batch_normalization( inputs, momentum=0.999, epsilon=0.001, training=is_training) def _deconv2d(inputs, filters, kernel_size, stride, l2_weight): return tf.layers.conv2d_transpose( inputs, filters, [kernel_size, kernel_size], strides=[stride, stride], activation=tf.nn.relu, padding='same', kernel_initializer=tf.keras.initializers.glorot_uniform, kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight), bias_regularizer=tf.keras.regularizers.l2(l=l2_weight)) def _conv2d(inputs, filters, kernel_size, stride, l2_weight): return tf.layers.conv2d( inputs, filters, [kernel_size, kernel_size], strides=[stride, stride], activation=None, padding='same', kernel_initializer=tf.keras.initializers.glorot_uniform, kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight), bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))
def unconditional_generator(noise, mode, weight_decay=2.5e-5): """Generator to produce unconditional MNIST images.""" is_training = (mode == tf.estimator.ModeKeys.TRAIN) net = _dense(noise, 1024, weight_decay) net = _batch_norm(net, is_training) net = tf.nn.relu(net) net = _dense(net, 7 * 7 * 256, weight_decay) net = _batch_norm(net, is_training) net = tf.nn.relu(net) net = tf.reshape(net, [-1, 7, 7, 256]) net = _deconv2d(net, 64, 4, 2, weight_decay) net = _deconv2d(net, 64, 4, 2, weight_decay) # Make sure that generator output is in the same range as `inputs` # ie [-1, 1]. net = _conv2d(net, 1, 4, 1, 0.0) net = tf.tanh(net) return net
上述代码是一个无条件生成器(unconditional_generator),用于生成MNIST图像。
_leaky_relu = lambda net: tf.nn.leaky_relu(net, alpha=0.01) def unconditional_discriminator(img, unused_conditioning, mode, weight_decay=2.5e-5): del unused_conditioning is_training = (mode == tf.estimator.ModeKeys.TRAIN) net = _conv2d(img, 64, 4, 2, weight_decay) net = _leaky_relu(net) net = _conv2d(net, 128, 4, 2, weight_decay) net = _leaky_relu(net) net = tf.layers.flatten(net) net = _dense(net, 1024, weight_decay) net = _batch_norm(net, is_training) net = _leaky_relu(net) net = _dense(net, 1, weight_decay) return net
上述代码是一个无条件的判别器(discriminator)函数,用于生成对抗网络(GAN)。
TF-GAN provides some standard methods of evaluating generative models. In this example, we measure:
TF-GAN 提供了一些评估生成模型的标准方法。在这个例子中,我们测量
from tensorflow_gan.examples.mnist import util as eval_util import os def get_eval_metric_ops_fn(gan_model): real_data_logits = tf.reduce_mean(gan_model.discriminator_real_outputs) gen_data_logits = tf.reduce_mean(gan_model.discriminator_gen_outputs) real_mnist_score = eval_util.mnist_score(gan_model.real_data) generated_mnist_score = eval_util.mnist_score(gan_model.generated_data) frechet_distance = eval_util.mnist_frechet_distance( gan_model.real_data, gan_model.generated_data) return { 'real_data_logits': tf.metrics.mean(real_data_logits), 'gen_data_logits': tf.metrics.mean(gen_data_logits), 'real_mnist_score': tf.metrics.mean(real_mnist_score), 'mnist_score': tf.metrics.mean(generated_mnist_score), 'frechet_distance': tf.metrics.mean(frechet_distance), }
上述代码是一个用于评估生成对抗网络(GAN)模型的评估指标函数。其目的是计算模型在MNIST数据集上的一些度量值。
The GANEstimator assembles and manages the pieces of the whole GAN model.
GANEstimator 对整个 GAN 模型的各个部分进行组装和管理。
The GANEstimator constructor takes the following compoonents for both the generator and discriminator:
GANEstimator 构造函数的生成器和判别器都需要以下成分:
train_batch_size = 32 #@param noise_dimensions = 64 #@param generator_lr = 0.001 #@param discriminator_lr = 0.0002 #@param def gen_opt(): gstep = tf.train.get_or_create_global_step() base_lr = generator_lr # Halve the learning rate at 1000 steps. lr = tf.cond(gstep < 1000, lambda: base_lr, lambda: base_lr / 2.0) return tf.train.AdamOptimizer(lr, 0.5) gan_estimator = tfgan.estimator.GANEstimator( generator_fn=unconditional_generator, discriminator_fn=unconditional_discriminator, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, params={'batch_size': train_batch_size, 'noise_dims': noise_dimensions}, generator_optimizer=gen_opt, discriminator_optimizer=tf.train.AdamOptimizer(discriminator_lr, 0.5), get_eval_metric_ops_fn=get_eval_metric_ops_fn)
上述代码定义了一个使用TensorFlow-GAN库的GAN Estimator(生成对抗网络估计器)。
接下来,定义了一个名为gen_opt的函数。该函数创建了一个Adam优化器,用于生成器,其学习率在1000个步骤之后减半。使用tf.cond函数实现了学习率调度。
然后,代码使用tfgan.estimator.GANEstimator类创建了一个GAN估计器。
# Disable noisy output. tf.autograph.set_verbosity(0, False) import time steps_per_eval = 500 #@param max_train_steps = 5000 #@param batches_for_eval_metrics = 100 #@param # Used to track metrics. steps = [] real_logits, fake_logits = [], [] real_mnist_scores, mnist_scores, frechet_distances = [], [], [] cur_step = 0 start_time = time.time() while cur_step < max_train_steps: next_step = min(cur_step + steps_per_eval, max_train_steps) start = time.time() gan_estimator.train(input_fn, max_steps=next_step) steps_taken = next_step - cur_step time_taken = time.time() - start print('Time since start: %.2f min' % ((time.time() - start_time) / 60.0)) print('Trained from step %i to %i in %.2f steps / sec' % ( cur_step, next_step, steps_taken / time_taken)) cur_step = next_step # Calculate some metrics. metrics = gan_estimator.evaluate(input_fn, steps=batches_for_eval_metrics) steps.append(cur_step) real_logits.append(metrics['real_data_logits']) fake_logits.append(metrics['gen_data_logits']) real_mnist_scores.append(metrics['real_mnist_score']) mnist_scores.append(metrics['mnist_score']) frechet_distances.append(metrics['frechet_distance']) print('Average discriminator output on Real: %.2f Fake: %.2f' % ( real_logits[-1], fake_logits[-1])) print('Inception Score: %.2f / %.2f Frechet Distance: %.2f' % ( mnist_scores[-1], real_mnist_scores[-1], frechet_distances[-1])) # Vizualize some images. iterator = gan_estimator.predict( input_fn, hooks=[tf.train.StopAtStepHook(num_steps=21)]) try: imgs = np.array([next(iterator) for _ in range(20)]) except StopIteration: pass tiled = tfgan.eval.python_image_grid(imgs, grid_shape=(2, 10)) plt.axis('off') plt.imshow(np.squeeze(tiled)) plt.show() # Plot the metrics vs step. plt.title('MNIST Frechet distance per step') plt.plot(steps, frechet_distances) plt.figure() plt.title('MNIST Score per step') plt.plot(steps, mnist_scores) plt.plot(steps, real_mnist_scores) plt.show()
它是一个使用TensorFlow的GAN(生成对抗网络)进行训练和评估的过程。
所以我们在这里主要展示了如何使用TensorFlow构建和训练GAN模型,并使用评估指标来监控训练进度。