ITエンジニア ノイのブログ

ITエンジニアのノイです。 YouTubeで ITエンジニアのお勉強という学習用の動画を公開しています。チャンネル登録お願いします!https://m.youtube.com/channel/UCBKfJIMVWXd3ReG_FDh31Aw/playlists

BatchNormalization(バッチ正規化)のパラメータgammaとbeta

BatchNormalization

BatchNormalization(バッチ正規化)は、ニューラルネットワークの学習を安定化し、収束を速めるための手法の一つです。これは、各ミニバッチ内での入力データの平均をゼロにし、標準偏差を1に調整することによって、学習の安定性を向上させるものです。 平たく言うと、各層に入れるデータを正規化してあげないと、いきなりとんでもない値が飛び込んできて、変数が安定して収束しなくなっちゃうよ!だから、データを綺麗にしておこう!ってことです。

BatchNormalizationレイヤーの2つのパラメータ

gamma (γ): スケールを調整するパラメータ。学習時に更新される重みで、デフォルトでは1です。 beta (β): シフトを調整するパラメータ。学習時に更新される重みで、デフォルトでは0です。 これらのパラメータは、各特徴量に対して1つずつ存在し、学習中に最適な値に調整されます。gammaは特徴量のスケールを制御し、betaは特徴量のシフトを制御します。これにより、ネットワークが柔軟に学習できるようになり、勾配消失や爆発の問題を軽減します。

BatchNormalizationの実装例

class BatchNormalization:
    def __init__(self, gamma, beta, rho=0.9, moving_mean=None, moving_var=None):
        self.gamma = gamma 
        self.beta = beta 
        self.rho = rho 

        # 予測時に使用する平均と分散
        self.moving_mean = moving_mean   # muの移動平均
        self.moving_var = moving_var     # varの移動平均

        # 計算中に算出される値を保持しておく変数群
        self.batch_size = None
        self.x_mu = None
        self.x_std = None
        self.std = None
        self.dgamma = None
        self.dbeta = None

    def forward(self, x, train_flg=True):
        if x.ndim == 4:
            # 画像形式の場合
            N, C, H, W = x.shape
            x = x.transpose(0, 2, 3, 1) # NHWCに入れ替え
            x = x.reshape(N*H*W, C) # (N*H*W,C)の2次元配列に変換
            out = self.__forward(x, train_flg)
            out = out.reshape(N, H, W, C)# 4次元配列に変換
            out = out.transpose(0, 3, 1, 2) # 軸をNCHWに入れ替え
        elif x.ndim == 2:
            # 画像形式以外の場合
            out = self.__forward(x, train_flg)

        return out

    def __forward(self, x, train_flg, epsilon=1e-8):
        """
        x : 入力. n×dの行列. nはあるミニバッチのバッチサイズ. dは手前の層のノード数
        """
        if (self.moving_mean is None) or (self.moving_var is None):
            N, D = x.shape
            self.moving_mean = np.zeros(D)
            self.moving_var = np.zeros(D)

        if train_flg:
            # 入力xについて、nの方向に平均値を算出.
            mu = np.mean(x, axis=0) # 要素数d個のベクトル

            # 入力xから平均値を引く
            x_mu = x - mu   # n*d行列

            # 入力xの分散を求める
            var = np.mean(x_mu**2, axis=0)  # 要素数d個のベクトル

            # 入力xの標準偏差を求める(epsilonを足してから標準偏差を求める)
            std = np.sqrt(var + epsilon)  # 要素数d個のベクトル

            # 標準化
            x_std = x_mu / std  # n*d行列

            # 値を保持しておく
            self.batch_size = x.shape[0]
            self.x_mu = x_mu
            self.x_std = x_std
            self.std = std
            self.moving_mean = self.rho * self.moving_mean + (1-self.rho) * mu
            self.moving_var = self.rho * self.moving_var + (1-self.rho) * var
        else:
           #  予測時
            x_mu = x - self.moving_mean # n*d行列
            x_std = x_mu / np.sqrt(self.moving_var + epsilon) # n*d行列

        # gammaでスケールし、betaでシフトさせる
        out = self.gamma * x_std + self.beta # n*d行列
        return out

    def backward(self, dout):
        """
        逆伝播計算
        dout : CNNの場合は4次元、全結合層の場合は2次元
        """
        if dout.ndim == 4:
            # 画像形式の場合
            N, C, H, W = dout.shape
            dout = dout.transpose(0, 2, 3, 1) # NHWCに入れ替え
            dout = dout.reshape(N*H*W, C) # (N*H*W,C)の2次元配列に変換
            dx = self.__backward(dout)
            dx = dx.reshape(N, H, W, C)# 4次元配列に変換
            dx = dx.transpose(0, 3, 1, 2) # 軸をNCHWに入れ替え
        elif dout.ndim == 2:
            # 画像形式以外の場合
            dx = self.__backward(dout)

        return dx

    def __backward(self, dout):
        # betaの勾配
        dbeta = np.sum(dout, axis=0)

        # gammaの勾配(n方向に合計)
        dgamma = np.sum(self.x_std * dout, axis=0)

        # Xstdの勾配
        a1 = self.gamma * dout

        # Xmuの勾配(1つ目)
        a2 = a1 / self.std

        # 標準偏差の逆数の勾配(n方向に合計)
        a3 = np.sum(a1 * self.x_mu, axis=0)

        # 標準偏差の勾配
        a4 = -(a3) / (self.std * self.std)

        # 分散の勾配
        a5 = 0.5 * a4 / self.std

        # Xmuの2乗の勾配
        a6 = a5 / self.batch_size

        # Xmuの勾配(2つ目)
        a7 = 2.0  * self.x_mu * a6

        # muの勾配
        a8 = np.sum(-(a2+a7), axis=0)

        # Xの勾配
        dx = a2 + a7 +  a8 / self.batch_size # 第3項はn方向に平均

        self.dgamma = dgamma
        self.dbeta = dbeta

        return dx

