1. High level API checkpoints
只针对与 estimator
设置检查点的时间频率和总个数
my_checkpointing_config = tf.estimator.RunConfig( save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes. keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints.)
实例化时传递给 estimator 的 config 参数
model_dir 设置存储路径
classifier = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, hidden_units=[10, 10], n_classes=3, model_dir='models/iris', config=my_checkpointing_config)
一旦检查点文件存在,TensorFlow 总会在你调用 train()
、 evaluation()
或 predict()
时重建模型
------------------------------------------------------------------------------------------------------------
2.Low level API tf.train.Saver
-------------------------------------------------------------------------------------------------------------
Saver.save 存储 model 中的所有变量
import tensorflow as tf# 创建变量var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)# 添加初始化变量的操作init_op = tf.global_variables_initializer()# 添加保存和恢复这些变量的操作saver = tf.train.Saver()# 然后,加载模型,初始化变量,完成一些工作,并保存这些变量到磁盘中with tf.Session() as sess: sess.run(init_op) # 使用模型完成一些工作 var.op.run() # 将变量保存到磁盘中 save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in path: %s" % save_path)
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)# tf.get_variable: Gets an existing variable with these parameters or create a new one.# shape: Shape of the new or existing variable# initializer: Initializer for the variable if one is created. tf.zeros_initializer 赋值为0 [0 0 0]
saver = tf.train.Saver() # Saver 来管理模型中的所有变量,注意是所有变量
tf.Session() # A class for running TensorFlow operations.
with...as...#执行 with 后面的语句,如果可以执行则将赋值给 as 后的语句。如果出现错误则执行 with 后语句中的 __exit__#来报错。类似与 try if,但是更方便
Saver.save 选择性的存储变量
saver = tf.train.Saver({'var2':var2})
-------------------------------------------------------------------------------------------------------------
Saver.restore 加载路径中的所有变量
import tensorflow as tftf.reset_default_graph()# 创建一些变量var = tf.get_variable("var", shape=[3])# 添加保存和恢复这些变量的操作saver = tf.train.Saver()# 然后,加载模型,使用 saver 从磁盘中恢复变量,并使用变量完成一些工作with tf.Session() as sess: # 从磁盘中恢复变量 saver.restore(sess, "/tmp/model.ckpt") print("Model restored.") # 检查变量的值 print("var : %s" % var.eval())
-------------------------------------------------------------------------------------------------------------
inspector_checkpoint 检查存储的变量
加载 inspect_checkpoints
from tensorflow.python.tools import inspect_checkpoint as chkp
打印存储起来的所有变量
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True, all_tensor_names=False)
注意其中的参数 all_tensor_names 教程中并未添加这个参数,运行时持续报错 missing
打印制定的变量
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='var1', all_tensors=False, all_tensor_names=False)