卷积神经网络CNN识别MNIST数据集

这次我们将建立一个卷积神经网络,它可以把MNIST手写字符的识别准确率提升到99%,读者可能需要一些卷积神经网络的基础知识才能更好的理解本节的内容。

程序的开头是导入TensorFlow:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist  import input_data

接下来载入MNIST数据集,并建立占位符。占位符x的含义为训练图像,y_为对应训练图像的标签。

# 读入数据
mnist  = input_data.read_data_sets( "MNIST_data/" , one_hot = True )
# x为训练图像的占位符,y_为训练图像标签的占位符
x  = tf.placeholder(tf.float32, [ None ,  784 ])
y_  = tf.placeholder(tf.float32, [ None ,  10 ])

运行后会在当前目录下得到一个名为MINST_data的数据集。如下图所示

由于使用的是卷积神经网络对图像进行分类,所以不能再使用784维的向量表示输入的x,而是将其还原为28*28的图片形式。[-1,28,28,1]中的-1表示形状第一维的大小是根据x自动确定的。

# 将单张图片从784维向量重新还原为28*28的矩阵图片
x_image  = tf.reshape(x, [ - 1 ,  28 ,  28 ,  1 ])

x_image就是输入的训练图像,接下来,我们对训练图像进行卷积计算,第一层卷积的代码如下:

def weight_variable(shape):
    initial  = tf.truncated_normal(shape, stddev = 0.1 )
    return tf.Variable(initial)

def bias_variable(shape):
    initial  = tf.constant( 0.1 , shape = shape)
    return tf.Variable(initial)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides = [ 1 ,  1 ,  1 ,  1 ], padding = 'SAME' )

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize = [ 1 ,  2 ,  2 ,  1 ], strides = [ 1 ,  2 ,  2 ,  1 ], padding = 'SAME' )

# 第一层卷积层
W_conv1  = weight_variable([ 5 ,  5 ,  1 ,  32 ])
b_conv1  = bias_variable([ 32 ])
h_conv1  = tf.nn.relu(conv2d(x_image, W_conv1)  + b_conv1)
h_pool1  = max_pool_2x2(h_conv1)

首先定义了四个函数,函数weight_variable可以返回一个给定形状的变量,并自动以截断正态分布初始化,bias_variable同样返回一个给定形状的变量,初始化所有值是0.1,可分别用这两个函数创建卷积的核(kernel)与偏置(bias)。h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)是真正进行卷积运算,卷积计算后选用ReLU作为激活函数。h_pool1 = max_pool_2x2(h_conv1)是调用函数max_pool_2x2进行一次池化操作。卷积、激活函数、池化,可以说是一个卷积层的“标配”,通常一个卷积层都会包含这三个步骤,有时也会去掉最后的池化操作。

对第一次卷积操作后产生的h_pool1再做一次卷积计算,使用的代码与上面类似。

# 第二层卷积
W_conv2  = weight_variable([ 5 ,  5 ,  32 ,  64 ])
b_conv2  = bias_variable([ 64 ])
h_conv2  = tf.nn.relu(conv2d(h_pool1, W_conv2)  + b_conv2)
h_pool2  = max_pool_2x2(h_conv2)

两层卷积层之后是全连接层:

# 全连接层,输出为1024维的向量
W_fc1  = weight_variable([ 7 * 7 * 64 ,  1024 ])
b_fc1  = bias_variable([ 1024 ])
h_pool2_flat  = tf.reshape(h_pool2, [ - 1 ,  7 * 7 * 64 ])
h_fc1  = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1)  + b_fc1)
# 使用Dropout,keep_prob是一个占位符,训练时为0.5,测试时为1
keep_prob  = tf.placeholder(tf.float32)
h_fc1_drop  = tf.nn.dropout(h_fc1, keep_prob)

在全连接层中加入了Dropout,它是防止神经网络过拟合的一种手段。在每一步训练时,以一定概率“去掉”网络中的某些连接,但这种去除不是永久性的,只是在当前步骤中去除,并且每一步去除的连接都是随机选择的。在这个程序中,选择的Dropout概率是0.5,也就是说训练时每一个连接都有50%的概率被去除。在测试时保留所有连接。

最后,再加入一层全连接,把上一步得到的h_fc1_drop转换为10个类别的打分。

# 把1024维的向量转换为10维,对应10个类别
W_fc2  = weight_variable([ 1024 ,  10 ])
b_fc2  = bias_variable([ 10 ])
y_conv  = tf.matmul(h_fc1_drop, W_fc2)  + b_fc2

