konchangakita

KPSを一番楽しんでいたブログ 会社の看板を背負いません 転載はご自由にどうぞ

【DeepLearning特訓】RNN応用 LSTM編

E資格向けの自習アウトプット
自分用メモ

LSTM:Long Short-Term Memory(長短期記憶)は、RNNで系列が長くなっていった時におきてしまう「長期依存性の課題」への解決アプローチの一つです

長期依存性の課題

RNNの弱点として多くの時間ステップでわたって伝播されて勾配が消失してしまう「勾配消失問題」があります。

【課題解決アプローチ方法】
スキップ接続:時刻をスキップして直接続、粗い時間スケールを得る
Leakyユニット:前の時刻から接続をα倍、現時刻の入力(1-α)倍する
接続の削除
LSTM:ゲート付きRNN
GRU

LSTM の手法

内部状態という考え方を新しく付与
f:id:konchangakita:20210130113536p:plain:w300

内部状態に影響を与えるj3つ のゲートを導入して、勾配消失問題に対応
忘却ゲート(f):前の状態の影響率
入力ゲート(i):現時刻の入力データの影響率
出力ゲート(o):現在の状態の影響率



LSTM 順伝播の構造

各ゲートを追加してイメージはこんな感じ
かなりややこしくなってきた
f:id:konchangakita:20210130124147p:plain
f:id:konchangakita:20210130001813p:plain:w400


整理のためにも順伝播までを Pythonコードの実装

# シグモイド
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# LSTMモデル
# N:バッチサイズ、D:入力単語数、H:中間層の出力次元数
class LSTM:
    def __init__(self, wx, wh, b):
        self.params = [wx, wh, b]   # wx[D,4H], wh[H,4H], b[4H]
        self.grads = [np.zeros_like(wx), np.zeros_like(wh), np.zeros_like(b)]
        self.cache = None
    
    def forward(self, x, h_prev, c_prev):
        wx, wh, b = self.params 
        N, H = h_prev.shape

        A = np.dot(h_prev, wh) + np.dot(x, wx) + b  # [N,4H]

        # slice
        f = sigmoid(A[:, :H])       # 忘却ゲート [N,H]
        i = sigmoid(A[:, 2*H:3*H])  # 入力ゲート [N,H]
        g = np.tanh(A[:, H:2*H])  
        o = sigmoid(A[:, 3*H:])     # 出力ゲート [N,H]

        c_next = f * c_prev + i * g # 現時刻の状態 [N,H]
        h_next = o * np.tanh(c_next) # 現時刻の出力 [N,H]

        self.cache = (x, h_prev, c_prev, f, i, g, o, c_next)
        return h_next, c_next



適当な入力値で実行してみる

import numpy as np

# 入力を適当に定義
x = np.arange(25).reshape(5,5)
h_prev = np.ones((5,10))
c_prev = np.zeros((5,10))

# 重みを初期化
wx = np.random.randn(5, 40)
wh = np.random.randn(10, 40)
b = np.zeros(40)

# モデルインスタンス
lstm = LSTM(wx, wh, b)

# 順伝播
lstm.forward(x, h_prev, c_prev)


Peephole Connections

前の内部状態が各ゲートに影響与えるコネクションを追加
f:id:konchangakita:20210130124306p:plain
f:id:konchangakita:20210130013304p:plain:w480


さいごに

ここからが自然言語処理の本番だ
実践で使えるRNN応用技術にとりくんでいく