ITエンジニア ノイのブログ

ITエンジニアのノイです。 YouTubeで ITエンジニアのお勉強という学習用の動画を公開しています。チャンネル登録お願いします!https://m.youtube.com/channel/UCBKfJIMVWXd3ReG_FDh31Aw/playlists

勾配確認

数値微分誤差逆伝播によって求められる勾配は、計算方法や精度などの観点から異なる特徴を持ちます。そのため、両者に違いがあるのか確認する必要があります。差異がある場合には実装が間違っていることがあるので、デバッグにも使えます。

数値微分

微小な変化量(例えば、0.0001など)を入力変数に加えて関数を評価し、その結果から勾配を求める方法です。具体的には、中心差分法や前方差分法などが一般的に使用されます。数値微分の利点は、実装が比較的容易であり、任意の関数に対して使用できることです。ただし、数値微分は計算コストが高く、微小な変化量の選択や丸め誤差の影響を受ける可能性があるため、精度の面で制約があります。

誤差逆伝播

ニューラルネットワークなどのモデルにおいて、目的関数(損失関数)とモデルのパラメータの間の勾配を効率的に計算する手法です。誤差逆伝播では、順方向の計算と逆方向の計算を組み合わせることで、パラメータごとの勾配を効率的に求めることができます。誤差逆伝播の利点は、計算コストが数値微分に比べて低く、高速なパラメータ更新が可能であることです。また、誤差逆伝播はモデルの構造を利用しているため、数値微分よりも高い精度を実現できます。

数値微分はあくまで近似的な勾配を求める手法であり、計算コストが高い一方、誤差逆伝播は厳密な勾配を効率的に求めることができます。ただし、誤差逆伝播ニューラルネットワークなどの特定のモデルに適用される手法であるため、他の関数には直接適用することができません。

勾配確認

勾配確認(Gradient Checking)は、数値微分誤差逆伝播によって求めた勾配が一致しているかを確認する手法です。勾配確認は主に、実装した機械学習モデルの勾配計算部分が正確かつ適切に行われているかを検証するために使用されます。

MNISTでの確認

MNISTのデータセットを用意します

import tensorflow as tf
mnist = tf.keras.datasets.mnist
(X_train, y_train),(X_test, y_test) = mnist.load_data()

from sklearn.preprocessing import LabelBinarizer
lb = LabelBinarizer()

train = X_train/255
test = X_test/255
train = train.reshape(-1, 28*28)
test = test.reshape(-1, 28*28)
train_labels = lb.fit_transform(y_train)
test_labels = lb.fit_transform(y_test)

差異があるか確認する ここでNNetは別にニューラルネットを実装しています。

# データの読み込み
x_train, t_train = train, train_labels
x_test, t_test = test, test_labels

net = NNet(input_size=784, hidden_size=50, output_size=10)

x_batch = x_train[:3]
t_batch = t_train[:3]

# 数値微分で求めた勾配
grad_numerical = net.numerical_gradient_(x_batch, t_batch)

# 誤差逆伝播で求めた勾配
grad_backprop = net.gradient(x_batch, t_batch)

# 両者の勾配が概ね一致していることを確認する
for key in grad_numerical.keys():
    diff = np.average( np.abs(grad_backprop[key] - grad_numerical[key]) )
    print(key + ":" + str(diff))

出力結果

W1:4.330064037918453e-10
b1:2.6988644159537267e-09
W2:7.2531332852898484e-09
b2:1.4052603425768194e-07

W1,b1,W2,b2共に両者の誤差は小さいことがわかります。 数値微分誤差逆伝播によって求めた勾配を比較し、差異がある場合は問題がある可能性があるため、実装を見直す必要があります。

youtu.be