y_conv相当于Softmax模型中的Logit,当然可以使用Softmax函数将其转换为10个类别的概率,再定义交叉熵损失。但其实TensorFlow提供了一个更直接的tf.nn.softmax_cross_entropy_with_logits函数,它可以直接对Logit定义交叉熵损失,写法为:

# 不采用先softmax再计算交叉熵的方法
# 而是采用tf.nn.softmax_cross_entropy_with_logits直接计算
cross_entropy  = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y_, logits = y_conv))
# 同样定义train_step
train_step  = tf.train.AdamOptimizer( 1e - 4 ).minimize(cross_entropy)

定义测试的准确率

# 定义测试的准确率
correct_prediction  = tf.equal(tf.argmax(y_conv,  1 ), tf.argmax(y_,  1 ))
accuracy  = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

在控制台显示在验证集上训练时模型的准确度,方便监控训练的进度,也可以据此来调整模型的参数。

# 创建Session,对变量初始化
sess  = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

# 训练20000步
for i  in range ( 20000 ):
    batch  = mnist.train.next_batch( 50 )
    # 每100步报告一次在验证集上的准确率
    if i  % 100 = = 0 :
        train_accuracy  = accuracy. eval (feed_dict = {
            x: batch[ 0 ], y_: batch[ 1 ], keep_prob:  1.0
        })
        print ( "step %d,training accuracy %g" % (i, train_accuracy))
    train_step.run(feed_dict = {x: batch[ 0 ], y_: batch[ 1 ], keep_prob:  0.5 })

训练结束后,打印在全体测试集上的准确率:

# 训练结束后报告在测试集上的准确率
print ( "test accuracy %g" % accuracy. eval (feed_dict = {
    x: mnist.test.images, y_: mnist.test.labels, keep_prob:  1.0
}))

最后得到的结果在控制台显示为

可以最终测试得到的准确率结果应该在99%左右。与Softmax回归模型相比,使用两层卷积的神经网络模型借助了卷积的威力,准确率有非常大的提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/1166.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言老题新解16-20 用命令行打印一些图案

文章目录11 打印字母C12 输出国际象棋棋盘。13 打印楼梯,同时在楼梯上方打印两个笑脸。14 输出9*9 口诀。15 有一道题要输出一个图形,然后Very Beautiful。11 打印字母C 11 用*号输出字母C的图案。 讲道理这绝对不该是个新人能整出来的活儿&#xff0c…

TCP/IP协议栈之数据包如何穿越各层协议(绝对干货)

所有互联网服务,均依赖于TCP/IP协议栈。懂得数据是如何在协议栈传输的,将会帮助你提升互联网程序的性能和解决TCP相关问题的能力。 我们讲述在Linux场景下数据包是如何在协议层传输的。 1、发送数据 应用层发送数据的过程大致如下: 我们把…

蓝桥杯嵌入式第五课--输入捕获

前言输入捕获的考题十分明确,就是测量输入脉冲波形的占空比和频率,对我们的板子而言,就是检测板载的两个信号发生器产生的信号:具体来说就是使用PA15和PB4来做输入捕获。输入捕获原理简介输入捕获能够对输入信号的上升沿和下降沿进…

WorkTool企微机器人接入智能问答

一、前言 最新版的企微机器人已经集成 Chat ,无需开发可快速搭建智能对话机器人。 从官方介绍看目前集成版本使用模型为 3.5-turbo。 二、入门 创建 WorkTool 机器人 你可以通过这篇快速入门教程,来快速配置一个自己的企微机器人。 实现的流程如图&…

Windows与Linux端口占用、查看的方法总结

Windows与Linux端口占用、查看的方法总结 文章目录Windows与Linux端口占用、查看的方法总结一、Windows1.1Windows查看所有的端口1.2查询指定的端口占用1.3查询PID对应的进程1.4查杀死/结束/终止进程二、Linux2.1lsof命令2.2netstat命令一、Windows 1.1Windows查看所有的端口 …

基于GPT-4的免费代码生成工具

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

SpringCloud五大核心组件

Consul 等,提供了搭建分布式系统及微服务常用的工具,如配置管理、服务发现、断路器、智能路由、微代理、控制总线、一次性token、全局锁、选主、分布式会话和集群状态等,满足了构建微服务所需的所有解决方案。 服务发现——Netflix Eureka …

7个最受欢迎的Python库,大大提高开发效率

