Tensorflowでの学習結果の保存と再読み込みについて

Tensorflowで学習した結果を保存して、続きを学習することがあると思います。なかなかその内容が理解できなかったので、できるだけシンプルにそれを実現する方法について、紹介したいと思います。
まず、今回はTensorflowのチュートリアルで紹介されているmnistのサンプルを利用して、保存・再読み込みを実装します。mnistとは、手書きの数字を分類するためのデータセットです。
また、今回は分類するサンプルコードの中で、mnist_deep.pyを利用することにしました。このファイルの158行目以降に次のような変更を加えます。

ここでは、まず、モデルのパスをMODEL_PATHとして定義します。次にモデルを保存・再読み込みするインスタンスとして、saverを定義します。

次に保存したモデルが存在する場合、それを読み込む処理を追加します。さらに、学習終了時に保存する処理を追加します。
ここまでの変更結果はここにアップロードしています。
参考サイト
スポンサーリンク
スポンサーサイト
スポンサーサイト

シェアする

  • このエントリーをはてなブックマークに追加

フォローする

スポンサーリンク
スポンサーサイト