バッチ処理とは
バッチ処理は、複数のタスクやデータを一括で処理する手法です。一つのまとまり(バッチ)に含まれるタスクやデータは、同様の処理を受けることが多いです。意味合いとしては一個ずつ処理をすると大変だから、まとめて処理をしよう!という意味になります。
predictのバッチ化
機械学習モデルにテストデータを入力するとき、データを1つずつ入力するのではなく、まとめて1000個入力してしまおう!という考えです。 リアルタイムで1つずつ推論しなければならないこともありますが、ソリューションによってはバッチ化することもできます。
テストデータを1つずつ入力する。
def predict(network, x): W1, W2, W3 = network["W1"],network["W2"],network["W3"] b1, b2, b3 = network["b1"],network["b2"],network["b3"] a1 = np.dot(x,W1) + b1 z1 = sigmoid(a1) a2 = np.dot(z1,W2) + b2 z2 = sigmoid(a2) a3 = np.dot(z2,W3) + b3 return softmax(a3)
NUM = 1000 for i in range(NUM): y = predict(network, test[i]) p = np.append(p, np.argmax(y))
1000個のテストデータを1つずつPredict関数に入れています。 forで回すのは時間もかかるので大変ですね。。。
テストデータを1000個まとめて入力する(バッチ化)
NUM = 1000
y = predict(network, test[:NUM])
ある意味スッキリしていますね。この方が当然という気もしますが。。。
print(test[:NUM].shape, W1.shape, W2.shape, W3.shape, test_labels[:NUM].shape) print(b1.shape, b2.shape, b3.shape)
結果
(1000, 784) (784, 50) (50, 100) (100, 10) (1000, 10) (50,) (100,) (10,)
テストデータをtest[:NUM]
として入力していますが、(1000, 784)と1000枚分の28*28(mnistなので)となっています。
ラベルのtest_labels[:NUM]
についても1000枚分に対して0~9のラベルが対応していることがわかります。