博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow 学习初步- 变量,模型的存储和读取
阅读量:6906 次
发布时间:2019-06-27

本文共 2861 字,大约阅读时间需要 9 分钟。

hot3.png

1. High level API checkpoints

只针对与 estimator

设置检查点的时间频率和总个数

my_checkpointing_config = tf.estimator.RunConfig(    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.)

实例化时传递给 estimator 的 config 参数

model_dir 设置存储路径

classifier = tf.estimator.DNNClassifier(    feature_columns=my_feature_columns,    hidden_units=[10, 10],    n_classes=3,    model_dir='models/iris',    config=my_checkpointing_config)

一旦检查点文件存在,TensorFlow 总会在你调用 train() 、 evaluation() 或 predict() 时重建模型

------------------------------------------------------------------------------------------------------------

2.Low level API tf.train.Saver

-------------------------------------------------------------------------------------------------------------

Saver.save 存储 model 中的所有变量

import tensorflow as tf# 创建变量var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)# 添加初始化变量的操作init_op = tf.global_variables_initializer()# 添加保存和恢复这些变量的操作saver = tf.train.Saver()# 然后,加载模型,初始化变量,完成一些工作,并保存这些变量到磁盘中with tf.Session() as sess:  sess.run(init_op)  # 使用模型完成一些工作  var.op.run()  # 将变量保存到磁盘中  save_path = saver.save(sess, "/tmp/model.ckpt")  print("Model saved in path: %s" % save_path)
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)# tf.get_variable: Gets an existing variable with these parameters or create a new one.# shape: Shape of the new or existing variable# initializer: Initializer for the variable if one is created. tf.zeros_initializer 赋值为0 [0 0 0]
saver = tf.train.Saver() # Saver 来管理模型中的所有变量,注意是所有变量
tf.Session() # A class for running TensorFlow operations.
with...as...#执行 with 后面的语句,如果可以执行则将赋值给 as 后的语句。如果出现错误则执行 with 后语句中的 __exit__#来报错。类似与 try if,但是更方便

Saver.save 选择性的存储变量

saver = tf.train.Saver({'var2':var2})

-------------------------------------------------------------------------------------------------------------

Saver.restore 加载路径中的所有变量

import tensorflow as tftf.reset_default_graph()# 创建一些变量var = tf.get_variable("var", shape=[3])# 添加保存和恢复这些变量的操作saver = tf.train.Saver()# 然后,加载模型,使用 saver 从磁盘中恢复变量,并使用变量完成一些工作with tf.Session() as sess:  # 从磁盘中恢复变量  saver.restore(sess, "/tmp/model.ckpt")  print("Model restored.")  # 检查变量的值  print("var : %s" % var.eval())

-------------------------------------------------------------------------------------------------------------

inspector_checkpoint 检查存储的变量

加载 inspect_checkpoints

from tensorflow.python.tools import inspect_checkpoint as chkp

打印存储起来的所有变量

chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True, all_tensor_names=False)

注意其中的参数 all_tensor_names 教程中并未添加这个参数,运行时持续报错 missing

打印制定的变量

chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='var1', all_tensors=False, all_tensor_names=False)

 

转载于:https://my.oschina.net/u/2362565/blog/1802226

你可能感兴趣的文章
新手文档 - 你应该知道什么
查看>>
模拟——软件——认知——终极无线电简述
查看>>
Pro*C 内嵌SQL
查看>>
spring cloud config client refresh过程
查看>>
深入浅出Future Pattern
查看>>
微信公众平台企业号回调模式的URL验证
查看>>
平台常用函数介绍
查看>>
公司讲座
查看>>
惆怅,诸事不顺
查看>>
Lambda架构与推荐在电商网站实践
查看>>
Docker Swarm与Apache Mesos的区别
查看>>
消息中间件保证消息一致性解决方案
查看>>
java内嵌浏览器DJNativeSwing
查看>>
【Git入门之七】Git和Github
查看>>
ActiveMQ学习笔记(4)——通过ActiveMQ收发消息
查看>>
Spring3 MyBatis3 日志配置
查看>>
Php学习
查看>>
寓意很深刻的故事
查看>>
Confluence 6 权限概述
查看>>
Android小白的探索:2D绘图之Android简易版Microsoft Visio学习之路 三、装饰者模式...
查看>>