やりたいこと
sample_weight.pklの事前学習の重みを読み込みたい。
sample_weight.pklの例としてオライリーさんのGitHubにあるものを使います。 github.com
ファイルから読み込む
def init_network(): with open("sample_weight.pkl", "rb") as f: network = pickle.load(f) return network network = init_network() print(network.keys())
結果
dict_keys(['b2', 'W1', 'b1', 'W2', 'W3', 'b3'])
GoogleColabで実行する場合には図のように配置をしてください。
pickle.load
pickle.load
関数は、pickle形式のファイルからオブジェクトを読み込むために使用されます。具体的には、pickle.dump
関数によって保存されたオブジェクトを読み込み、Pythonのオブジェクトとして復元します。
pickle形式はPython固有のものであり、他のプログラミング言語との互換性はありません。また、信頼できないソースからのpickleファイルの読み込みは、セキュリティ上のリスクを引き起こす可能性があるため、注意が必要です。
補足 オブジェクトの直列化(シリアライズ)と非直列化(デシリアライズ)
直列化(シリアライズ)は、オブジェクトをそのままの形式ではなく、バイト列やテキストなどの直列化された形式に変換することです。このプロセスにより、オブジェクトの状態やデータが保存や転送に適した形式に変換されます。直列化は、オブジェクトをファイルに保存したり、ネットワーク経由で送信したり、データベースに格納したりする際に使用されます。
非直列化(デシリアライズ)は、直列化されたデータを元のオブジェクトの形式に戻すプロセスです。直列化されたデータは、保存や転送のために変換された形式なので、非直列化を行うことで元のオブジェクトの状態やデータを復元することができます。
重みを用いる
W1, W2, W3 = network["W1"],network["W2"],network["W3"] b1, b2, b3 = network["b1"],network["b2"],network["b3"] print(network["W1"]) print(network["b1"]) print(W1.shape, W2.shape, W3.shape) print(b1.shape, b2.shape, b3.shape)
結果
[[-0.00741249 -0.00790439 -0.01307499 ... 0.01978721 -0.04331266 -0.01350104] [-0.01029745 -0.01616653 -0.01228376 ... 0.01920228 0.02809811 0.01450908] [-0.01309184 -0.00244747 -0.0177224 ... 0.00944778 0.01387301 0.03393568] ... [ 0.02242565 -0.0296145 -0.06326169 ... -0.01012643 0.01120969 0.01027199] [-0.00761533 0.02028973 -0.01498873 ... 0.02735376 -0.01229855 0.02407041] [ 0.00027915 -0.06848375 0.00911191 ... -0.03183098 0.00743086 -0.04021148]] [-0.06750315 0.0695926 -0.02730473 0.02256093 -0.22001474 -0.22038847 0.04862635 0.13499236 0.23342554 -0.0487357 0.10170191 -0.03076038 0.15482435 0.05212503 0.06017235 -0.03364862 -0.11218343 -0.26460695 -0.03323386 0.13610415 0.06354368 0.04679805 -0.01621654 -0.05775835 -0.03108677 0.10366164 -0.0845938 0.11665157 0.21852103 0.04437255 0.03378392 -0.01720384 -0.07383765 0.16152057 -0.10621249 -0.01646949 0.00913961 0.10238428 0.00916639 -0.0564299 -0.10607515 0.09892716 -0.07136887 -0.06349134 0.12461706 0.02242282 -0.00047972 0.04527043 -0.15179175 0.10716812] (784, 50) (50, 100) (100, 10) (50,) (100,) (10,)
層の数
この時、ニューラルネットワークの層は次のようになります。
- 入力層のノード数:784
- 1つ目の中間層のノード数:50
- 2つ目の中間層のノード数:100
- 出力層のノード数:10
入力層のノード数
1つ目の重み行列W1の行数と一致するため、784です。これは、mnistなので28*28となっています。
1つ目の中間層のノード数
1つ目の重み行列W1の列数と2つ目の重み行列W2の行数と一致するため、50です。
2つ目の中間層のノード数
2つ目の重み行列W2の列数と3つ目の重み行列W3の行数と一致するため、100です。
出力層のノード数
3つ目の重み行列W3の列数と一致するため、10です。