このBatchNormalizationクラスは、バッチ正規化の順伝播と逆伝播を実装しています。

__init__ メソッド:

このメソッドは BatchNormalization クラスのコンストラクタで、バッチ正規化のパラメータを初期化します。

  • gamma: スケール調整のためのパラメータで、学習中に更新されます。
  • beta: シフト調整のためのパラメータで、学習中に更新されます。
  • rho: 移動平均を算出する際に使用する係数です。
  • moving_mean, moving_var: 予測時に使用する平均と分散を保存する変数です。

forward メソッド:

バッチ正規化の順伝播計算を行います。

  • 入力 x は、4次元の画像形式または2次元の形式に対応しています。
  • train_flg=True の場合は学習時、train_flg=False の場合は予測時の計算を行います。
  • 入力 x に対して平均、分散、標準化を計算し、それに対して gamma でスケールし、beta でシフトさせた結果を出力します。
  • 学習時には、移動平均と移動分散も更新します。

__forward メソッド:

実際に順伝播計算を行う内部メソッドです。

  • train_flg=True の場合は学習時、train_flg=False の場合は予測時の計算を行います。

backward メソッド:

バッチ正規化の逆伝播計算を行います。

  • 勾配として受け取った dout に対して、gammabeta の勾配を計算します。
  • X_std の勾配を計算し、それを用いて X_mustdgamma の勾配を求めます。
  • 最終的には、X_mustdgamma の勾配を用いて入力 x の勾配を計算します。

youtu.be

深層学習教科書 ディープラーニング G検定(ジェネラリスト)公式テキスト 第2版 (EXAMPRESS) [ 一般社団法人日本ディープラーニング協会 ]

価格:3,080円
(2023/8/16 20:42時点)
感想(5件)