HaneCa

独り立ちへ向けた長い道

Tensorflowでの学習結果の保存と再読み込みについて

投稿日: 2018年10月9日 最終更新日: 2019年1月1日
Tensorflowで学習した結果を保存して、続きを学習することがあると思います。なかなかその内容が理解できなかったので、できるだけシンプルにそれを実現する方法について、紹介したいと思います。
まず、今回はTensorflowのチュートリアルで紹介されているmnistのサンプルを利用して、保存・再読み込みを実装します。mnistとは、手書きの数字を分類するためのデータセットです。
また、今回は分類するサンプルコードの中で、mnist_deep.pyを利用することにしました。このファイルの158行目以降に次のような変更を加えます。
  import os
  MODEL_PATH = './tensorflowModel.ckpt'
  saver = tf.train.Saver()
    if os.path.exists(MODEL_PATH+'.meta'):
      saver.restore(sess, MODEL_PATH)

    saver.save(sess, MODEL_PATH)

ここでは、まず、モデルのパスをMODEL_PATHとして定義します。次にモデルを保存・再読み込みするインスタンスとして、saverを定義します。

次に保存したモデルが存在する場合、それを読み込む処理を追加します。さらに、学習終了時に保存する処理を追加します。
ここまでの変更結果はここにアップロードしています。

参考サイト

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください