Tensorflowで学習した結果を保存して、続きを学習することがあると思います。なかなかその内容が理解できなかったので、できるだけシンプルにそれを実現する方法について、紹介したいと思います。
まず、今回はTensorflowのチュートリアルで紹介されているmnistのサンプルを利用して、保存・再読み込みを実装します。mnistとは、手書きの数字を分類するためのデータセットです。
また、今回は分類するサンプルコードの中で、mnist_deep.pyを利用することにしました。このファイルの158行目以降に次のような変更を加えます。
import os MODEL_PATH = './tensorflowModel.ckpt' saver = tf.train.Saver() if os.path.exists(MODEL_PATH+'.meta'): saver.restore(sess, MODEL_PATH) saver.save(sess, MODEL_PATH)
ここでは、まず、モデルのパスをMODEL_PATHとして定義します。次にモデルを保存・再読み込みするインスタンスとして、saverを定義します。
次に保存したモデルが存在する場合、それを読み込む処理を追加します。さらに、学習終了時に保存する処理を追加します。
ここまでの変更結果はここにアップロードしています。