当第三方库可以帮我们完成需求时,就不要重复造轮子了 整理了GitHub上7个最受好评的Python库,将在你的开发之旅中提供帮助 PySnooper 很多时候时间都花在了Debug上,大多数人呢会在出错位置的附近使用print,打印某些变量的值 这个…

算法竞赛必考算法——动态规划(01背包和完全背包)

动态规划(一) 目录动态规划(一)1.01背包问题1.1题目介绍1.2思路一介绍(二维数组)1.3思路二介绍(一维数组) 空间优化1.4思路三介绍(输入数据优化)2.完全背包问题2.1题目描述:2.2思路一(朴素算法)2.3思路二(将k优化处理掉)2.4思路三(优化j的初始条件)总结1.01背包问题…

Spring Cloud Alibaba全家桶(四)——微服务调用组件Feign

前言 本文小新为大家带来 微服务调用组件Feign 的相关知识,具体内容包含什么是Feign,Spring Cloud Alibaba快速整合OpenFeign,Spring Cloud Feign的自定义配置及使用(包括:日志配置、契约配置、自定义拦截器实现认证逻…

Autosar-ComM浅谈

文章目录 一、ComM概述二、和其他模块的依赖关系三、ComM通道状态机ComM模式与通讯能力关系表四、ComM中的PNC一、ComM概述 ComM全称是Communication Manager,顾名思义就是通信的管理,是BSW(基本软件)服务层的一个组件。 ComM的作用: 为用户简化Communication Stack的使用…

中断控制器

在Linux内核中,各个设备驱动可以简单地调用request_irq()、enable_irq()、disable_irq()、 local_irq_disable()、local_irq_enable()等通用API来…

STM32----MPU6050

前言:最近几个月没有写文章了,因为这学期的事情真的有点多,但是想了想,文章还是要更新,总结自己学习的知识,真的很重要!!! 废话不多说,正文开始:…

【vue.js】在网页中实现一个金属抛光质感的按钮

文章目录前言效果电脑效果手机效果说明完整代码index.html前言 诶?这有一个按钮(~ ̄▽ ̄)~,这是一个在html中实现的具有金属质感并且能镜面反射的按钮~ 效果 电脑效果 手机效果 说明 主要思路是使用 navig…

【算法基础】二分图(染色法 匈牙利算法)

一、二分图 1. 染色法 一个图是二分图,当且仅当,图中不含奇数环。在判别一个图是否为二分图⑩,其实相当于染色问题,每条边的两个点必须是不同的颜色,一共有两种颜色,如果染色过程中出现矛盾,则说明不是二分图。 for i = 1 to n:if i 未染色DFS(i, 1); //将i号点染色未…

Leetcode138. 复制带随机指针的链表

复制带随机指针的链表 第一步 拷贝节点链接在原节点的后面 第二步拷贝原节点的random , 拷贝节点的 random 在原节点 random 的 next 第三步 将拷贝的节点尾插到一个新链表 ,并且将原链表恢复 从前往后遍历链表 ,将原链表的每个节点进行复制,并l链接到原…

【STL二】STL序列式容器(array、vector、deque、list、forward_list)

【STL二】STL序列式容器&#xff08;array、vector、deque、list、forward_list&#xff09;1.array<T,N>&#xff08;数组容器&#xff09;2.vector<T>&#xff08;向量容器&#xff09;3.deque<T>&#xff08;双端队列容器&#xff09;&#xff1a;4.list&…

第一个 Qt 程序

第一个 Qt 程序 “hello world ”的起源要追溯到 1972 年&#xff0c;贝尔实验室著名研究员 Brian Kernighan 在撰写 “B 语言教程与指导(Tutorial Introduction to the Language B)”时初次使用&#xff08;程序&#xff09;&#xff0c;这是目前已 知最早的在计算机著作中将…

用sql计算两个经纬度坐标距离(米数互转)

目录 一、sql示例&#xff08;由近到远&#xff09; 二 、参数讲解 三、查询效果 - 距离&#xff08;公里 / 千米&#xff09; 四、查询效果 - 距离&#xff08;米&#xff09; 五、距离四舍五入保留后2位小数&#xff08;java&#xff09; 一、sql示例&#xff08;由近到远…

2023年最新最全 VSCode 插件推荐

Visual Studio Code 是由微软开发的一款免费的、针对于编写现代Web和云应用的跨平台源代码编辑器。它包含了一个丰富的插件市场&#xff0c;提供了很多实用的插件。下面就来分享 2023 年前端必备的 VS Code 插件&#xff01; 前端框架 ES7 React/Redux/React-Native snippets …
最新文章