CNN对 MNIST 数据库中的图像进行分类

加载 MNIST 数据库

MNIST 是机器学习领域最著名的数据集之一。

  • 它有 70,000 张手写数字图像 - 下载非常简单 - 图像尺寸为 28x28 - 灰度图
from keras.datasets import mnist

# 使用 Keras 导入MNIST 数据库
(X_train, y_train), (X_test, y_test) = mnist.load_data()

print("The MNIST database has a training set of %d examples." % len(X_train))
print("The MNIST database has a test set of %d examples." % len(X_test))

 将前六个训练图像可视化

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm as cm
import numpy as np

# 绘制前六幅训练图像
fig = plt.figure(figsize=(20,20))
for i in range(6):
    ax = fig.add_subplot(1, 6, i+1, xticks=[], yticks=[])
    ax.imshow(X_train[i], cmap='gray')
    ax.set_title(str(y_train[i]))

查看图像的更多细节 

def visualize_input(img, ax):
    ax.imshow(img, cmap='gray')
    width, height = img.shape
    thresh = img.max()/2.5
    for x in range(width):
        for y in range(height):
            ax.annotate(str(round(img[x][y],2)), xy=(y,x),
                        horizontalalignment='center',
                        verticalalignment='center',
                        color='white' if img[x][y]<thresh else 'black')

fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
visualize_input(X_train[0], ax)

 预处理输入图像:通过将每幅图像中的每个像素除以 255 来调整图像比例

# 调整比例,使数值在 0 - 1 范围内 [0,255] --> [0,1]
X_train = X_train.astype('float32')/255
X_test = X_test.astype('float32')/255 

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

 对标签进行预处理:使用单热方案对分类整数标签进行编码

from keras.utils import to_categorical

num_classes = 10 
# 打印前十个(整数值)训练标签
print('Integer-valued labels:')
print(y_train[:10])

# 对标签进行一次性编码
# 将类别向量转换为二进制类别矩阵
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

# 打印前十个(单次)训练标签
print('One-hot labels:')
print(y_train[:10])

 重塑数据以适应我们的 CNN(和 input_shape)

# 输入图像尺寸为 28x28 像素的图像。
img_rows, img_cols = 28, 28

X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

print('input_shape: ', input_shape)
print('x_train shape:', X_train.shape)

定义模型架构

您必须传递以下参数:

  • filters - 滤波器的数量。
  • kernel_size - 指定(正方形)卷积窗口高度和宽度的数值。

还有一些额外的、可选的参数需要调整:

  • strides - 卷积的步长。如果不指定任何参数,strides 将设为 1。
  • padding - "有效 "或 "相同 "之一。如果不做任何指定,padding 将设置为 "有效"。
  • activation - 通常为 "relu"。如果不指定任何内容,则不会应用激活。我们强烈建议你为网络中的每个卷积层添加 ReLU 激活函数。

 需要注意的事项

  • 始终为 CNN 中的 Conv2D 层添加 ReLU 激活函数。除网络中的最后一层外,密集层也应具有 ReLU 激活函数。
  • 在构建分类网络时,网络的最终层应是具有 softmax 激活函数的密集层。最终层的节点数应等于数据集中的类总数。
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

# 创建模型对象
model = Sequential()

# CONV_1: 添加 CONV 层,采用 RELU 激活,深度 = 32 内核
model.add(Conv2D(32, kernel_size=(3, 3), padding='same',activation='relu',input_shape=(28,28,1)))
# POOL_1: 对图像进行下采样,选择最佳特征
model.add(MaxPooling2D(pool_size=(2, 2)))

# CONV_2: 在这里,我们将深度增加到 64
model.add(Conv2D(64, (3, 3),padding='same', activation='relu'))
# POOL_2: more downsampling
model.add(MaxPooling2D(pool_size=(2, 2)))

# 由于维度过多,我们只需要一个分类输出
model.add(Flatten())

# FC_1: 完全连接,获取所有相关数据
model.add(Dense(64, activation='relu'))

# FC_2: 输出软最大值,将矩阵压制成 10 个类别的输出概率
model.add(Dense(10, activation='softmax'))

model.summary()

需要注意的事项:
  • 网络以两个卷积层的序列开始,然后是最大池化层。
  • 最后一层为数据集中的每个对象类别设置了一个条目,并具有软最大激活函数,因此可以返回概率。
  • Conv2D 深度从输入层的 1 增加到 32 到 64。
  • 我们还想减少高度和宽度--这就是 maxpooling 的作用所在。请注意,在池化层之后,图像尺寸从 28 减小到 14。
  • 可以看到,每个输出形状都用 None 代替了批量大小。这是为了便于在运行时更改批次大小。
  • 最后,我们会添加一个或多个全连接层来确定图像中包含的对象。例如,如果在上一个最大池化层中发现了车轮,那么这个 FC 层将转换该信息,以更高的概率预测图像中出现了一辆汽车。如果图像中有眼睛、腿和尾巴,那么这可能意味着图像中有一只狗。

