Pylearn2 のお勉強 - 6時間目
Pylearn2 の tutorial でお勉強 - 5時間目 - まんぼう日記 のつづき.自分で Theano 使って convolutional net のプログラムを書くことを視野に入れて,Pylearn2 で学習したネットワークのパラメータを取り出して,自前でネットワーク出力を計算してみることにします.
学習したネットワークの重みを取り出す
例えば Pylearn2 の tutorial でお勉強 - 2時間目 - まんぼう日記 の Part 4 では,学習後のMLPの重み等は mlp_3_best.pkl という名前の pickle ファイルに保存されていて,print_model.py というスクリプトでその概要を得ることができます.このスクリプトの中身と Models — Pylearn2 dev documentation の記述から,MLP の場合は次のようなプログラムによって重みとバイアス項の値を Numpy の array として取り出せることが分かりました.
上記を実行すると,次のような出力が得られます.
$ python getweights1005.py h0 Input space: VectorSpace(dim=784, dtype=float64) Total input dimension: 784 h1 Input space: VectorSpace(dim=500, dtype=float64) Total input dimension: 500 y Input space: VectorSpace(dim=1000, dtype=float64) Total input dimension: 1000 layer 0 (784, 500) (500,) layer 1 (500, 1000) (1000,) layer 2 (1000, 10) (10,)
w と b は Numpy の array です.試しに GPU で学習して得られた pickle ファイルを読んでみたら,このプログラムでは読み出せませんでした.そちらはまた別の手を考えないといけないようです.
MNISTデータを得る
重みを取り出す方法がわかったので,次はデータの準備です.上記のネットワークは MNIST の手書き数字データを使って学習しています.自分でMNISTのデータを扱うプログラムを書いてもよいわけですが,Datasets — Pylearn2 dev documentation を見たら簡単に MNIST のデータが得られそうだったので,そっちで行くことにしました.
MNIST のデータ(テストデータの方)は次のようなプログラムで得られます.
実行結果はこんな感じ.
$ python data1005.py 10000 True data1005.py:10: UserWarning: Usage of `topo` and `target` arguments are being deprecated, and will be removed around November 7th, 2013. `data_specs` should be used instead. itr = data.iterator( 'sequential', batch_size = 3000, targets = True ) (3000, 784) (3000, 1) (3000, 784) (3000, 1) (3000, 784) (3000, 1) (1000, 784) (1000, 1)
変数 X が入力データ,Y が出力の正解.MNISTの場合,Yの中身は0から9の整数でした.上記ウェブページには targets = True とすると入力とともに出力の正解も取り出せるって書いてあったからそう書いたのに,警告されちゃった.ウェブページの記述は古いままみたいで,data_specs の使い方の説明が見つからへん.それから,1万x784くらいのサイズなら,num_batches = 1 として全体をひとかたまりで取り出したって構わへんわけですが,ここでは,小さいバッチに切り分けて取り出す練習のため,batch_size = 3000 とかしてみました.
自分でネットワーク出力を計算してみる
ここまでの結果から,次のようなプログラムを書いてみました.
実行結果は次の通り.
$ python hoge1005.py # num_layers = 3 num_units = 784 - 500 - 1000 - 10 hoge1005.py:71: UserWarning: Usage of `topo` and `target` arguments are being deprecated, and will be removed around November 7th, 2013. `data_specs` should be used instead. itr = data.iterator( 'sequential', num_batches = 1, targets = True ) # ndat = 10000 ndim = 784 # test_y_misclass: 0.0165
テストデータでの誤識別率 test_y_misclass の値は,print_monitor.py が出力する値と同じになりました.めでたしめでたし.
次は Theano 使ってネットワーク出力を計算するプログラム書いてみるかな.