konchangakita

KPSを一番楽しんでいたブログ 会社の看板を背負いません 転載はご自由にどうぞ

PyTorchを使ったDeep Learningのお勉強 画像処理編【ノイズ除去実験】

基本的な画像認識はなんとなくできたので、ここからは応用編です
せっかく実装してみたCNNを応用して、オートエンコーダ(自己符号化器)にチャレンジしてみたいと思います

というわけで、今回はDAE(Denoising Autoencoder)とよばれる、画像からノイズ除去に挑戦です

ⅰ)入力された画像をCNN畳み込み処理で重要な特徴をとりだし、
ⅱ)重要な特徴を捉ええたフィルタが作られ
ⅲ)このフィルタで復元する過程で邪魔な画素(ノイズ)は取り除かれ、
ⅳ)ノイズの無い画像が完成!

f:id:konchangakita:20200516153846p:plain
(元の画像荒いのは、32x32 pixelで軽いから使っているのです)

※先に言い訳を書いておきますが、今回試す手法以外にに良い手法(or学習モデル)はたくさんあります!
※今回はあくまで独学でできるDAEに挑戦し、 Python / PyTorch の Deep Learning コーディング能力向上のお勉強用です

チャレンジ環境

Windows 10
・anaconda:4.8.3
・PyTorch:1.5.0
・PyTorch_Lighgning:0.7.5
GPUNVIDIA Geforce GTX 1070
  Cuda cuda_10.2.89_441.22_win10
  CuDNN v7.6.5.32

DAE(Denoising Autoencoder)に挑戦するための準備

こんな感じの流れで進めていきます
 1.画像データセットの用意
 2.ノイズ注入
 3.学習モデルの定義
 4.入力データ ノイズ画像、教師データ 元画像で学習
 5.ノイズ入りテストデータで確認

DataLoaderで読み込むためのデータセット作りが結構てこずりました
もっと効率の良い方法があるはず。。。

1.画像データセットの用意

大元になる画像は、画素数の低い(負担が少ない)CIFAR-10を利用します
まずは、普通にデータセットとして読み込みます

import torch
import torchvision

transform = transforms.Compose([
    transforms.ToTensor()
])

# CIFAR-10
cifa10_train = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
cifar10_test = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform)

データセットには、画像ごとに画像データとラベルが対になって入っています

cifa10_train.[0][0].shape, cifa10_train.[0][1].shape

Output:


(torch.Size([3, 32, 32]), torch.Size([1]))


2.ノイズ入り画像のデータセット作成

データセット画像データを抜き出して、ノイズ入り画像を自作していきたいと思います

# 画像データだけとりだす
dataloader = torch.utils.data.DataLoader(cifa10_train, batch_size=len(cifa10_train), shuffle=True)
train_img_origin, train_label = next(iter(dataloader))
dataloader = torch.utils.data.DataLoader(cifar10_test, batch_size=len(cifa10_train), shuffle=True)
test_img_origin, test_label = next(iter(dataloader))

train_img_origin.shape, test_img_origin.shape

学習用に50,000枚、テスト用に10,000枚分の画像データが抜き出せました

Output:


(torch.Size([50000, 3, 32, 32]), torch.Size([10000, 3, 32, 32]))

ノイズ注入

あとは、この画像にノイズを入れていくわけですが、試行錯誤の結果
 ⅰ)画像枚数分 forで回す
 ⅱ)乱数でノイズの場所を指定
 ⅲ)0~1の乱数のノイズを注入
な、なんとも力わざです

# ノイズ入り画像を返す
def noise_injection(img_set, noise_num=100):
    n, ch, h, w = img_set.shape
    img_set_noise = torch.zeros(n, ch, h, w) # 初期化
    
    # for文使わずに、画像枚数分を一気に計算した方が早いだろうな
    for i in range(n):
        img_set_noise[i] = img_set[i]

        # ノイズの場所
        ch = torch.randint(0, 3, (noise_num,))
        h = torch.randint(0, len(img_set[i][0]), (noise_num,))
        w = torch.randint(0, len(img_set[i][0][0]), (noise_num,))

        # ノイズ注入
        noise = torch.rand(1, noise_num)
        img_set_noise[i, ch, h, w] = noise
        
    return img_set_noise

