確率的勾配降下法
確率的勾配降下法(Stochastic Gradient Descent, SGD)は、機械学習や深層学習において最適化アルゴリズムとして広く使われています。SGDは、勾配降下法(Gradient Descent)を基にしており、大規模なデータセットを扱う際に効果的です。
通常の勾配降下法では、トレーニングデータセット全体の勾配を計算して、パラメータを更新します。しかし、大規模なデータセットでは、全データを使って勾配を計算することは計算コストが高くなります。そのため、SGDではランダムに選ばれたサンプル(ミニバッチ)の勾配を使用してパラメータを更新します。
SGDの手順
SGDはランダム性を持つため、収束までの時間が不確定であるという特徴があります。一方で、ミニバッチごとの勾配計算が高速であるため、大規模なデータセットでも比較的高速に学習を進めることができます。
SGDは機械学習モデルの学習において広く使用される手法であり、ニューラルネットワークのトレーニングにもよく利用されます。また、SGDにはいくつかの派生手法も存在し、例えば、モーメンタム法(Momentum)、Nesterov accelerated gradient(NAG)、Adagrad、Adamなどがあります。これらの手法はSGDの収束性や学習速度の改善を目指しています。
コードの概要(実装)
SGDの更新式はこのようになっています。
class SGD: def __init__(self, lr=0.001): self.lr = lr def update(self, params, grads): for key in params.keys(): params[key] -= self.lr * grads[key] #更新前のパラーメータから勾配を持っているgradsにlrを掛けて引く
SGD
というクラスが定義されています。__init__
関数では、学習係数(learning rate)を指定できるようにしています。デフォルト値は0.001です。update
関数は、重みの更新を行います。この関数には2つの引数があります。params
とgrads
。params
は重みの辞書であり、各重みの値を格納しています。grads
は勾配の辞書であり、各重みに対応する勾配の値を格納しています。update
関数では、重みの更新を行うために、params
の各要素に対して学習係数と対応する勾配を掛けて減算しています。
この実装では、各重みごとに学習係数を乗じた勾配を引くことで、重みの更新を行っています。SGDは、一度に全てのトレーニングデータを用いて勾配を計算するのではなく、ランダムに選ばれたデータ(またはデータのミニバッチ)の勾配を利用するため、「確率的」と呼ばれます。
ポイント
lr
(学習係数)は、SGDのパフォーマンスに大きな影響を与えます。学習係数が小さいと収束までのステップ数が増えますが、大きすぎると発散する可能性があります。適切な学習係数を選ぶことが重要です。params
とgrads
は辞書形式で与えられることが前提です。これにより、ネットワークの各層の重みや勾配をまとめて扱うことができます。
重みの更新
params[key] -= self.lr * grads[key]
この式は重みの更新を行うための式です。
重みの更新には、勾配降下法の基本的な原理が使われています。勾配降下法では、最小化したい目的関数の勾配(導関数)を用いて、各パラメータを更新して最適な値に近づけることを目指します。具体的には、目的関数の勾配の逆方向にパラメータを移動させることで、目的関数の値を小さくすることを目指します。
式の解説
params[key]
は、重みの値を表しています。grads[key]
は、対応する重みに対する勾配を表しています。(dL/dθ)self.lr
は学習係数(learning rate)であり、重みの更新のステップサイズを制御します。
式の意味
self.lr * grads[key]
は、学習係数と対応する重みの勾配の積です。この値は、重みの更新の大きさを制御します。学習係数が大きい場合は大きな更新が行われ、学習係数が小さい場合は小さな更新が行われます。params[key] -= self.lr * grads[key]
は、重みの現在の値から学習係数と勾配の積を引くことで、重みを更新します。これにより、重みが目的関数の最小値に向かって更新されます。params[key]
とgrads[key]
は、同じキー(重みの名前やインデックス)に対応している必要があります。つまり、params
とgrads
は、同じ構造を持った辞書として与えられる必要があります。- この更新式は、SGDによる重みの更新を実現するための基本的な式です。他の最適化アルゴリズム(例:Adam、RMSpropなど)では、より複雑な更新式が使われる場合がありますが、基本的な考え方は同じです。