画像処理関連のディープラーニングぽいものの構築を通して、PyTorchの理解を深めてきましたが
(決して学習自体はうまくいってませんがw)
これからもディープラーニング自体は勉強を続けていくわけですが、PyTorch(に限らない?)でコーディングしていく上で、理解するのに一番時間を使った (苦労した)DatasetとDataLoaderについて、自分の理解を整理する為に書いてみる
「正直初めのうちは、コピペで使いまわして、動けばよかった。。。」
だったのですが、
世の中にいろんな賢い人たちが作ったライブラリや学習モデルが溢れだしている現在では、ディープラーニングはデータの前処理が 7-8割(あいまい) なんてどっかで聞いた気がするので、画像処理ではData Augmentation(水増し)なんかもそれにあたるのかなぁ
と、ふと思い
独学でPyTorch + Dataset/Dataloader 周りを調べる上で苦労したポイントを備忘録的にまとめておこう思います
Dataset と DataLoarder の関係
文字で書かれると理解できなかったので、そんな時は絵にしてみる
ざっくりいうと
- テーマを決めて、学習させたい素材を集めて問題と答えをセットに発送
- 小分けに箱詰めして保管
- 箱詰めされた単位で学習していく
になります
Dataset 丸っと一気に学習すれば良いんでない?なんでこんな面倒なことをするかというと、ミニバッチサイズごとに区切って学習させた方が効率がよかったり、コンピュータリソース(CPU/GPU/メモリ)の観点なんかもあるそうです
それぞれの役割をもう少しだけ細かく書くと
Dataset
➀公開されているデータセット(MNIST、CIFAR、Kaggleなど)を使う
②自作データセットを作る
クローリング、スクレイピングなどで自分でデータを集めて、数字データのクレンジングや画像データのAugmentationしてやったりの後に入力データと教師データ(レベル)をセットで格納
DataLoader
Dataset からバッチサイズやシャッフルの有無を指定して格納する
・Dataset内に1,000個データセットがあった場合
・バッチサイズ が100だっとすると
⇒DataLoader内には、バッチサイズ100が x10個
取り出しの単位は、バッチサイズごとになります
学習モデル
ステップ:バッチサイズごとの学習データを取り込んで、損失関数の計算、パラメータの更新が行われます
Epoch:DataLoaderのバッチを全部取り出すとEpoch一回分になります
では、ここから実戦形式でPyTorchで書いていきます
Dataset にデータを突っ込む
torchvision.datasetsには、いくつかの公開されているデータセットを読み込むサブセットが用意されています
まずは一番単純な例はこんな感じ
import torch import torchvision from torchvision import transforms # 取り込んだデータを torch.tensorに変換する transform = transforms.Compose([ transforms.ToTensor() ]) # CIFAR-10データセットの読み込み cifa10_train = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform) cifa10_train
Dataset CIFAR10
Number of datapoints: 50000
Root location: data
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
)
表示するとこんな風にデータセットの概要がなんとなく表示されます
(Datasetの名前とかデータ数とか)
まずは、この部分から説明です
transform
transform = transforms.Compose([ transforms.ToTensor() ])
データセットを読み込む際に、データの変換をかましてやります
この例では、CIFAR-10 のデータセットを読み込む際に PyTorch Tensorに変換かましています、変換していないと画像データそのまんまが入っているので学習できません
ここにいろいろと仕込むことができるので、自作 Dataset を作る際に大活躍します
では、Dataset の中身を確認していきましょう
Dataset の数を出力
len(cifa10_train)
50000
Dataset は入力データと教師データ(ラベル)が、5万個のセットになっています
1個目のデータセットは
cifa10_train[0][0]:入力データ
cifa10_train[0][1]:教師データ
2個目なら cifa10_train[1][0] という感じなります
入力データは torch.Size([3, 32, 32]) の torch.tensor形式の画像データが入っています
cifa10_train[0][0].shape, cifa10_train[0][0]
(torch.Size([3, 32, 32]),
tensor([[[0.2314, 0.1686, 0.1961, ..., 0.6196, 0.5961, 0.5804],
[0.0627, 0.0000, 0.0706, ..., 0.4824, 0.4667, 0.4784],
[0.0980, 0.0627, 0.1922, ..., 0.4627, 0.4706, 0.4275],
...,
...,
[0.3765, 0.1333, 0.1020, ..., 0.2745, 0.0275, 0.0784],
[0.3765, 0.1647, 0.1176, ..., 0.3686, 0.1333, 0.1333],
[0.4549, 0.3686, 0.3412, ..., 0.5490, 0.3294, 0.2824]]]))
教師データには 6 というラベルが入ってます
cifa10_train[0][1] 6
画像の内容を確認しておきましょう
# 画像を表示してチェック import numpy as np import matplotlib.pyplot as plt sample = test[0][0] img = np.transpose(sample, (1, 2, 0)) plt.imshow(img)
32 x 32 pixel のカエルちゃんの画像ですね
DataLoader でデータを読み込む
結構分かりにくいのがDataLoaderです
簡単に理解するならば、バッチサイズごとに段ボールに詰めて、倉庫に保管するような作業
さっそく、実例ですが
dataloader = torch.utils.data.DataLoader(cifa10_train, batch_size=128, shuffle=True)
バッチサイズ:128
シャッフル:ON
で、格納されています
とりだしてみましょう
data = next(iter(dataloader)) x, t = data
data の中にバッチサイズ1回分(128個)のデータセットが読み込まれました
入力データ、教師データ(ラベル)が対になっているので、x, tに分けています
学習する場合は、このバッチサイズごとに
こんな感じで
y = f(x)
loss = loss(y, t)
optimizer(parameter)
みたいに学習モデルへ突っ込んで、loss の計算、パラメータの更新みたいに学習させていきます
そして次の data(バッチ) を DataLoader から引っ張ってきます
この考え方自体は非常に重要なのですが、
実は PyTorch Lightning では使い方が若干異なります
そこで出てくるのが @dataloader です
@dataloaderってなんぞ
PyTorch Lightning では、dataset のままでOKです
train, val, test という名前つけておけば、勝手にDataLoaderに突っ込んでくれます
バッチサイズも学習モデル内で指定しておくだけです
その役割を果たしているのが、@dataloader と続く関数になるのですが
@ホニャララは、pythonの機能でデコレータと呼ばれるもので、続く関数に細工してくれるものになります
(デコレータ自体はなんだかややこしいので割愛)
@pl.data_loader def train_dataloader(self): return torch.utils.data.DataLoader(train, self.batch_size, shuffle=True)
PyTorch Lightning で大事なこととしては、train, val, test のデータセットを用意しておけば、勝手にDataLoaderを利用してくれます
(関数の中のdataset名は変えたら、一緒に変えてください)
自作のDatasetを作る
画像データを取扱うには、Data Augmentation(水増し)というのが出てきます
一枚の画像から、小さく切り出して、回転させたり、反転さたりして、画像の水増しをしてやります
ここで Dataset の自作が必要になってきます
torch.utils.data.Dataset を継承して、オリジナルのDataset クラスを作ってみます
使うメソッドは
def __init__ :初期化用
def __len__(self):データの数
def __getitem__:dataset からデータを呼び出すたびに実行される
重要なのは __getitem__ が、毎回実行されることになります
ランダムな処理をこのメソッド内で呼び出すことで、簡単に水増しすることができます
では、早速実践で画像水増し用のクラスとして、学習用の input と target の画像でータセットをつくります(inputデータの方をわざと低画質にして比較して学習させる超解像で使われる手法)
[target_data]
・ランダムな位置を切り出し
・ランダムな回転・反転など行う
[input_data]
・target_data の画像を低画質に落とす
>|python|
# Dataset自作クラス
# オリジナル画像からランダムクロップして、低画質画像とデータセットを作る
class MyDataset_Resize(torch.utils.data.Dataset):
def __init__(self, data, scale=4):
self.data = data
origin_size = data.size(2)
scale_size = origin_size // scale # 切り出す画像サイズ
down_size = scale_size //scale # 低画質画像用のサイズ
# ランダムな場所を切り出しつつ、ランダムに回転、反転などを加える
self.trans_random = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomCrop(scale_size, pad_if_needed=True, padding_mode='reflect'),
transforms.RandomApply([
functools.partial(TF.rotate, angle=0),
functools.partial(TF.rotate, angle=90),
functools.partial(TF.rotate, angle=180),
functools.partial(TF.rotate, angle=270),
]),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor()
])
# 一旦低画質にする
self.trans_quarter = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(size=down_size, interpolation=Image.BICUBIC),
transforms.ToTensor()
])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
origin_data = self.data[idx]
target_data = self.trans_random(origin_data)
input_data = self.trans_quarter(target_data)
return input_data, target_data
|
入力画像を torch.tensor で準備
train_img_origin.shape, train_img_origin.dtype >> (torch.Size([50, 3, 128, 128]), torch.float32)
データセットの作成
dataset = MyDataset_Resize(train_img_origin) len(dataset), dataset[0][0].shape, >> (50, torch.Size([3, 8, 8]), torch.Size([3, 32, 32]))
これで dataset を参照するたびに画像データの中身が変わるデータセットが完成です
import numpy as np import matplotlib.pyplot as plt sample = dataset[0][0] sample_img = np.transpose(sample, (1,2,0)) plt.imshow(sample_img)
落とし穴
参照するたびにランダムに処理が施されるので、対になっている input_data と target_data の中身を参照したい場合に、こんな風に連続で呼び出してしまうと、
sample = dataset[0][0]
sample = dataset[0][1]
ランダム処理が2回走るので、別の画像データになってしまいます
<悪い例> import matplotlib.pyplot as plt sample = dataset[0][0] sample_img = np.transpose(sample, (1,2,0)) plt.imshow(sample_img) sample = dataset[0][1] sample_img = np.transpose(sample, (1,2,0)) plt.imshow(sample_img)
確認したい場合は、DataLoader 使いましょう
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1) data = next(iter(dataloader)) x, t = data sample_img = np.transpose(x[0], (1,2,0)) plt.imshow(sample_img) sample_img = np.transpose(t[0], (1,2,0)) plt.imshow(sample_img)
左:input_data 右:target_data
以上