当前位置:首页 » 《休闲阅读》 » 正文

Tensorflow 模型的保存和载入_tugouxp的专栏

17 人参与  2022年01月01日 08:02  分类 : 《休闲阅读》  评论

点击全文阅读


Minst训练模型源码:

import tensorflow as tf #导入tensorflow库
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util  
import pylab 

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784]) # mnist data维度 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 数字=> 10 classes

# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

# 构建模型
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分类

# Minimize error using cross entropy
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

#参数设置
learning_rate = 0.01
# 使用梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

training_epochs = 25
batch_size = 100
display_step = 1
saver = tf.train.Saver()
model_ckpt_path = "model/521model.ckpt"
model_pb_path = "model/521model.pb"

# 启动session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())# Initializing OP

    # 启动循环开始训练
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
        # 遍历全部数据集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
                                                          y: batch_ys})
            # Compute average loss
            avg_cost += c / total_batch
        # 显示训练中的详细信息
        if (epoch+1) % display_step == 0:
            print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))
            
        # print(sess.run(W))
        # print(sess.run(b))

    print( " Finished!")

    # 测试 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

    # Save model weights to disk
    save_path = saver.save(sess, model_ckpt_path)
    print("Model saved in file: %s" % save_path)
    
    graph_def = tf.get_default_graph().as_graph_def()
    # 保存对应计算节点的名称。
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
    # 将导出的模型存入文件中
    with tf.gfile.GFile(model_pb_path, "wb") as f:
        f.write(output_graph_def.SerializeToString())


#读取模型
print("Starting 2nd session...")
with tf.Session() as sess:
    # Initialize variables
    sess.run(tf.global_variables_initializer())
    # Restore model weights from previously saved model
    saver.restore(sess, model_ckpt_path)
    
     # 测试 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
    
    output = tf.argmax(pred, 1)
    batch_xs, batch_ys = mnist.train.next_batch(2)
    outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})
    print(outputval,predv,batch_ys)

    im = batch_xs[0]
    print(im.shape)
    im = im.reshape(-1,28)
    print(im.shape)
    pylab.imshow(im)
    pylab.show()
    
    im = batch_xs[1]
    print(im.shape)
    im = im.reshape(-1,28)
    print(im.shape)
    pylab.imshow(im)
    pylab.show()

运行结果

执行完后,可以看到,保存的模型


保存模型的方法:

保存ckpt格式的模型:

    # Save model weights to disk
    save_path = saver.save(sess, model_ckpt_path)
    print("Model saved in file: %s" % save_path)

保存pb格式的模型

    graph_def = tf.get_default_graph().as_graph_def()
    # 保存对应计算节点的名称。
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
    # 将导出的模型存入文件中
    with tf.gfile.GFile(model_pb_path, "wb") as f:
        f.write(output_graph_def.SerializeToString())

结束!


点击全文阅读


本文链接:http://m.zhangshiyu.com/post/32481.html

模型  保存  计算  
<< 上一篇 下一篇 >>

  • 评论(0)
  • 赞助本站

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。

最新文章

  • 捧一片星空新鲜出炉林溪傅迟宴完本_捧一片星空新鲜出炉(林溪傅迟宴)
  • 她在春日里沉眠高口碑(乔清瑜季泽珩)
  • 半堂花夜渡空城裴砚泽沈诺柠结局+番外(裴砚泽沈诺柠)_(半堂花夜渡空城裴砚泽沈诺柠结局+番外)列表_笔趣阁(裴砚泽沈诺柠)
  • 半堂花夜渡空城精编之作(裴砚泽沈诺柠)全书免费_(裴砚泽沈诺柠)半堂花夜渡空城精编之作后续(裴砚泽沈诺柠)
  • 半堂花夜渡空城结局+番外(裴砚泽沈诺柠)_(半堂花夜渡空城结局+番外)列表_笔趣阁(裴砚泽沈诺柠)
  • (番外)+(全书)顾裴延江照璃(长叹雁归难留+后续+结局)_顾裴延江照璃免费列表_笔趣阁(长叹雁归难留+后续+结局)
  • 全书浏览九幽不渡卿结局+番外(孟卿卿谢昭远)_九幽不渡卿结局+番外(孟卿卿谢昭远)全书结局
  • 往梦难复温+后续+结局(沈淮霆宋思予)_(沈淮霆宋思予)往梦难复温+后续+结局列表_笔趣阁(沈淮霆宋思予)
  • 她在春日里沉眠书结局+番外优质全章(乔清瑜季泽珩)_她在春日里沉眠书结局+番外优质全章乔清瑜季泽珩
  • 往梦难复温+后续+结局(沈淮霆宋思予)结局_(沈淮霆宋思予往梦难复温+后续+结局全书结局)结局列表_笔趣阁(沈淮霆宋思予)
  • 「世子养花,自当娇贵」节选角色羁绊特辑‌_沈梨裴辞最新章节免费阅读
  • 江照璃的长叹雁归难留后续在线顾裴延江照璃全书在线

    关于我们 | 我要投稿 | 免责申明

    Copyright © 2020-2022 ZhangShiYu.com Rights Reserved.豫ICP备2022013469号-1