频道栏目
首页 > 资讯 > 其他 > 正文

tensorflow API:tf.train.Saver 与 NotFoundError: "x_x" not found in checkpoint

18-04-26        来源:[db:作者]  
收藏   我要投稿

保存

import所需的模块, 然后建立神经网络当中的 W 和 b, 并初始化变量.

import tensorflow as tf
import numpy as np
## Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
# init= tf.initialize_all_variables() # tf 马上就要废弃这种写法
# 替换成下面的写法:
init = tf.global_variables_initializer()

保存时, 首先要建立一个 tf.train.Saver() 用来保存, 提取变量.
Saver的例子:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# dict形式传递
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# list形式传递
saver = tf.train.Saver([v1, v2])
#等价于以创建变量时取的op.name名字做为dict的key:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})

再创建一个名为my_net的文件夹, 用这个 saver 来保存变量到这个目录 “my_net/save_net.ckpt”.

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess, "my_net/save_net.ckpt")
    print("Save to path: ", save_path)

"""    
Save to path:  my_net/save_net.ckpt
"""

直接这样运行报错:
原因: 保存和加载 在前后进行,在前后两次定义了

W = tf.Variable(xxx,name=”weight”)

相当于 在TensorFlow 图的堆栈创建了两次 name = “weight” 的变量,第二个(第n个)的实际 name 会变成 “weight_1” (“weight_n-1”),之后我们在保存 checkpoint 中实际搜索的是 “weight_n-1” 这个变量 而不是 “weight” ,因此就会出错。

解决方案:
(1)在加载过程中,定义 name 相同的变量前面加
tf.reset_default_graph() 清除默认图的堆栈,并设置全局图为默认图

相关TAG标签
上一篇:HTML5下绘制动画代码教程
下一篇:HTML5下Canvas save怎么保存恢复状态?
相关文章
图文推荐

关于我们 | 联系我们 | 广告服务 | 投资合作 | 版权申明 | 在线帮助 | 网站地图 | 作品发布 | Vip技术培训 | 举报中心

版权所有: 红黑联盟--致力于做实用的IT技术学习网站