im2colとは
im2colは、画像認識において利用される関数で、主に高速な行列演算を活かすために使用されます。この関数は、効率的なnumpyの操作を可能にします。ループを使用することができますが、これはnumpyの優れた特性を活かす点で劣る方法です。(numpyはforループなどで処理すると効率が低下する傾向があります)
import numpy as np def im2col(input_data, filter_h, filter_w, stride=1, pad=0, constant_values=0): # 入力データのデータ数, チャンネル数, 高さ, 幅を取得する N, C, H, W = input_data.shape # 出力データ(畳み込みまたはプーリングの演算後)の形状を計算する out_h = (H + 2*pad - filter_h)//stride + 1 # 出力データの高さ(端数は切り捨てる) out_w = (W + 2*pad - filter_w)//stride + 1 # 出力データの幅(端数は切り捨てる) # パディング処理 img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant', constant_values=constant_values) # pad=1以上の場合、周囲を0で埋める # 配列の初期化 col = np.zeros((N, C, filter_h, filter_w, out_h, out_w)) # 配列を並び替える(フィルター内のある1要素に対応する画像中の画素を取り出してcolに代入する) for y in range(filter_h): """ フィルターの高さ方向のループ """ y_max = y + stride*out_h for x in range(filter_w): """ フィルターの幅方向のループ """ x_max = x + stride*out_w # imgから値を取り出し、colに入れる col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride] # y:y_max:strideの意味はyからy_maxまでの場所をstride刻みで指定している # x:x_max:stride の意味はxからx_maxまでの場所をstride刻みで指定している # 軸を入れ替えて、2次元配列(行列)に変形する col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1) return col
コード解説
N, C, H, W = input_data.shape
input_data
は、画像データのミニバッチ(複数の画像が同時に処理される)を表す4次元のNumPy配列です。以下に、各次元の意味を示します:
- 第1次元 (N): バッチサイズ。同時に処理される画像の数を示します。
- 第2次元 (C): チャンネル数。画像ごとに異なる色や特徴マップがある場合、それらの数を示します。
- 第3次元 (H): 画像の高さ。ピクセルの行数を示します。
- 第4次元 (W): 画像の幅。ピクセルの列数を示します。
具体的な例として、input_data
が次のような場合を考えてみましょう:
import numpy as np # バッチサイズ: 2, チャンネル数: 3, 画像の高さ: 4, 画像の幅: 4 input_data = np.random.rand(2, 3, 4, 4)
この場合、input_data
は2つの画像を含み、各画像は3つのチャンネル(例: 赤、緑、青)を持ち、高さと幅がそれぞれ4ピクセル × 4ピクセルです。このような形式のデータが im2col
関数の入力として与えられます。
y_max = y + stride*out_h
この行のコード y_max = y + stride*out_h
は、im2col
関数内のループにおいて、フィルターの高さ方向における最大座標を計算しています。
まず、y
はフィルターの高さ方向におけるループ変数で、フィルター内の行を指します。stride
はフィルターの移動幅を表し、out_h
は畳み込み演算後の出力データの高さを示します。
この式 y_max = y + stride*out_h
は、フィルターの高さ方向において、現在の y
から stride * out_h
を加算した座標を y_max
として計算しています。この y_max
は、フィルターを適用する際に元画像から切り取る領域の終端(最大座標)を表します。
言い換えれば、この計算によって、現在のフィルター位置から stride * out_h
だけ移動した位置が、畳み込み演算の適用範囲における高さの最大座標を示しています。これは畳み込み演算が適用されるフィルターの範囲を指定するために使われます。
同様の考え方が横方向における x_max
の計算にも適用されています。
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
この部分のコードは、img
というパディング済みの画像データから、フィルター内のある1つの要素に対応する画像中のピクセル値を取り出し、col
配列に格納しています。これは im2col
操作の一部であり、畳み込み演算を行うためのデータの再配置を行っています。
col[:, :, y, x, :, :]
:col
配列の特定の位置に対応するサブ配列。- この位置は、フィルター内のある1つの要素に対応します。
img[:, :, y:y_max:stride, x:x_max:stride]
: パディング済みの画像データから対応する部分を切り出す。y:y_max:stride
は、y
からy_max
までの範囲をstride
刻みで指定しています。x:x_max:stride
は、x
からx_max
までの範囲をstride
刻みで指定しています。
これにより、col
配列の特定の位置には、元画像から取り出されたフィルターに相当する領域のピクセル値が格納されます。この操作をフィルター内の各位置に対して繰り返すことで、im2col
操作が行われ、畳み込み演算がより効率的に実行されるようになります。
「:」の意味
このコードの:
は、NumPyにおいて「すべての要素」を指定するための記号です。具体的には、スライシングにおいてstart:stop:step
の形式で使用され、:
が指定された場合は全要素を表します。
col[:, :, y, x, :, :]
とimg[:, :, y:y_max:stride, x:x_max:stride]
の部分では、各次元に対して:(コロン)
が使用されています。具体的な意味は次の通りです:
col[:, :, y, x, :, :]
:col
の特定の位置に対応する部分行列を指定しています。y
およびx
に対応する次元において、:
が使われているため、それぞれの次元において全ての要素を取得しています。img[:, :, y:y_max:stride, x:x_max:stride]
:img
から特定の領域を切り出すためのスライスを指定しています。ここでも、:(コロン)
を使用して各次元において全ての要素を取得していますが、y:y_max:stride
およびx:x_max:stride
の部分では、stride
刻みでサンプリングするための指定が行われています。
簡単に言うと、:
は「全ての要素」を指定するためのものであり、これにより対象の次元全体のデータを操作することができます。
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(Nout_hout_w, -1)
この行のコードは、col
配列の軸を入れ替えてから、2次元の行列に変形しています。
1.col.transpose(0, 4, 5, 1, 2, 3)
:
-
transpose
関数は、配列の軸を入れ替えます。 - 引数で指定された軸の順序に従って、
col
の軸が入れ替えられます。 - 具体的な引数
(0, 4, 5, 1, 2, 3)
では、元の軸の順序を以下のように変更しています:0
から0
1
から4
2
から5
3
から1
4
から2
5
から3
- この操作により、
col
の次元の順序が変更されます。
2.reshape(N*out_h*out_w, -1)
:
-
reshape
関数は、配列の形状を変更します。 - 第1引数で指定された形状に変換しますが、
-1
を指定することで、残りの次元は自動的に計算されます。 - 具体的な引数
(N*out_h*out_w, -1)
では、col
を 2 次元の行列に変形しています。 - 行数は
N*out_h*out_w
となり、列数は元の次元数に応じて自動的に計算されます。
これらの操作は、通常は畳み込みニューラルネットワーク (CNN) の処理において、im2col
操作を経て得られたデータを、通常の行列形式に変換するために行われます。
畳み込み演算のサイズ
よく見る公式です。
このコードは、畳み込み演算やプーリング演算を適用した後の出力データの形状(高さと幅)を計算しています。
# 出力データ(畳み込みまたはプーリングの演算後)の形状を計算する out_h = (H + 2*pad - filter_h)//stride + 1 # 出力データの高さ(端数は切り捨てる) out_w = (W + 2*pad - filter_w)//stride + 1 # 出力データの幅(端数は切り捨てる)
H
: 入力データの高さW
: 入力データの幅filter_h
: フィルター(カーネル)の高さfilter_w
: フィルターの幅stride
: ストライド(フィルターを適用する間隔)pad
: パディング(入力データの周囲に追加される枠の数)
1.出力データの高さ (out_h
) の計算:
-
(H + 2*pad - filter_h)
は、パディングを考慮したフィルターを適用するための縦方向の空間です。 -
//stride
はストライドによってフィルターを適用する間隔で割る操作です。 -
+1
は、最後のフィルターが入力データの最後まで適用されるように、1 を加えています。 - 端数は切り捨てるため、整数除算 (
//
) が使われています。
2.出力データの幅 (out_w
) の計算:
- 同様に、
(W + 2*pad - filter_w)
は、パディングを考慮したフィルターを適用するための横方向の空間です。 -
//stride
はストライドによってフィルターを適用する間隔で割る操作です。 -
+1
は、最後のフィルターが入力データの最後まで適用されるように、1 を加えています。 - 端数は切り捨てるため、整数除算 (
//
) が使われています。
深層学習教科書 ディープラーニング G検定(ジェネラリスト)公式テキスト 第2版 (EXAMPRESS) [ 一般社団法人日本ディープラーニング協会 ] 価格:3,080円 |