# ノイズ入り画像を作成
train_img_noise = noise_injection(train_img_origin, noise_num=300)
test_img_noise = noise_injection(test_img_origin, noise_num=300)

300個のノイズを入れてみました(画像サイズ:3ch x 32(H) x 32(W))
f:id:konchangakita:20200516173028p:plain

データセットを作る

次にノイズ入り画像と元画像が対になったデータセット作ります

# Dataset自作クラス
class MyDataset(torch.utils.data.Dataset):

    def __init__(self, data, label, transform=None):
        self.transform = transform
        self.data = data
        self.label = label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label =  self.label[idx]

        if self.transform:
            out_data = self.transform(out_data)

        return out_data, out_label

# datasetに突っ込む
train_val = MyDataset(train_img_noise, train_img_origin)
test = MyDataset(test_img_noise, test_img_origin)

これはきっとこの MyDataset クラスで、一緒にノイズ注入してデータセットに突っ込むのも併せた方が、良いんだろうな
次がんばろう

ノイズ入り画像の確認

とりあえず元画像と並べて表示してみる

# ノイズがちゃんとはいっているか画像確認
sample = train_val
#sample = test

n = 5
m = np.random.randint(0, len(sample), n)

plt.figure(figsize=(10, 4))
for i in range(m.shape[0]):
    # ノイズ入り画像
    img = np.transpose(sample[m[i]][0], (1, 2, 0))
    plt.subplot(2, 5, i+1)
    plt.axis("off")
    plt.imshow(img)

    # 元画像
    img = np.transpose(sample[m[i]][1], (1, 2, 0))    
    plt.subplot(2, 5, i+1+5)
    plt.axis("off")
    plt.imshow(img)
plt.show()

f:id:konchangakita:20200516172534p:plain


3.学習モデルの定義

何が正解か分かりませんので、とりあえずシンプルに 5層の畳み込み層モデルを作ってみました
f:id:konchangakita:20200516183637p:plain

各層の誤差関数はMSE(平均二乗誤差)、画像の評価にはPSNR(ピーク信号対雑音比)を使ってみましたが、、
このPNSRの算出方法が自信ない

TensorBoardに画像を出力する

boardtag = "NoiseRemove"

# 学習データ用クラス
class TrainNet(pl.LightningModule):
    
    @pl.data_loader
    def train_dataloader(self):
        return torch.utils.data.DataLoader(train, self.batch_size, shuffle=True)
    
    def training_step(self, batch, batch_idx):
        x, t = batch
        y = self.forward(x)      
        loss = self.lossfun(y, t)
        psnr = 10 * torch.log10(1 / loss)
        #y_label = torch.argmax(y, dim=1)
        #acc = torch.sum(t == y_label) * 1.0 / len(t)
        tensorboard_logs = {boardtag+'/train-loss': loss, boardtag+'/train-psnr': psnr} # tensorboard

        
        # Train画像をTensorBoardに出力
        if batch_idx == 0:
            grid = torchvision.utils.make_grid(x[:25],5)
            self.logger.experiment.add_image(
                tag = boardtag+'/train-img_origin',
                img_tensor=grid,
                global_step=self.global_step
            )           

            grid = torchvision.utils.make_grid(y[:25],5)
            self.logger.experiment.add_image(
                tag = boardtag+'/train-img_train',
                img_tensor=grid,
                global_step=self.global_step
            )            
        
        results = {'loss': loss, 'log': tensorboard_logs}
        #results = {'loss': loss}
        return results

    
# 検証データ用クラス
class ValidationNet(pl.LightningModule):

    @pl.data_loader
    def val_dataloader(self):
        return torch.utils.data.DataLoader(val, self.batch_size)

    def validation_step(self, batch, batch_idx):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)

         
        # 検証画像をTensorBoardに出力
        if batch_idx == 0:
            grid = torchvision.utils.make_grid(t[:25],5)
            self.logger.experiment.add_image(
                tag = boardtag+'/val-img_origin',
                img_tensor=grid,
                global_step=self.global_step
            ) 
            
            grid = torchvision.utils.make_grid(y[:25],5)
            self.logger.experiment.add_image(
                tag = boardtag+'/val-img_val',
                img_tensor=grid,
                global_step=self.global_step
            )    
        
        
        #y_label = torch.argmax(y, dim=1)
        #acc = torch.sum(t == y_label) * 1.0 / len(t)
        results = {'val_loss': loss}
        return results

    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        psnr = 10 * torch.log10(1 / avg_loss)
        #avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        tensorboard_logs = {boardtag+'/val-avg_loss': avg_loss, boardtag+'/val-avg_psnr': psnr}
        results = {'val_loss': avg_loss, 'log': tensorboard_logs}        
        return results

    
