ITエンジニア ノイのブログ

ITエンジニアのノイです。 YouTubeで ITエンジニアのお勉強という学習用の動画を公開しています。チャンネル登録お願いします!https://m.youtube.com/channel/UCBKfJIMVWXd3ReG_FDh31Aw/playlists

学習済みの重みの読み込み

やりたいこと

sample_weight.pklの事前学習の重みを読み込みたい。

sample_weight.pklの例としてオライリーさんのGitHubにあるものを使います。 github.com

ファイルから読み込む

def init_network():
    with open("sample_weight.pkl", "rb") as f:
        network = pickle.load(f)
    return network

network = init_network()
print(network.keys())

結果

dict_keys(['b2', 'W1', 'b1', 'W2', 'W3', 'b3'])

GoogleColabで実行する場合には図のように配置をしてください。

sample_weight.pklの配置

pickle.load

pickle.load関数は、pickle形式のファイルからオブジェクトを読み込むために使用されます。具体的には、pickle.dump関数によって保存されたオブジェクトを読み込み、Pythonのオブジェクトとして復元します。 pickle形式はPython固有のものであり、他のプログラミング言語との互換性はありません。また、信頼できないソースからのpickleファイルの読み込みは、セキュリティ上のリスクを引き起こす可能性があるため、注意が必要です。

補足 オブジェクトの直列化(シリアライズ)と非直列化(デシリアライズ

直列化(シリアライズは、オブジェクトをそのままの形式ではなく、バイト列やテキストなどの直列化された形式に変換することです。このプロセスにより、オブジェクトの状態やデータが保存や転送に適した形式に変換されます。直列化は、オブジェクトをファイルに保存したり、ネットワーク経由で送信したり、データベースに格納したりする際に使用されます。

非直列化(デシリアライズは、直列化されたデータを元のオブジェクトの形式に戻すプロセスです。直列化されたデータは、保存や転送のために変換された形式なので、非直列化を行うことで元のオブジェクトの状態やデータを復元することができます。

重みを用いる

W1, W2, W3 = network["W1"],network["W2"],network["W3"]
b1, b2, b3 = network["b1"],network["b2"],network["b3"] 

print(network["W1"])
print(network["b1"])
print(W1.shape, W2.shape, W3.shape)
print(b1.shape, b2.shape, b3.shape)

結果

[[-0.00741249 -0.00790439 -0.01307499 ...  0.01978721 -0.04331266
  -0.01350104]
 [-0.01029745 -0.01616653 -0.01228376 ...  0.01920228  0.02809811
   0.01450908]
 [-0.01309184 -0.00244747 -0.0177224  ...  0.00944778  0.01387301
   0.03393568]
 ...
 [ 0.02242565 -0.0296145  -0.06326169 ... -0.01012643  0.01120969
   0.01027199]
 [-0.00761533  0.02028973 -0.01498873 ...  0.02735376 -0.01229855
   0.02407041]
 [ 0.00027915 -0.06848375  0.00911191 ... -0.03183098  0.00743086
  -0.04021148]]
[-0.06750315  0.0695926  -0.02730473  0.02256093 -0.22001474 -0.22038847
  0.04862635  0.13499236  0.23342554 -0.0487357   0.10170191 -0.03076038
  0.15482435  0.05212503  0.06017235 -0.03364862 -0.11218343 -0.26460695
 -0.03323386  0.13610415  0.06354368  0.04679805 -0.01621654 -0.05775835
 -0.03108677  0.10366164 -0.0845938   0.11665157  0.21852103  0.04437255
  0.03378392 -0.01720384 -0.07383765  0.16152057 -0.10621249 -0.01646949
  0.00913961  0.10238428  0.00916639 -0.0564299  -0.10607515  0.09892716
 -0.07136887 -0.06349134  0.12461706  0.02242282 -0.00047972  0.04527043
 -0.15179175  0.10716812]
(784, 50) (50, 100) (100, 10)
(50,) (100,) (10,)

層の数

この時、ニューラルネットワークの層は次のようになります。

  • 入力層のノード数:784
  • 1つ目の中間層のノード数:50
  • 2つ目の中間層のノード数:100
  • 出力層のノード数:10

入力層のノード数

1つ目の重み行列W1の行数と一致するため、784です。これは、mnistなので28*28となっています。

1つ目の中間層のノード数

1つ目の重み行列W1の列数と2つ目の重み行列W2の行数と一致するため、50です。

2つ目の中間層のノード数

2つ目の重み行列W2の列数と3つ目の重み行列W3の行数と一致するため、100です。

出力層のノード数

3つ目の重み行列W3の列数と一致するため、10です。

youtu.be