频道栏目
首页 > 资讯 > 其他综合 > 正文

机器学习笔记6:TensorFlow入门之MNIST数据集训练

17-08-30        来源:[db:作者]  
收藏   我要投稿

机器学习笔记6:TensorFlow 入门之MNIST数据集训练

  本文的主要内容参考于TensorFlow中文社区内容,并在下面的文章中测试其中的样例代码。

  关于TensorFlow的基本使用方法以及综述和相关概念,在TensorFlow的中文社区的教程中的1.3节有详细的介绍。所以就在此不再记录。

  MNIST数据集的训练类似于编程过程中的 Hello World,关于MNIST是入门级别的计算机视觉数据集,其中包含各种手写数字图片:

 


这里写图片描述

 

  该数据集中每一个图片都有其中对应的标签(Label),用于告诉使用者对应的数字几,如上图中对应的数字为5,0,4,1。

  在此教程中,我们将训练一个机器学习模型用于预测图片里面的数字。我们的目的不是要设计一个世界一流的复杂模型 – 尽管我们会在之后给你源代码去实现一流的预测模型 – 而是要介绍下如何使用TensorFlow。所以,我们这里会从一个很简单的数学模型开始,它叫做Softmax Regression。

  提到Softmax回归函数,有个文章介绍了常用的激活函数及其比较:常用激活函数比较,其中较为详细的分析了Sigmoid,ReLU,Softmax三种函数以及其比较。

  对应这个教程的实现代码很短,而且真正有意思的内容只包含在三行代码里面。但是,去理解包含在这些代码里面的设计思想是非常重要的:TensorFlow工作流程和机器学习的基本概念。因此,这个教程会很详细地介绍这些代码的实现原理。

MNIST数据集

  社区提供了一份python源代码用于自动下载和安装这个数据集。

import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

整体代码及注释

# coding=utf-8
'''
Author:Chen hao
Description: Simple MINIST
Date: August 22 , 2017
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None


def main(_):
  # 导入数据
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

  # Create the model
  x = tf.placeholder(tf.float32, [None, 784])
        # a 2-D tensor of floating-point numbers
        # None means that a dimension can be of any length
  W = tf.Variable(tf.zeros([784, 10]))
  b = tf.Variable(tf.zeros([10]))
  y = tf.matmul(x, W) + b
        # It only takes one line to define it

  # Define loss and optimizer
  # y_表示一个样本的实际label
  y_ = tf.placeholder(tf.float32, [None, 10])

  # The raw formulation of cross-entropy,
  #
  #   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
  #                                 reduction_indices=[1]))
                # tf.reduce_sum adds the elements in the second dimension of y,
                # due to the reduction_indices=[1] parameter.
                # tf.reduce_mean computes the mean over all the examples in the batch.
  #
  # can be numerically unstable.
  #
  # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
  # outputs of 'y', and then average across the batch.

  # 用cross-entropy作为损失来衡量模型的误差
  cross_entropy = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
  # 然后使用梯度下降的方式来训练模型使得loss达到最小
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
        # apply your choice of optimization algorithm to modify the variables and reduce the loss.

  sess = tf.InteractiveSession()
        # launch the model in an InteractiveSession
  tf.global_variables_initializer().run()
        # create an operation to initialize the variables

  # Train~~stochastic training
  for _ in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
            # Each step of the loop,
            # we get a "batch" of one hundred random data points from our training set.
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

  # Test trained model
  # argmax函数可以给出某个tensor对象在某一维熵的其数据最大值所在的索引值
  # 标签向量是0,1组成,因此最大值1所在的索引位置就是类别标签值
  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            # use tf.equal to check if our prediction matches the truth
            # tf.argmax(y,1) is the label our model thinks is most likely for each input,
            # while tf.argmax(y_,1) is the correct label.
  # 将布尔值转化为0和1的表现形式
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            # [True, False, True, True] would become [1,0,1,1] which would become 0.75.
  print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                      y_: mnist.test.labels}))
            # ask for our accuracy on our test data,about 92%

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
                      help='Directory for storing input data')
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

  代码运行的结果为0.9195,每个人运行结果可能都有所不同。

  这个最终结果值应该大约是91%。

  这个结果好吗?嗯,并不太好。事实上,这个结果是很差的。这是因为我们仅仅使用了一个非常简单的模型。不过,做一些小小的改进,我们就可以得到97%的正确率。最好的模型甚至可以获得超过99.7%的准确率!

相关TAG标签
上一篇:Python 安装 第三方库的安装技巧
下一篇:类-成员变量和局部变量
相关文章
图文推荐

关于我们 | 联系我们 | 广告服务 | 投资合作 | 版权申明 | 在线帮助 | 网站地图 | 作品发布 | Vip技术培训 | 举报中心

版权所有: 红黑联盟--致力于做实用的IT技术学习网站