konchangakita

KPSを一番楽しんでいたブログ 会社の看板を背負いません 転載はご自由にどうぞ

【DeepLearning特訓】GAN 敵対的生成ネットワーク

E資格向けの自習アウトプット
自分用メモ

GAN 敵対的生成ネットワークは、2014年にイアン・グッドフェロー氏らが「Generative Adversarial Network」という論文で発表
生成モデルと識別モデルの組み合わせ、敵対させ競い合わせることで精度を上げていく手法

 ・生成(Generator)モデル:見破られないようなニセモノを作る
 ・識別(Discriminator)モデル:訓練データを使ってニセモノを見破ろうとする

「偽造犯と警察」とか「怪盗と探偵」とか「ルパンと銭形」とかそんなライバル同士が切磋琢磨していって、結果として騙す側(トリック)巧妙になっていくような関係

GAN のアーキテクチャ

f:id:konchangakita:20210207211937p:plain
1.ノイズを乱数からサンプリングする
2.Generator で Fakeデータを生成
3.Realデータと Fakeデータを Discriminator に識別させる
4.学習方針
  - Generator(生成モデル)は、識別判定のロスが大きくなるように ➔うまく騙せた
  - Discriminator(識別モデル)は、識別判定のロスが小さくなるように ➔みやぶった

GAN の目的関数(損失関数)

いつものごとく導出の過程の理解は一旦おいておいて

以下をセットで覚えておく
Generator ネットワーク 𝐺:𝑧→𝑥'
Discriminator ネットワーク 𝐷:𝑥→(0,1)
f:id:konchangakita:20210211163038p:plain

Discriminator:Realデータ真(1)、Fakeデータ偽(0)の時の、Discriminatorの予測に対する交差エントロピー
見破れてるのか
Generator:Realデータ真(1)、Fakeデータ真(1)の時の、Discriminatorの予測に対する交差エントロピー
うまく騙せてるのか

このあたりは、考えると沼っていくので、実装しながら理解することにする

GAN の種類

DCGAN(Deep Convolutinal GAN):CNNを使う
LAPGAN(Laplacian Pyramid):低解像度と高解像度の画像の差を比較
Conditional GAN:訓練時に教師データのラベル情報を用いて、生成するクラスを指定できる
StarGAN:マルチドメインに適用できるように拡張

この他にもたくさん研究されているらしい

TensorFlow で実装しながら頭を整理

TensorFlow公式に DCGAN のチュートリアルがあったので、こちらで参考に実装してみる
www.tensorflow.org

今回はこのあたりを使う

import tensorflow as tf
from tensorflow.keras import layers 
from tensorflow.keras.datasets import mnist, fashion_mnist
import matplotlib.pyplot as plt
データセット

データセットの準備、今回は Fashion MNIST を使ってみる

# データのロード
(x_train, t_train), (x_test, t_test) = fashion_mnist.load_data()

# 設定
BUFFER_SIZE = x_train.shape[0]  # 60000
BATCH_SIZE = 256

x_train = x_train.reshape(BUFFER_SIZE, 28, 28, 1).astype('float32')
x_train = (x_train - 127.5) / 127.5     # Normalize [-1, 1]
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)


どんな画像か確認しておく

fig = plt.figure(figsize=(4,4))

for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(x_train[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
    plt.axis('off')

plt.show()

f:id:konchangakita:20210207233715p:plain

Generator ネットワーク

ノイズを受け取って、CNN の逆畳み込みで画像を作っていく

# Generator
# noise から画像を作る
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Reshape((7,7,256)))
    assert model.output_shape == (None, 7,7,256)    # Noneはバッチサイズ

    model.add(layers.Conv2DTranspose(128, (5,5), strides=1, padding='same', use_bias=False))
    assert model.output_shape == (None, 7,7,128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5,5), strides=2, padding='same', use_bias=False))
    assert model.output_shape == (None, 14,14,64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5,5), strides=2, padding='same', use_bias=False))
    assert model.output_shape == (None, 28,28,1)

    return model


None はバッチサイズがはいる

generator = make_generator_model()
generator.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_2 (Dense)              (None, 12544)             1254400   
_________________________________________________________________
batch_normalization_3 (Batch (None, 12544)             50176     
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 12544)             0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 7, 7, 128)         819200    
_________________________________________________________________
batch_normalization_4 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 14, 14, 64)        204800    
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 28, 28, 1)         1600      
=================================================================
Total params: 2,330,944
Trainable params: 2,305,472
Non-trainable params: 25,472
_________________________________________________________________


試し画像を一枚だけ作ってみる

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0], cmap='gray')

学習してないので、当然わけわからん Fake 画像が表示される
f:id:konchangakita:20210207235519p:plain

Discriminatorネットワーク

CNN で畳み込んで特徴量を抽出していく

# Discriminator
# 真の場合には正の数値を、偽の場合は負の数値を返す
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5,5), strides=2, padding='same', input_shape=[28,28,1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5,5), strides=2, padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model
discriminator = make_discriminator_model()
discriminator.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_2 (Conv2D)            (None, 14, 14, 64)        1664      
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 128)         204928    
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 6272)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 6273      
=================================================================
Total params: 212,865
Trainable params: 212,865
Non-trainable params: 0
_________________________________________________________________


適当な Fake 画像の適当な特徴量

decision = discriminator(generated_image)
decision
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.2810372]], dtype=float32)>


損失関数とオプティマイザー

損失関数はクロスエントロピーオプティマイザーにはAdam

【損失関数の考え方】
Generatorでは、fake_loss を Fakeデータを Real と判定できたか(ちゃんと騙せたか)
Discriminatorでは、Realロス と Fakeロス の和
 - real_loss は 訓練データを Real 判定できたか
 - fake_loss は Fakeデータを Fake 判定できたか(見破ったか)

# 損失関数とオプティマイザー
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

レーニングループ

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        dis_loss = discriminator_loss(real_output, fake_output)
        gen_loss = generator_loss(fake_output)

    # 勾配の保存
    gradients_of_discriminator = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)

    # パラメータ更新
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

# エポック数分のループ
def train(dataset, epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            gen_loss, dis_loss = train_step(image_batch)

        print('EPOCH:{}, {}, {},'.format(epoch, gen_loss, dis_loss))

        generate_and_save_images(generator, epoch+1, seed)

        # Save the model every 15 epochs
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_dir)

# Fake画像(生成画像)の表示と保存
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    fig = plt.figure(figsize=(4,4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    plt.savefig('output/dcgan_image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()
# トレーニング設定
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate, noise_dim])

レーニング実行

train(train_dataset, EPOCHS)

エポック 10、エポック 50、エポック 100 のそれぞれの Fake画像
f:id:konchangakita:20210208004241p:plain:w180f:id:konchangakita:20210208004308p:plain:w180f:id:konchangakita:20210208004328p:plain:w180

それっぽいものが出来てきています
構造としてはシンプルなのに、オリジナルの画像が作られてくるのすごい

さいごに

構造をざっくり理解するために書いてきましたが、つまるところ Generate された分布と入力されたデータの分布を近づけていくということが重要なようです
(訓練データそのものを作るわけではない)
生成データは実用的実におもしろい、試験が終わったら実装してみたい
これからは強化学習