# テストデータ用クラス
class TestNet(pl.LightningModule):

    @pl.data_loader
    def test_dataloader(self):
        return torch.utils.data.DataLoader(test, self.batch_size)

    def test_step(self, batch, batch_idx):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        #y_label = torch.argmax(y, dim=1)
        #acc = torch.sum(t == y_label) * 1.0 / len(t)        

        # 検証画像をTensorBoardに出力
        if batch_idx == 0:
            grid = torchvision.utils.make_grid(t[:25],5)
            self.logger.experiment.add_image(
                tag = boardtag+'/ztest-img_origin',
                img_tensor=grid,
                global_step=self.global_step
            )
            grid = torchvision.utils.make_grid(y[:25],5)
            self.logger.experiment.add_image(
                tag = boardtag+'/ztest-img_test',
                img_tensor=grid,
                global_step=self.global_step
            )    

        results = {'test_loss': loss}
        return results

    def test_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        #avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        tensorboard_logs = {boardtag+'/ztest_loss': avg_loss}
        results = {'test_loss': avg_loss, 'log': tensorboard_logs}            
        return results

# 学習データ、検証データ、テストデータクラスの継承クラス
class Net(TrainNet, ValidationNet, TestNet):
    def __init__(self, batch_size=256):
        super(Net, self).__init__()
        self.batch_size = batch_size
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1, stride=1)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1) 
        self.conv3 = nn.Conv2d(32, 32, 3, padding=1) 
        self.conv4 = nn.Conv2d(32, 32, 3, padding=1) 
        self.conv5 = nn.Conv2d(32, 3, 3, padding=1) 
        
        
    def forward(self, x):
        # 3ch > 32ch, shape 32 x 32 > 16 x 16
        x = F.relu(self.conv1(x)) # [32,32,32]
        x = F.max_pool2d(x, 2, 2) # [32,16,16]
        
        # 32h > 32ch, shape 16 x 16 > 8 x 8
        x = F.relu(self.conv2(x)) # [128,16,16]
        x = F.max_pool2d(x, 2, 2) # [128,8,8]
        
        # 32h > 32ch, shape 8 x 8 > 16 x 16 
        x = F.relu(self.conv3(x))
        x = F.interpolate(x, size=[16,16], mode='nearest')
        
        # 32h > 32ch, shape 16 x 16 > 32 x 32
        x = F.relu(self.conv4(x))
        x = F.interpolate(x, size=[32,32], mode='nearest')
        
        x = torch.sigmoid(self.conv5(x))
        
        return x
    
    def lossfun(self, y, t):
        #return F.binary_cross_entropy(y, t)
        return F.mse_loss(y, t)
    
    def configure_optimizers(self):
        #return torch.optim.SGD(self.parameters(), lr=0.01)
        return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-3)


4.学習してみる

入力データ:ノイズ入り画像
教師データ:ノイズ入れる前の元画像

epoch 100で 10分くらいでした、GPUさまさま
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning import loggers

# モデルのインスタンス化
net = Net()
net

# TensorBoard用のログの場所を指定
tb_logger = loggers.TensorBoardLogger(save_dir='lightning_logs', name='noise_remove')

trainer = Trainer(max_epochs=100, gpus=1, logger=tb_logger)
trainer.fit(net)


学習後にパラメータ保存

# パラメータの保存
net = net.to('cpu')
torch.save(net.state_dict(), 'pytorch_noise_remove.pt')
TensorBoard で学習の状況を確認

なんだか良い感じに学習している様子です
f:id:konchangakita:20200516222027p:plain
画像もTensorBoardで表示できています、だいたい履歴も見れるしこれいいな
f:id:konchangakita:20200516222330p:plain

学習していくさま→→
f:id:konchangakita:20200516222903p:plain:w180f:id:konchangakita:20200516222947p:plain:w180f:id:konchangakita:20200516223012p:plain:w180

