ITエンジニア ノイのブログ

ITエンジニアのノイです。 YouTubeで ITエンジニアのお勉強という学習用の動画を公開しています。チャンネル登録お願いします!https://m.youtube.com/channel/UCBKfJIMVWXd3ReG_FDh31Aw/playlists

逆伝搬の実装

計算グラフと逆伝搬

ディープラーニングにおける計算グラフと逆伝搬(バックプロパゲーション)は、ニューラルネットワークの学習アルゴリズムの重要な要素です。計算グラフは、ネットワーク内の演算やデータの流れを視覚的に表現するための方法です。逆伝搬は、誤差を後ろ向きに伝えながら、ネットワークのパラメータを更新する方法です。

計算グラフ

計算グラフは、ニューラルネットワークの演算をノードとエッジで表現します。ノードは演算を表し、エッジはデータの流れを表します。ノードには、入力データを受け取り、それを使って演算を行い、出力を生成します。エッジは、ノード間をデータが伝わる経路を示します。 ちなみに計算グラフについてはこの本が一番わかりやすいと思いました。りんごの例で説明しています。

ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装 [ 斎藤 康毅 ]

価格:3,740円
(2023/6/17 12:43時点)
感想(8件)

逆伝搬

逆伝搬は、計算グラフを使って誤差を後ろ向きに伝えるアルゴリズムです。まず、ネットワークの出力と正解の間の誤差を計算します。その後、誤差を逆向きに伝播させながら、各ノードでの微分値を計算します。微分値は、各ノードにおける演算の局所的な感度を表します。逆伝搬の過程で、微分値は連鎖率を利用して計算されます。

ノードの実装

加算ノードの逆伝播

𝑧 = 𝑥 + 𝑦 という数式について考えます。

加算ノードの逆伝播
連鎖率の考え方を適用すると∂L/∂xと∂L/∂yは簡単に求められます。

class AddLayer:
    def __init__(self):
        pass 
        
    def forward(self, x, y):
        return x + y
    
    def backward(self, dout):
        """
        doutは上流(出力)側の勾配
        """
        dLdx = dout
        dLdy = dout
        return dLdx, dLdy
al = AddLayer()
x = np.array([1])
y = np.array([2])
dout = np.array([5])
print(x)
print(y)
print("forward=", al.forward(x, y))
dLdx , dLdy = al.backward(dout)
print("dLdx=", dLdx)
print("dLdy=", dLdy)

結果

[1]
[2]
forward= [3]
dLdx= [5]
dLdy= [5]

乗算ノードの逆伝播

次は𝑧 = 𝑥×𝑦という乗算について考えてみます。

乗算ノードの逆伝搬
今度は×1ではなく、x側は×y、y側は×xとなっています。ここでも連鎖率の考え方を使ってますが、∂z/∂xと∂z/∂yが偏微分であるため、xで微分するとyが残り、yで微分するとxが残るためです。

class MultiLayer:
    def __init__(self):
        self.x = None
        self.y = None
        
    def forward(self, x, y):
        self.x = x #記憶しておく
        self.y = y #記憶しておく
        return x * y
    
    def backward(self, dout):  
        dLdx = dout * self.y
        dLdy = dout * self.x
        return dLdx, dLdy
ml = MultiLayer()
x = np.array([1])
y = np.array([2])
print(x)
print(y)

dout = np.array([5])
print("forward=", ml.forward(x, y))
dLdx , dLdy = ml.backward(dout)
print("dLdx=", dLdx)
print("dLdy=", dLdy)

結果

[1]
[2]
forward= [2]
dLdx= [10]
dLdy= [5]

逆伝搬の実装は連鎖率と偏微分

このように連鎖率を用いることで簡単に表現することができます。何を掛けるかというのが重要ですが、これは偏微分を計算することになります。頭の中で考えると紛らわしくなってしまうので、そんな時は計算グラフを書くと整理できます。

youtu.be