编译模型

# rmsprop 和自适应学习率 (adaDelta) 是梯度下降的流行形式,仅次于 adam 和 adagrad
# 因为我们有多个类别 (10)

# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', 
              metrics=['accuracy'])

训练模型

from keras.callbacks import ModelCheckpoint   

# 训练模型
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, 
                               save_best_only=True)
hist = model.fit(X_train, y_train, batch_size=32, epochs=12,
          validation_data=(X_test, y_test), callbacks=[checkpointer], 
          verbose=2, shuffle=True)

 在验证集上加载分类准确率最高的模型

# 加载能获得最佳验证精度的权重
model.load_weights('model.weights.best.hdf5')

计算测试集的分类准确率 

# 评估测试的准确性
score = model.evaluate(X_test, y_test, verbose=0)
accuracy = 100*score[1]

# 打印测试精度
print('Test accuracy: %.4f%%' % accuracy)

 

注意事项:

MLP 和 CNN 通常不会产生可比较的结果。MNIST 数据集非常特别,因为它非常干净,而且经过了完美的预处理。例如,所有图像大小相同,并以 28x28 像素网格为中心。如果数字稍有偏斜或不居中,这项任务就会难得多。对于真实世界中杂乱无章的图像数据,CNN 将真正超越 MLP。

为了直观地了解为什么会出现这种情况,要将图像输入 MLP,首先必须将图像转换为矢量。然后,MLP 会将图像视为没有特殊结构的简单数字向量。它不知道这些数字原本是按空间网格排列的。

相比之下,CNN 的设计目的完全相同,即处理多维数据中的模式。与 MLP 不同的是,CNN 知道,相距较近的图像像素比相距较远的像素关系密切。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/193072.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

许战海战略文库|主品牌升级为产业技术品牌,引领企业全球化发展

在当今高速发展的全球经济中&#xff0c;企业品牌已经成为其核心资产之一。这不仅仅是因为品牌可以为消费者带来识别度&#xff0c;更重要的是&#xff0c;它们可以为企业带来深厚的竞争壁垒。但对于许多企业来说&#xff0c;特别是技术密集型企业&#xff0c;仅仅依靠主品牌的…

Maven总结

文章目录 为什么学习Maven?一、Maven项目架构管理工具二、Maven的下载安装及配置1.maven的下载2.maven目录结构3.配置阿里云镜像和本地仓库:4.maven配置环境变量。5.阿里云镜像和本地仓库说明 三、idea中maven的操作1.以模板的形式创建maven项目2.其他配置maven的方式3.不勾模…

竞赛选题 题目:基于机器视觉opencv的手势检测 手势识别 算法 - 深度学习 卷积神经网络 opencv python

文章目录 1 简介2 传统机器视觉的手势检测2.1 轮廓检测法2.2 算法结果2.3 整体代码实现2.3.1 算法流程 3 深度学习方法做手势识别3.1 经典的卷积神经网络3.2 YOLO系列3.3 SSD3.4 实现步骤3.4.1 数据集3.4.2 图像预处理3.4.3 构建卷积神经网络结构3.4.4 实验训练过程及结果 3.5 …

linux 搭建Nginx网页(编译安装)

♥️作者&#xff1a;小刘在C站 ♥️个人主页&#xff1a; 小刘主页 ♥️不能因为人生的道路坎坷,就使自己的身躯变得弯曲;不能因为生活的历程漫长,就使求索的 脚步迟缓。 ♥️学习两年总结出的运维经验&#xff0c;以及思科模拟器全套网络实验教程。专栏&#xff1a;云计算技…

ABAP: JSON 报文解析——/ui2/cl_json

1、JSON数组 报文格式如下&#xff0c;是JSON 数组类型的。 [{"I_TYPE":"V","I_BUSINESSSCOPE":"1001"},{"I_TYPE":"V","I_BUSINESSSCOPE":"1002"} ] json转换为SAP内表&#xff1a; TYP…

分割回文串

题目链接 分割回文串 题目描述 注意点 s 仅由小写英文字母组成返回 s 保证每个子串都是回文串所有可能的分割方案 解答思路 从左到右将字符串进行分割&#xff0c;分割左侧部分判断是否是回文子串&#xff0c;如果不是说明不满足题意可以忽略&#xff1b;如果是则可以对右…

数字营销:概述和类型

数字营销无处不在。公司已经开始采用密集的数字营销活动来接触目标受众。从社交媒体句柄到网站&#xff0c;数字营销彻底改变了互联网时代产品和服务的营销和推广方式。本文将详细讨论数字营销的范围和类型。 什么是数字营销&#xff1f; 数字营销使用社交媒体、电子邮件、网…

逆袭之战,线下门店如何在“?”萧条的情况下实现爆发增长?

未来几年&#xff0c;商业走势将受到全球经济形势、科技进步和消费者需求变化等多种因素的影响。随着经济复苏和消费者信心提高&#xff0c;消费市场将继续保持增长&#xff0c;品质化、个性化、智能化等将成为消费趋势。同时&#xff0c;线上购物将继续保持快速增长&#xff0…

