PyTorchを使ったDeep Learningのお勉強 画像処理編【ノイズ除去実験】
基本的な画像認識はなんとなくできたので、ここからは応用編です
せっかく実装してみたCNNを応用して、オートエンコーダ(自己符号化器)にチャレンジしてみたいと思います
というわけで、今回はDAE(Denoising Autoencoder)とよばれる、画像からノイズ除去に挑戦です
ⅰ)入力された画像をCNN畳み込み処理で重要な特徴をとりだし、
ⅱ)重要な特徴を捉ええたフィルタが作られ
ⅲ)このフィルタで復元する過程で邪魔な画素(ノイズ)は取り除かれ、
ⅳ)ノイズの無い画像が完成!
(元の画像荒いのは、32x32 pixelで軽いから使っているのです)
※先に言い訳を書いておきますが、今回試す手法以外にに良い手法(or学習モデル)はたくさんあります!
※今回はあくまで独学でできるDAEに挑戦し、 Python / PyTorch の Deep Learning コーディング能力向上のお勉強用です
チャレンジ環境
・Windows 10
・anaconda:4.8.3
・PyTorch:1.5.0
・PyTorch_Lighgning:0.7.5
・GPU:NVIDIA Geforce GTX 1070
Cuda cuda_10.2.89_441.22_win10
CuDNN v7.6.5.32
DAE(Denoising Autoencoder)に挑戦するための準備
こんな感じの流れで進めていきます
1.画像データセットの用意
2.ノイズ注入
3.学習モデルの定義
4.入力データ ノイズ画像、教師データ 元画像で学習
5.ノイズ入りテストデータで確認
DataLoaderで読み込むためのデータセット作りが結構てこずりました
もっと効率の良い方法があるはず。。。
1.画像データセットの用意
大元になる画像は、画素数の低い(負担が少ない)CIFAR-10を利用します
まずは、普通にデータセットとして読み込みます
import torch import torchvision transform = transforms.Compose([ transforms.ToTensor() ]) # CIFAR-10 cifa10_train = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform) cifar10_test = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform)
データセットには、画像ごとに画像データとラベルが対になって入っています
cifa10_train.[0][0].shape, cifa10_train.[0][1].shape
(torch.Size([3, 32, 32]), torch.Size([1]))
2.ノイズ入り画像のデータセット作成
データセット画像データを抜き出して、ノイズ入り画像を自作していきたいと思います
# 画像データだけとりだす dataloader = torch.utils.data.DataLoader(cifa10_train, batch_size=len(cifa10_train), shuffle=True) train_img_origin, train_label = next(iter(dataloader)) dataloader = torch.utils.data.DataLoader(cifar10_test, batch_size=len(cifa10_train), shuffle=True) test_img_origin, test_label = next(iter(dataloader)) train_img_origin.shape, test_img_origin.shape
学習用に50,000枚、テスト用に10,000枚分の画像データが抜き出せました
(torch.Size([50000, 3, 32, 32]), torch.Size([10000, 3, 32, 32]))
ノイズ注入
あとは、この画像にノイズを入れていくわけですが、試行錯誤の結果
ⅰ)画像枚数分 forで回す
ⅱ)乱数でノイズの場所を指定
ⅲ)0~1の乱数のノイズを注入
な、なんとも力わざです
# ノイズ入り画像を返す def noise_injection(img_set, noise_num=100): n, ch, h, w = img_set.shape img_set_noise = torch.zeros(n, ch, h, w) # 初期化 # for文使わずに、画像枚数分を一気に計算した方が早いだろうな for i in range(n): img_set_noise[i] = img_set[i] # ノイズの場所 ch = torch.randint(0, 3, (noise_num,)) h = torch.randint(0, len(img_set[i][0]), (noise_num,)) w = torch.randint(0, len(img_set[i][0][0]), (noise_num,)) # ノイズ注入 noise = torch.rand(1, noise_num) img_set_noise[i, ch, h, w] = noise return img_set_noise # ノイズ入り画像を作成 train_img_noise = noise_injection(train_img_origin, noise_num=300) test_img_noise = noise_injection(test_img_origin, noise_num=300)
300個のノイズを入れてみました(画像サイズ:3ch x 32(H) x 32(W))
データセットを作る
次にノイズ入り画像と元画像が対になったデータセット作ります
# Dataset自作クラス class MyDataset(torch.utils.data.Dataset): def __init__(self, data, label, transform=None): self.transform = transform self.data = data self.label = label def __len__(self): return len(self.data) def __getitem__(self, idx): out_data = self.data[idx] out_label = self.label[idx] if self.transform: out_data = self.transform(out_data) return out_data, out_label # datasetに突っ込む train_val = MyDataset(train_img_noise, train_img_origin) test = MyDataset(test_img_noise, test_img_origin)
これはきっとこの MyDataset クラスで、一緒にノイズ注入してデータセットに突っ込むのも併せた方が、良いんだろうな
次がんばろう
ノイズ入り画像の確認
とりあえず元画像と並べて表示してみる
# ノイズがちゃんとはいっているか画像確認 sample = train_val #sample = test n = 5 m = np.random.randint(0, len(sample), n) plt.figure(figsize=(10, 4)) for i in range(m.shape[0]): # ノイズ入り画像 img = np.transpose(sample[m[i]][0], (1, 2, 0)) plt.subplot(2, 5, i+1) plt.axis("off") plt.imshow(img) # 元画像 img = np.transpose(sample[m[i]][1], (1, 2, 0)) plt.subplot(2, 5, i+1+5) plt.axis("off") plt.imshow(img) plt.show()
3.学習モデルの定義
何が正解か分かりませんので、とりあえずシンプルに 5層の畳み込み層モデルを作ってみました
各層の誤差関数はMSE(平均二乗誤差)、画像の評価にはPSNR(ピーク信号対雑音比)を使ってみましたが、、
このPNSRの算出方法が自信ない
TensorBoardに画像を出力する
boardtag = "NoiseRemove" # 学習データ用クラス class TrainNet(pl.LightningModule): @pl.data_loader def train_dataloader(self): return torch.utils.data.DataLoader(train, self.batch_size, shuffle=True) def training_step(self, batch, batch_idx): x, t = batch y = self.forward(x) loss = self.lossfun(y, t) psnr = 10 * torch.log10(1 / loss) #y_label = torch.argmax(y, dim=1) #acc = torch.sum(t == y_label) * 1.0 / len(t) tensorboard_logs = {boardtag+'/train-loss': loss, boardtag+'/train-psnr': psnr} # tensorboard # Train画像をTensorBoardに出力 if batch_idx == 0: grid = torchvision.utils.make_grid(x[:25],5) self.logger.experiment.add_image( tag = boardtag+'/train-img_origin', img_tensor=grid, global_step=self.global_step ) grid = torchvision.utils.make_grid(y[:25],5) self.logger.experiment.add_image( tag = boardtag+'/train-img_train', img_tensor=grid, global_step=self.global_step ) results = {'loss': loss, 'log': tensorboard_logs} #results = {'loss': loss} return results # 検証データ用クラス class ValidationNet(pl.LightningModule): @pl.data_loader def val_dataloader(self): return torch.utils.data.DataLoader(val, self.batch_size) def validation_step(self, batch, batch_idx): x, t = batch y = self.forward(x) loss = self.lossfun(y, t) # 検証画像をTensorBoardに出力 if batch_idx == 0: grid = torchvision.utils.make_grid(t[:25],5) self.logger.experiment.add_image( tag = boardtag+'/val-img_origin', img_tensor=grid, global_step=self.global_step ) grid = torchvision.utils.make_grid(y[:25],5) self.logger.experiment.add_image( tag = boardtag+'/val-img_val', img_tensor=grid, global_step=self.global_step ) #y_label = torch.argmax(y, dim=1) #acc = torch.sum(t == y_label) * 1.0 / len(t) results = {'val_loss': loss} return results def validation_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() psnr = 10 * torch.log10(1 / avg_loss) #avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean() tensorboard_logs = {boardtag+'/val-avg_loss': avg_loss, boardtag+'/val-avg_psnr': psnr} results = {'val_loss': avg_loss, 'log': tensorboard_logs} return results # テストデータ用クラス class TestNet(pl.LightningModule): @pl.data_loader def test_dataloader(self): return torch.utils.data.DataLoader(test, self.batch_size) def test_step(self, batch, batch_idx): x, t = batch y = self.forward(x) loss = self.lossfun(y, t) #y_label = torch.argmax(y, dim=1) #acc = torch.sum(t == y_label) * 1.0 / len(t) # 検証画像をTensorBoardに出力 if batch_idx == 0: grid = torchvision.utils.make_grid(t[:25],5) self.logger.experiment.add_image( tag = boardtag+'/ztest-img_origin', img_tensor=grid, global_step=self.global_step ) grid = torchvision.utils.make_grid(y[:25],5) self.logger.experiment.add_image( tag = boardtag+'/ztest-img_test', img_tensor=grid, global_step=self.global_step ) results = {'test_loss': loss} return results def test_end(self, outputs): avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() #avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean() tensorboard_logs = {boardtag+'/ztest_loss': avg_loss} results = {'test_loss': avg_loss, 'log': tensorboard_logs} return results # 学習データ、検証データ、テストデータクラスの継承クラス class Net(TrainNet, ValidationNet, TestNet): def __init__(self, batch_size=256): super(Net, self).__init__() self.batch_size = batch_size self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1, stride=1) self.conv2 = nn.Conv2d(32, 32, 3, padding=1) self.conv3 = nn.Conv2d(32, 32, 3, padding=1) self.conv4 = nn.Conv2d(32, 32, 3, padding=1) self.conv5 = nn.Conv2d(32, 3, 3, padding=1) def forward(self, x): # 3ch > 32ch, shape 32 x 32 > 16 x 16 x = F.relu(self.conv1(x)) # [32,32,32] x = F.max_pool2d(x, 2, 2) # [32,16,16] # 32h > 32ch, shape 16 x 16 > 8 x 8 x = F.relu(self.conv2(x)) # [128,16,16] x = F.max_pool2d(x, 2, 2) # [128,8,8] # 32h > 32ch, shape 8 x 8 > 16 x 16 x = F.relu(self.conv3(x)) x = F.interpolate(x, size=[16,16], mode='nearest') # 32h > 32ch, shape 16 x 16 > 32 x 32 x = F.relu(self.conv4(x)) x = F.interpolate(x, size=[32,32], mode='nearest') x = torch.sigmoid(self.conv5(x)) return x def lossfun(self, y, t): #return F.binary_cross_entropy(y, t) return F.mse_loss(y, t) def configure_optimizers(self): #return torch.optim.SGD(self.parameters(), lr=0.01) return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-3)
4.学習してみる
入力データ:ノイズ入り画像
教師データ:ノイズ入れる前の元画像
epoch 100で 10分くらいでした、GPUさまさま
import pytorch_lightning as pl from pytorch_lightning import Trainer from pytorch_lightning import loggers # モデルのインスタンス化 net = Net() net # TensorBoard用のログの場所を指定 tb_logger = loggers.TensorBoardLogger(save_dir='lightning_logs', name='noise_remove') trainer = Trainer(max_epochs=100, gpus=1, logger=tb_logger) trainer.fit(net)
学習後にパラメータ保存
# パラメータの保存 net = net.to('cpu') torch.save(net.state_dict(), 'pytorch_noise_remove.pt')
TensorBoard で学習の状況を確認
なんだか良い感じに学習している様子です
画像もTensorBoardで表示できています、だいたい履歴も見れるしこれいいな
学習していくさま→→
5.テストデータで確認
では、テスト用のデータを使って実際に画像処理してみましょう
# 保存したパラメータの読み込み net = Net() net.load_state_dict(torch.load('pytorch_noise_remove_SRCNNbn.pt')) # 推論モード net.eval() net.freeze() # ランダムで5枚を抽出する n = 5 m = torch.randint(0, 10000, (n,)) t = torch.zeros([len(m), 3, 32, 32], dtype=torch.float32) x = torch.zeros([len(m), 3, 32, 32], dtype=torch.float32) for i in range(len(m)): t[i] = test[m[i]][1] x[i] = test[m[i]][0] y = test # 開始! y = net(x) # それぞれの画像表示 plt.figure(figsize=(10, 6)) for i in range(n): # 結果 img = np.transpose(y[i], (1, 2, 0)) plt.subplot(3, 5, i+1) plt.axis("off") plt.imshow(img) # ノイズ入り画像 img = np.transpose(x[i], (1, 2, 0)) plt.subplot(3, 5, i+1+5) plt.axis("off") plt.imshow(img) # 元画像 img = np.transpose(t[i], (1, 2, 0)) plt.subplot(3, 5, i+1+10) plt.axis("off") plt.imshow(img) plt.show()
1段目:出力された画像
2段目:入力された画像
3段目:元画像
一番上の段がノイズ入り画像に処理を施した画像なわけですが
とりあえずノイズは消えた、、、が、全然ぼやけてるなぁ
考察してみる
だいたい目星はついているのですが
ⅰ)アップサンプリングの手法
プーリングで小さくした画像を大きくする際に、単純に"nearest"で、やったからでしょうなぁ
x = F.interpolate(x, size=[32,32], mode='nearest')
ⅱ)画像が小さい
そもそも画素数が低いので、抽出すべき特徴が少ないんですよね
では、対処方法は?
BachNomalization を試してみる
ReLu の代わりに投入
# 学習データ、検証データ、テストデータクラスの継承クラス class NetBn(TrainNet, ValidationNet, TestNet): def __init__(self, batch_size=256): super(NetBn, self).__init__() self.batch_size = batch_size self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1, stride=1) self.bn1 = nn.BatchNorm2d(32) self.conv2 = nn.Conv2d(32, 32, 3, padding=1) self.bn2 = nn.BatchNorm2d(32) self.conv3 = nn.Conv2d(32, 32, 3, padding=1) self.bn3 = nn.BatchNorm2d(32) self.conv4 = nn.Conv2d(32, 32, 3, padding=1) self.bn4 = nn.BatchNorm2d(32) self.conv5 = nn.Conv2d(32, 3, 3, padding=1) self.bn5 = nn.BatchNorm2d(32) def forward(self, x): # 3ch > 32ch, shape 32 x 32 > 16 x 16 x = self.bn1(self.conv1(x)) # [32,32,32] x = F.max_pool2d(x, 2, 2) # [32,16,16] # 32h > 32ch, shape 16 x 16 > 8 x 8 x = self.bn2(self.conv2(x)) # [128,16,16] x = F.max_pool2d(x, 2, 2) # [128,8,8] # 32h > 32ch, shape 8 x 8 > 16 x 16 x = self.bn3(self.conv3(x)) x = F.interpolate(x, size=[16,16], mode='nearest') # 32h > 32ch, shape 16 x 16 > 32 x 32 x = self.bn4(self.conv4(x)) x = F.interpolate(x, size=[32,32], mode='nearest') x = torch.sigmoid(self.conv5(x)) return x def lossfun(self, y, t): #return F.binary_cross_entropy(y, t) return F.mse_loss(y, t) def configure_optimizers(self): #return torch.optim.SGD(self.parameters(), lr=0.01) return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-3)
アップサンプリングの方法は変えていないので、ぼやけた感じは変わらずですが
色がはっきり出るようになったので、特徴量をうまく抽出できているのかな