im2colとは
im2colは、画像認識において利用される関数で、主に高速な行列演算を活かすために使用されます。この関数は、効率的なnumpyの操作を可能にします。ループを使用することができますが、これはnumpyの優れた特性を活かす点で劣る方法です。(numpyはforループなどで処理すると効率が低下する傾向があります)
import numpy as np
def im2col (input_data, filter_h, filter_w, stride=1 , pad=0 , constant_values=0 ):
N, C, H, W = input_data.shape
out_h = (H + 2 *pad - filter_h)//stride + 1
out_w = (W + 2 *pad - filter_w)//stride + 1
img = np.pad(input_data, [(0 ,0 ), (0 ,0 ), (pad, pad), (pad, pad)],
'constant' , constant_values=constant_values)
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range (filter_h):
"""
フィルターの高さ方向のループ
"""
y_max = y + stride*out_h
for x in range (filter_w):
"""
フィルターの幅方向のループ
"""
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
col = col.transpose(0 , 4 , 5 , 1 , 2 , 3 ).reshape(N*out_h*out_w, -1 )
return col
コード解説
N, C, H, W = input_data.shape
input_data
は、画像データのミニバッチ(複数の画像が同時に処理される)を表す4次元のNumPy配列です。以下に、各次元の意味を示します:
第1次元 (N): バッチサイズ。同時に処理される画像の数を示します。
第2次元 (C): チャンネル数。画像ごとに異なる色や特徴マップがある場合、それらの数を示します。
第3次元 (H): 画像の高さ。ピクセル の行数を示します。
第4次元 (W): 画像の幅。ピクセル の列数を示します。
具体的な例として、input_data
が次のような場合を考えてみましょう:
import numpy as np
input_data = np.random.rand(2 , 3 , 4 , 4 )
この場合、input_data
は2つの画像を含み、各画像は3つのチャンネル(例: 赤、緑、青)を持ち、高さと幅がそれぞれ4ピクセル × 4ピクセル です。このような形式のデータが im2col
関数の入力として与えられます。
y_max = y + stride*out_h
この行のコード y_max = y + stride*out_h
は、im2col
関数内のループにおいて、フィルターの高さ方向における最大座標を計算しています。
まず、y
はフィルターの高さ方向におけるループ変数で、フィルター内の行を指します。stride
はフィルターの移動幅を表し、out_h
は畳み込み演算後の出力データの高さを示します。
この式 y_max = y + stride*out_h
は、フィルターの高さ方向において、現在の y
から stride * out_h
を加算した座標を y_max
として計算しています。この y_max
は、フィルターを適用する際に元画像から切り取る領域の終端(最大座標)を表します。
言い換えれば、この計算によって、現在のフィルター位置から stride * out_h
だけ移動した位置が、畳み込み演算の適用範囲における高さの最大座標を示しています。これは畳み込み演算が適用されるフィルターの範囲を指定するために使われます。
同様の考え方が横方向における x_max
の計算にも適用されています。
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
この部分のコードは、img
というパディング済みの画像データから、フィルター内のある1つの要素に対応する画像中のピクセル 値を取り出し、col
配列に格納しています。これは im2col
操作の一部であり、畳み込み演算を行うためのデータの再配置を行っています。
col[:, :, y, x, :, :]
: col
配列の特定の位置に対応するサブ配列。
この位置は、フィルター内のある1つの要素に対応します。
img[:, :, y:y_max:stride, x:x_max:stride]
: パディング済みの画像データから対応する部分を切り出す。
y:y_max:stride
は、y
から y_max
までの範囲を stride
刻みで指定しています。
x:x_max:stride
は、x
から x_max
までの範囲を stride
刻みで指定しています。
これにより、col
配列の特定の位置には、元画像から取り出されたフィルターに相当する領域のピクセル 値が格納されます。この操作をフィルター内の各位置に対して繰り返すことで、im2col
操作が行われ、畳み込み演算がより効率的に実行されるようになります。
「:」の意味
このコードの:
は、NumPyにおいて「すべての要素」を指定するための記号です。具体的には、スライシングにおいてstart:stop:step
の形式で使用され、:
が指定された場合は全要素を表します。
col[:, :, y, x, :, :]
とimg[:, :, y:y_max:stride, x:x_max:stride]
の部分では、各次元に対して:(コロン)
が使用されています。具体的な意味は次の通りです:
col[:, :, y, x, :, :]
: col
の特定の位置に対応する部分行列を指定しています。y
およびx
に対応する次元において、:
が使われているため、それぞれの次元において全ての要素を取得しています。
img[:, :, y:y_max:stride, x:x_max:stride]
: img
から特定の領域を切り出すためのスライスを指定しています。ここでも、:(コロン)
を使用して各次元において全ての要素を取得していますが、y:y_max:stride
およびx:x_max:stride
の部分では、stride
刻みでサンプリングするための指定が行われています。
簡単に言うと、:
は「全ての要素」を指定するためのものであり、これにより対象の次元全体のデータを操作することができます。
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(Nout_h out_w, -1)
この行のコードは、col
配列の軸を入れ替えてから、2次元の行列に変形しています。
1.col.transpose(0, 4, 5, 1, 2, 3)
:
transpose
関数は、配列の軸を入れ替えます。
引数で指定された軸の順序に従って、col
の軸が入れ替えられます。
具体的な引数 (0, 4, 5, 1, 2, 3)
では、元の軸の順序を以下のように変更しています:
0
から 0
1
から 4
2
から 5
3
から 1
4
から 2
5
から 3
この操作により、col
の次元の順序が変更されます。
2.reshape(N*out_h*out_w, -1)
:
reshape
関数は、配列の形状を変更します。
第1引数で指定された形状に変換しますが、-1
を指定することで、残りの次元は自動的に計算されます。
具体的な引数 (N*out_h*out_w, -1)
では、col
を 2 次元の行列に変形しています。
行数は N*out_h*out_w
となり、列数は元の次元数に応じて自動的に計算されます。
これらの操作は、通常は畳み込みニューラルネットワーク (CNN) の処理において、im2col
操作を経て得られたデータを、通常の行列形式に変換するために行われます。
畳み込み演算のサイズ
よく見る公式です。
畳み込み演算のサイズ
このコードは、畳み込み演算やプーリング演算を適用した後の出力データの形状(高さと幅)を計算しています。
out_h = (H + 2 *pad - filter_h)//stride + 1
out_w = (W + 2 *pad - filter_w)//stride + 1
H
: 入力データの高さ
W
: 入力データの幅
filter_h
: フィルター(カーネル )の高さ
filter_w
: フィルターの幅
stride
: ストライド (フィルターを適用する間隔)
pad
: パディング(入力データの周囲に追加される枠の数)
1.出力データの高さ (out_h
) の計算:
(H + 2*pad - filter_h)
は、パディングを考慮したフィルターを適用するための縦方向の空間です。
//stride
はストライド によってフィルターを適用する間隔で割る操作です。
+1
は、最後のフィルターが入力データの最後まで適用されるように、1 を加えています。
端数は切り捨てるため、整数除算 (//
) が使われています。
2.出力データの幅 (out_w
) の計算:
同様に、(W + 2*pad - filter_w)
は、パディングを考慮したフィルターを適用するための横方向の空間です。
//stride
はストライド によってフィルターを適用する間隔で割る操作です。
+1
は、最後のフィルターが入力データの最後まで適用されるように、1 を加えています。
端数は切り捨てるため、整数除算 (//
) が使われています。
VIDEO youtu.be