ITエンジニア ノイのブログ

ITエンジニアのノイです。 YouTubeで ITエンジニアのお勉強という学習用の動画を公開しています。チャンネル登録お願いします!https://m.youtube.com/channel/UCBKfJIMVWXd3ReG_FDh31Aw/playlists

predictのバッチ化

バッチ処理とは

バッチ処理は、複数のタスクやデータを一括で処理する手法です。一つのまとまり(バッチ)に含まれるタスクやデータは、同様の処理を受けることが多いです。意味合いとしては一個ずつ処理をすると大変だから、まとめて処理をしよう!という意味になります。

predictのバッチ化

機械学習モデルにテストデータを入力するとき、データを1つずつ入力するのではなく、まとめて1000個入力してしまおう!という考えです。 リアルタイムで1つずつ推論しなければならないこともありますが、ソリューションによってはバッチ化することもできます。

テストデータを1つずつ入力する。

def predict(network, x):
    W1, W2, W3 = network["W1"],network["W2"],network["W3"]
    b1, b2, b3 = network["b1"],network["b2"],network["b3"]    
    
    a1 = np.dot(x,W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1,W2) + b2 
    z2 = sigmoid(a2)
    a3 = np.dot(z2,W3) + b3
    return softmax(a3)
NUM = 1000
for i in range(NUM):
    y = predict(network, test[i])
    p = np.append(p, np.argmax(y))

1000個のテストデータを1つずつPredict関数に入れています。 forで回すのは時間もかかるので大変ですね。。。

テストデータを1000個まとめて入力する(バッチ化)

NUM = 1000
y = predict(network, test[:NUM])

ある意味スッキリしていますね。この方が当然という気もしますが。。。

print(test[:NUM].shape, W1.shape, W2.shape, W3.shape, test_labels[:NUM].shape)
print(b1.shape, b2.shape, b3.shape)

結果

(1000, 784) (784, 50) (50, 100) (100, 10) (1000, 10)
(50,) (100,) (10,)

テストデータをtest[:NUM]として入力していますが、(1000, 784)と1000枚分の28*28(mnistなので)となっています。 ラベルのtest_labels[:NUM]についても1000枚分に対して0~9のラベルが対応していることがわかります。

おわりに

今回はpredictでのバッチ処理の例を紹介しました。学習時には当然のようにバッチ処理をしますね。

youtu.be