5.テストデータで確認

では、テスト用のデータを使って実際に画像処理してみましょう

# 保存したパラメータの読み込み
net = Net()
net.load_state_dict(torch.load('pytorch_noise_remove_SRCNNbn.pt'))

# 推論モード
net.eval()
net.freeze()

# ランダムで5枚を抽出する
n = 5
m = torch.randint(0, 10000, (n,))

t = torch.zeros([len(m), 3, 32, 32], dtype=torch.float32)
x = torch.zeros([len(m), 3, 32, 32], dtype=torch.float32)

for i in range(len(m)):
    t[i] = test[m[i]][1]
    x[i] = test[m[i]][0]
y = test

# 開始!
y = net(x)

# それぞれの画像表示
plt.figure(figsize=(10, 6))
for i in range(n):
    # 結果
    img = np.transpose(y[i], (1, 2, 0))
    plt.subplot(3, 5, i+1)
    plt.axis("off")
    plt.imshow(img)

    # ノイズ入り画像
    img = np.transpose(x[i], (1, 2, 0))    
    plt.subplot(3, 5, i+1+5)
    plt.axis("off")
    plt.imshow(img)
    
    # 元画像
    img = np.transpose(t[i], (1, 2, 0))    
    plt.subplot(3, 5, i+1+10)
    plt.axis("off")
    plt.imshow(img)    
plt.show()

1段目:出力された画像
2段目:入力された画像
3段目:元画像
f:id:konchangakita:20200516225240p:plain
一番上の段がノイズ入り画像に処理を施した画像なわけですが
とりあえずノイズは消えた、、、が、全然ぼやけてるなぁ


考察してみる

だいたい目星はついているのですが

ⅰ)アップサンプリングの手法
プーリングで小さくした画像を大きくする際に、単純に"nearest"で、やったからでしょうなぁ

x = F.interpolate(x, size=[32,32], mode='nearest')

ⅱ)画像が小さい
そもそも画素数が低いので、抽出すべき特徴が少ないんですよね

では、対処方法は?

BachNomalization を試してみる

ReLu の代わりに投入

# 学習データ、検証データ、テストデータクラスの継承クラス
class NetBn(TrainNet, ValidationNet, TestNet):
    def __init__(self, batch_size=256):
        super(NetBn, self).__init__()
        self.batch_size = batch_size
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1, stride=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1) 
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv4 = nn.Conv2d(32, 32, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(32)
        self.conv5 = nn.Conv2d(32, 3, 3, padding=1)
        self.bn5 = nn.BatchNorm2d(32)
        
        
    def forward(self, x):
        # 3ch > 32ch, shape 32 x 32 > 16 x 16
        x = self.bn1(self.conv1(x)) # [32,32,32]
        x = F.max_pool2d(x, 2, 2) # [32,16,16]
        
        # 32h > 32ch, shape 16 x 16 > 8 x 8
        x = self.bn2(self.conv2(x)) # [128,16,16]
        x = F.max_pool2d(x, 2, 2) # [128,8,8]
        
        # 32h > 32ch, shape 8 x 8 > 16 x 16 
        x = self.bn3(self.conv3(x))
        x = F.interpolate(x, size=[16,16], mode='nearest')
        
        # 32h > 32ch, shape 16 x 16 > 32 x 32
        x = self.bn4(self.conv4(x))
        x = F.interpolate(x, size=[32,32], mode='nearest')

        x = torch.sigmoid(self.conv5(x))
        
        return x
    
    def lossfun(self, y, t):
        #return F.binary_cross_entropy(y, t)
        return F.mse_loss(y, t)
    
    def configure_optimizers(self):
        #return torch.optim.SGD(self.parameters(), lr=0.01)
        return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-3)

アップサンプリングの方法は変えていないので、ぼやけた感じは変わらずですが
色がはっきり出るようになったので、特徴量をうまく抽出できているのかな
f:id:konchangakita:20200516225907p:plain

超解像に挑戦?

超解像とは、簡単にいうと画像の解像度を上げる手法です
画像がぼやけてるってことには、これが有効なような気がします

SRCNN、SRGAN(SRはSuper Resolution)という手法(論文)があるので
これを参考にしてみようかな