konchangakita

KPSを一番楽しんでいたブログ 会社の看板を背負いません 転載はご自由にどうぞ

【DeepLearning特訓】RNN応用 GRU編

E資格向けの自習アウトプット
自分用メモ

GRU(Gated Recurrent Unit)は、LSTM を単純化したモデル
LSTMはパラメータが多く学習に時間がかかので、パラメータを減らして計算量を減らす工夫
ゲートを2つに減らし、内部状態をなくした
 ・reset ゲート (r):過去の隠れ状態をどれだけ無視するか
 ・update ゲート (z):過去の隠れ状態を更新する役割

f:id:konchangakita:20210130193220p:plain:w300

GRUの構造

行列の内積を計算するところと、要素積(アダマール積)が混ざっているの要注意
f:id:konchangakita:20210130193333p:plain
f:id:konchangakita:20210202110524p:plain:w350

Python に落とし込むと理解しやすい(順伝播まで)
コピペで試せます

# シグモイド
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# GRUモデル
# N:バッチサイズ、D:入力単語数、H:中間層の出力次元数
class GRU:
    def __init__(self, wx, wh, b):
        self.params = wx, wh, b     # # wx[D,3H], wh[H,3H], b[3H]
        
    def forward(self, x, h_prev):
        wx, wh, b = self.params
        H = wh.shape[0]

        wxz, wxr, wxh = wx[:, :H], wx[:, H:2*H], wx[:, 2*H:]    # 入力用重み
        whz, whr, whh = wh[:, :H], wh[:, H:2*H], wh[:, 2*H:]    # 前の時刻出力用重み
        bz, br, bh = b[:H], b[H:2*H], b[2*H:]                   # バイアス

        z = sigmoid(np.dot(h_prev, whz) + np.dot(x, wxz) + bz)  # updateゲート
        r = sigmoid(np.dot(h_prev, whr) + np.dot(x, wxr) + br)  # resetゲート
        h_hat = sigmoid(np.dot(r*h_prev, whh) + np.dot(x, wxh) + bh )
        h_next = (1-z) * h_prev + z * h_hat

        return h_next



適当な入力とパラメータで実行

import numpy as np

# 入力を適当に定義
x = np.arange(25).reshape(5,5)
h_prev = np.ones((5,10))

# 重みを初期化
wx = np.random.randn(5, 30)
wh = np.random.randn(10, 30)
b = np.zeros(30)

# モデルインスタンス
gru = GRU(wx, wh, b)

# 順伝播
gru.forward(x, h_prev)
array([[7.96586166e-01, 9.79622537e-03, 9.99877107e-01, 9.76598775e-01,
        9.98707164e-01, 3.56091373e-01, 9.99988035e-01, 2.55284290e-01,
        9.89543388e-01, 9.96706424e-01],
       [2.47451855e-03, 5.64910120e-03, 9.99999999e-01, 9.99999983e-01,
        1.00000000e+00, 4.40713164e-03, 1.00000000e+00, 9.99995693e-01,
        9.82514705e-01, 1.00000000e+00],
       [5.47933744e-04, 3.69982598e-03, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 3.66723137e-05, 1.00000000e+00, 1.00000000e+00,
        9.70604673e-01, 1.00000000e+00],
       [1.66534473e-04, 2.42152843e-03, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 3.04634406e-07, 1.00000000e+00, 1.00000000e+00,
        9.50828177e-01, 1.00000000e+00],
       [5.06729678e-05, 1.58418307e-03, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 2.53141448e-09, 1.00000000e+00, 1.00000000e+00,
        9.19010082e-01, 1.00000000e+00]])

おわりに

LSTM と GRU どちらを使うべきか?ですが、文献によると、タスクやハイパーパラメータの調整によりケースバイケースの模様
最近では LSTM の方(+LSTM改良版)が多く使われているようですが、GRUは計算量が少なくて済むので、小さいデータセットやモデルの設計を繰り返し修正しながら作っていくには向いているらしい
LSTM のその応用はここでおしまい
他の自然言語処理モデルへつづく

LSTM応用編はおわり