BatchNormalization
BatchNormalization(バッチ正規化)は、ニューラルネットワークの学習を安定化し、収束を速めるための手法の一つです。これは、各ミニバッチ内での入力データの平均をゼロにし、標準偏差を1に調整することによって、学習の安定性を向上させるものです。 平たく言うと、各層に入れるデータを正規化してあげないと、いきなりとんでもない値が飛び込んできて、変数が安定して収束しなくなっちゃうよ!だから、データを綺麗にしておこう!ってことです。
BatchNormalizationレイヤーの2つのパラメータ
gamma (γ): スケールを調整するパラメータ。学習時に更新される重みで、デフォルトでは1です。 beta (β): シフトを調整するパラメータ。学習時に更新される重みで、デフォルトでは0です。 これらのパラメータは、各特徴量に対して1つずつ存在し、学習中に最適な値に調整されます。gammaは特徴量のスケールを制御し、betaは特徴量のシフトを制御します。これにより、ネットワークが柔軟に学習できるようになり、勾配消失や爆発の問題を軽減します。
BatchNormalizationの実装例
class BatchNormalization: def __init__(self, gamma, beta, rho=0.9, moving_mean=None, moving_var=None): self.gamma = gamma self.beta = beta self.rho = rho # 予測時に使用する平均と分散 self.moving_mean = moving_mean # muの移動平均 self.moving_var = moving_var # varの移動平均 # 計算中に算出される値を保持しておく変数群 self.batch_size = None self.x_mu = None self.x_std = None self.std = None self.dgamma = None self.dbeta = None def forward(self, x, train_flg=True): if x.ndim == 4: # 画像形式の場合 N, C, H, W = x.shape x = x.transpose(0, 2, 3, 1) # NHWCに入れ替え x = x.reshape(N*H*W, C) # (N*H*W,C)の2次元配列に変換 out = self.__forward(x, train_flg) out = out.reshape(N, H, W, C)# 4次元配列に変換 out = out.transpose(0, 3, 1, 2) # 軸をNCHWに入れ替え elif x.ndim == 2: # 画像形式以外の場合 out = self.__forward(x, train_flg) return out def __forward(self, x, train_flg, epsilon=1e-8): """ x : 入力. n×dの行列. nはあるミニバッチのバッチサイズ. dは手前の層のノード数 """ if (self.moving_mean is None) or (self.moving_var is None): N, D = x.shape self.moving_mean = np.zeros(D) self.moving_var = np.zeros(D) if train_flg: # 入力xについて、nの方向に平均値を算出. mu = np.mean(x, axis=0) # 要素数d個のベクトル # 入力xから平均値を引く x_mu = x - mu # n*d行列 # 入力xの分散を求める var = np.mean(x_mu**2, axis=0) # 要素数d個のベクトル # 入力xの標準偏差を求める(epsilonを足してから標準偏差を求める) std = np.sqrt(var + epsilon) # 要素数d個のベクトル # 標準化 x_std = x_mu / std # n*d行列 # 値を保持しておく self.batch_size = x.shape[0] self.x_mu = x_mu self.x_std = x_std self.std = std self.moving_mean = self.rho * self.moving_mean + (1-self.rho) * mu self.moving_var = self.rho * self.moving_var + (1-self.rho) * var else: # 予測時 x_mu = x - self.moving_mean # n*d行列 x_std = x_mu / np.sqrt(self.moving_var + epsilon) # n*d行列 # gammaでスケールし、betaでシフトさせる out = self.gamma * x_std + self.beta # n*d行列 return out def backward(self, dout): """ 逆伝播計算 dout : CNNの場合は4次元、全結合層の場合は2次元 """ if dout.ndim == 4: # 画像形式の場合 N, C, H, W = dout.shape dout = dout.transpose(0, 2, 3, 1) # NHWCに入れ替え dout = dout.reshape(N*H*W, C) # (N*H*W,C)の2次元配列に変換 dx = self.__backward(dout) dx = dx.reshape(N, H, W, C)# 4次元配列に変換 dx = dx.transpose(0, 3, 1, 2) # 軸をNCHWに入れ替え elif dout.ndim == 2: # 画像形式以外の場合 dx = self.__backward(dout) return dx def __backward(self, dout): # betaの勾配 dbeta = np.sum(dout, axis=0) # gammaの勾配(n方向に合計) dgamma = np.sum(self.x_std * dout, axis=0) # Xstdの勾配 a1 = self.gamma * dout # Xmuの勾配(1つ目) a2 = a1 / self.std # 標準偏差の逆数の勾配(n方向に合計) a3 = np.sum(a1 * self.x_mu, axis=0) # 標準偏差の勾配 a4 = -(a3) / (self.std * self.std) # 分散の勾配 a5 = 0.5 * a4 / self.std # Xmuの2乗の勾配 a6 = a5 / self.batch_size # Xmuの勾配(2つ目) a7 = 2.0 * self.x_mu * a6 # muの勾配 a8 = np.sum(-(a2+a7), axis=0) # Xの勾配 dx = a2 + a7 + a8 / self.batch_size # 第3項はn方向に平均 self.dgamma = dgamma self.dbeta = dbeta return dx
このBatchNormalization
クラスは、バッチ正規化の順伝播と逆伝播を実装しています。
__init__
メソッド:
このメソッドは BatchNormalization
クラスのコンストラクタで、バッチ正規化のパラメータを初期化します。
gamma
: スケール調整のためのパラメータで、学習中に更新されます。beta
: シフト調整のためのパラメータで、学習中に更新されます。rho
: 移動平均を算出する際に使用する係数です。moving_mean
,moving_var
: 予測時に使用する平均と分散を保存する変数です。
forward
メソッド:
バッチ正規化の順伝播計算を行います。
- 入力
x
は、4次元の画像形式または2次元の形式に対応しています。 train_flg=True
の場合は学習時、train_flg=False
の場合は予測時の計算を行います。- 入力
x
に対して平均、分散、標準化を計算し、それに対してgamma
でスケールし、beta
でシフトさせた結果を出力します。 - 学習時には、移動平均と移動分散も更新します。
__forward
メソッド:
実際に順伝播計算を行う内部メソッドです。
train_flg=True
の場合は学習時、train_flg=False
の場合は予測時の計算を行います。
backward
メソッド:
バッチ正規化の逆伝播計算を行います。
- 勾配として受け取った
dout
に対して、gamma
とbeta
の勾配を計算します。 X_std
の勾配を計算し、それを用いてX_mu
、std
、gamma
の勾配を求めます。- 最終的には、
X_mu
、std
、gamma
の勾配を用いて入力x
の勾配を計算します。
深層学習教科書 ディープラーニング G検定(ジェネラリスト)公式テキスト 第2版 (EXAMPRESS) [ 一般社団法人日本ディープラーニング協会 ] 価格:3,080円 |