Python编程基础

Python是一种简单易学的编程语言&#xff0c;广泛应用于Web开发、数据分析、人工智能等领域。无论您是初学者还是有一定编程经验的人士&#xff0c;都可以从Python的基础知识开始建立自己的编程技能。 目录 理论Python语言的发展程序设计语言的分类静态语言与脚本语言的区别 代…

高精度基准电压源测试方法有哪些

高精度基准电压源是一种能够产生稳定、可控的电压信号的设备&#xff0c;广泛应用于科学研究、工业检测和仪器仪表校准等领域。为了保证电压信号的准确性和可靠性&#xff0c;在使用高精度基准电压源进行测试时&#xff0c;需要采取一系列的测试方法和技术手段。 校准和验证是使…

使用群晖Synology Office提升生产力:如何多人同时编辑一个文件

使用群晖Synology Office提升生产力&#xff1a;多人同时编辑一个文件 正文开始前给大家推荐个网站&#xff0c;前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。 文章目录 使用群晖Synol…

【虚拟机Ubuntu 18.04配置网络】

虚拟机Ubuntu 18.04配置网络 1.配置网络连接方式,查看自己网关 2.修改主机名 3.修改系统配置1.配置网络连接方式,查看自己网关 选择虚拟机镜像设置网络连接模式,可以选择桥接或者NAT连接(我这里选择是NAT连接) 确定自己网关&#xff0c;可以在虚拟机 -》 编辑 -》虚拟网络编…

vue3实现element table缓存滚动条

背景 对于后台管理系统&#xff0c;数据的展示形式大多都是通过表格&#xff0c;常常会出现的一种场景&#xff0c;从表格跳到二级页面&#xff0c;再返回上一页时&#xff0c;需要缓存当前的页码和滚动条的位置&#xff0c;以为使用keep-alive就能实现这两种诉求&#xff0c;…

Uni-app智慧工地可视化信息云平台源码

智慧工地的核心是数字化&#xff0c;它通过传感器、监控设备、智能终端等技术手段&#xff0c;实现对工地各个环节的实时数据采集和传输&#xff0c;如环境温度、湿度、噪音等数据信息&#xff0c;将数据汇集到云端进行处理和分析&#xff0c;生成各种报表、图表和预警信息&…

CTF图片隐写

1.题目给出的zip文件给出提示如下。 2.用 ARCHPR爆破出密码。 3.解压后发现1.png&#xff0c;为图片隐写。 4.使用010editor打开图片&#xff0c;发现缺少png文件头。 010editor官方下载链接&#xff1a;sweetscape.com/download/010editor/ 5.添加文件头保存。 6.使用图片隐写…

内网穿透的应用-Jupyter Notbook+cpolar内网穿透实现公共互联网访问使用数据分析工作

文章目录 1.前言2.Jupyter Notebook的安装2.1 Jupyter Notebook下载安装2.2 Jupyter Notebook的配置2.3 Cpolar下载安装 3.Cpolar端口设置3.1 Cpolar云端设置3.2.Cpolar本地设置 4.公网访问测试5.结语 1.前言 在数据分析工作中&#xff0c;使用最多的无疑就是各种函数、图表、…

企业软件手机app定制开发新趋势|网站小程序搭建

企业软件手机app定制开发新趋势|网站小程序搭建 随着移动互联网的快速发展和企业数字化转型的加速&#xff0c;企业软件手机App定制开发正成为一个新的趋势。这种趋势主要是由于企业对于手机App的需求增长以及现有的通用应用不能满足企业特定需求的情况下而产生的。 首先&#…

接口自动化测试很难掌握吗?不!一小时学完

一. 什么是接口测试 接口测试是一种软件测试方法&#xff0c;用于验证不同软件组件之间的通信接口是否按预期工作。在接口测试中&#xff0c;测试人员会发送请求并检查接收到的响应&#xff0c;以确保接口在不同场景下都能正常工作。 就工具而言&#xff0c;常见的测试工具有…

无公网IP下,如何实现公网远程访问MongoDB文件数据库

文章目录 前言1. 安装数据库2. 内网穿透2.1 安装cpolar内网穿透2.2 创建隧道映射2.3 测试随机公网地址远程连接 3. 配置固定TCP端口地址3.1 保留一个固定的公网TCP端口地址3.2 配置固定公网TCP端口地址3.3 测试固定地址公网远程访问 前言 MongoDB是一个基于分布式文件存储的数…

【知网稳定检索】2024年应用经济学,管理科学与社会发展国际学术会议(AEMSS 2024)

2024年应用经济学&#xff0c;管理科学与社会发展国际学术会议&#xff08;AEMSS 2024&#xff09; 2024 International Conference on Applied Economics, Management Science and Social Development 2024年应用经济学&#xff0c;管理科学与社会发展国际学术会议&#xff…
最新文章