神经网络必备基础

和神经网络介绍相比,本文更侧重于程序实现

理解Keras中的组件

Keras是一个高级的神经网络API,用Python实现的,并且可以运行在TensorFlow、CNTK或Theano等后台之上。

model.compile()

compile(self, optimizer, loss, metrics=None, ...)
  • 该函数用于配置模型的学习过程
  • 接收3个参数
    • optimizer: 是对损失函数(Loss function)求最小值的算法,包含一系列参数。
      • 可以在compile()调用时直接指定算法的名称,比如"SGD", "adam","rmsprop"等
      • 也可以在compile之前先实例化,然后再将这个实例传给compile()
    • loss: 用于测量神经网络误差的准确度的目标函数,即损失函数,会在训练过程用于调整参数
    • metrics: 用于模型评价。metrics和loss在含义上的区别可以参考这篇博文:loss与metric的区别 以及 optimizer的介绍_metric loss-CSDN博客

常用的optimizer

  • SGD: Stochastic Gradient Descent,随机梯度下降法。支持动量(momentum),学习率衰减(learning rate decay)等超参数
  • RMSprop:常用于循环神经网络
  • Adam:Adaptive Moment Estimation,自适应矩估计。是一种基于一阶梯度的随机目标函数优化算法

常用的loss

  • mean_squared_error:均方差。用于回归问题
  • categorical_crossentropy: 分类交叉熵。计算预测和目标值之间的分类交叉熵,常用于目标有多个分类的问题。
  • binary_crossentropy: 二元交叉熵。计算预测和目标值之间的二元交叉熵,常用于目标有2个分类的问题。

以上参数详细信息均可以查看Keras手册: Optimizers, Losses, Metrics

使用Keras实现一个神经网络

下面给出一个神经网络的程序案例。该案例使用了经典的MNIST数据集。该数据集是一个用于识别手写数字的数据集,即根据手写的数字样式(图片),识别0~9共10个数字。该问题是一个多分类问题。该数据集的官网:http://yann.lecun.com/exdb/mnist/。以下程序依然是使用google Colaboratory的开发环境。

1. 导入Python的包

from keras.datasets import mnist
from keras.preprocessing.image import load_img, array_to_img
from keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense

import numpy as np
import matplotlib.pyplot as plt

# Allow view plots in the notebook, '%' is a magic usage for python for special
%matplotlib inline 

2. 加载数据

导入Keras的mnist数据集后,直接调用其load_data()方法

(X_train_raw, y_train_raw), (X_test_raw, y_test_raw) = mnist.load_data()
print(X_train_raw.shape, X_test_raw.shape)
print(y_train_raw.shape, y_test_raw.shape)

输出如下:

(60000, 28, 28) (10000, 28, 28)
(60000,) (10000,)

可以看到训练集中有60000组数据,每组数据包含28x28=784个像素点,即每个采样有784的维度,按照神经网络的通俗理解,有784个特性。之所以强调这个理解,是因为在随后的主题中,将讨论到卷积神经网络,在那个主题里,我们将讨论到图像识别问题中的高维问题。而本文中的问题,虽然也是一个图像识别问题,但由于像素点不是特别多,所以采用的方法还是传统的神经网络模型,即全连接的前馈神经网络。

测试集中有10000组数据。而标签y的每组数据只有一个数值,即0~9。

3. 理解图像数据

print(X_train_raw[0].shape)
plt.imshow(X_train_raw[0], cmap='gray')
print(y_train_raw[0])

输出如下:

可以看到,这个数字是5,图片显示了这个5的手写样式。

4. 预处理训练数据

对于图像数据,其每组数据(像素点)均为MxN的矩阵形式(不考虑色彩),而通常的神经网络,每层处理的实际上一个向量数据,而非矩阵形式,因此需要将MxN的矩阵转化为一个1xMN的向量。下面程序中的reshape()函数实现了这个功能,这也是一般图像数据需要进行的预处理。

每组数据的每个像素点的值是0~255(表示由黑渐变到白的每一种颜色),为了消除数值本身的影响,采用了对每个数值除以255的定标操作。

image_height, image_width = 28, 28
X_train = X_train_raw.reshape(60000, image_height*image_width)
X_test = X_test_raw.reshape(10000, image_height*image_width)
print(X_train.shape, X_test.shape)

# data value is scaled, which is divided by 255 (each pixel value is 0~255)
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
print(X_train[0])

reshape之后,X_train和X_test的shape输出如下:

(60000, 784) (10000, 784)

5. 预处理测试数据

对于二分类问题,标签一般为0,1;而对于多分类问题,我们会将标签扩张为一个由若干0和1组成的向量。

在这个问题中,总共有10个数字,我们构造一个1x10的向量,表示每一个数字。例如数字5可以表示为[0,0,0,0,0,1,0,0,0,0]。函数to_categorical()实现了该功能

# convert each class value to a vector with value of 0 or 1
# i.e. if class vector, then it will be a matrix with value of 0 or 1
# e.g. for 0~9, 10 classes in total, if value is 5, then it will be 
#      coverted to [0,0,0,0,0,1,0,0,0,0]
y_train = to_categorical(y_train_raw, 10)
y_test = to_categorical(y_test_raw, 10)
print(y_train.shape)
print(y_train[0])

输出如下:

(60000, 10)
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]

6. 构建Keras模型

Keras使用Sequential()构造一个空模型,然后用add()添加每一个层。Dense()用于产生一个全连接层。其中在第一层需要指明input_shape()。input_shape中的数字由每组数据包含的数值个数决定(本例中是28x28=784)。

model = Sequential()

# 'Dense' means it's a full-connected layer
# For the 1st layer, the shape must be given, input_shape=(784,) means the 
#   the input is a matrix with N * 784 (X_train.shape is (60000, 784))
model.add(Dense(512, activation='relu', input_shape=(784,)))

# For the following layers, the shape is not necessary
model.add(Dense(512, activation='relu'))

# output layer, 'softmax' is used
model.add(Dense(10, activation='softmax'))

7. 编译训练模型

# As it is a multi-classfication problem, loss is chosen as categorical_crossentropy
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# show summary to make sure it's what we expected
model.summary()

summary()输出如下

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 512)               401920    
                                                                 
 dense_1 (Dense)             (None, 512)               262656    
                                                                 
 dense_2 (Dense)             (None, 10)                5130      
                                                                 
=================================================================
Total params: 669706 (2.55 MB)
Trainable params: 669706 (2.55 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

每一层的参数是指计算这一层用到的权重和偏好的个数。对于全连接模型

  • 权重个数 = (上一层结点数) x (当前层结点数)
  • 偏好个数 = 当前层结点数

于是:

  • layer1 = 784 x 512 + 512
  • layer2 = 512 x 512 + 512
  • layer3 = 512 + 10 + 10

8. 训练模型

Keras使用fit()训练模型(其它很多AI的模型库也使用这个函数名)。返回的history会记录每个epoch的计算结果。

history = model.fit(X_train, y_train, epochs=20, validation_data=(X_test, y_test))

运行fit()后,结果如下:

Epoch 1/20
1875/1875 [==============================] - 25s 12ms/step - loss: 0.1821 - accuracy: 0.9446 - val_loss: 0.1131 - val_accuracy: 0.9642
Epoch 2/20
1875/1875 [==============================] - 24s 13ms/step - loss: 0.0796 - accuracy: 0.9753 - val_loss: 0.0731 - val_accuracy: 0.9768
Epoch 3/20
1875/1875 [==============================] - 30s 16ms/step - loss: 0.0558 - accuracy: 0.9826 - val_loss: 0.0712 - val_accuracy: 0.9774
Epoch 4/20
1875/1875 [==============================] - 35s 19ms/step - loss: 0.0419 - accuracy: 0.9871 - val_loss: 0.0793 - val_accuracy: 0.9755
Epoch 5/20
1875/1875 [==============================] - 30s 16ms/step - loss: 0.0353 - accuracy: 0.9886 - val_loss: 0.0808 - val_accuracy: 0.9773
Epoch 6/20
1875/1875 [==============================] - 23s 12ms/step - loss: 0.0281 - accuracy: 0.9906 - val_loss: 0.0690 - val_accuracy: 0.9831
Epoch 7/20
1875/1875 [==============================] - 24s 13ms/step - loss: 0.0234 - accuracy: 0.9926 - val_loss: 0.1078 - val_accuracy: 0.9769
Epoch 8/20
1875/1875 [==============================] - 24s 13ms/step - loss: 0.0218 - accuracy: 0.9929 - val_loss: 0.0940 - val_accuracy: 0.9790
Epoch 9/20
1875/1875 [==============================] - 23s 12ms/step - loss: 0.0212 - accuracy: 0.9937 - val_loss: 0.1177 - val_accuracy: 0.9777
Epoch 10/20
1875/1875 [==============================] - 23s 12ms/step - loss: 0.0193 - accuracy: 0.9943 - val_loss: 0.1055 - val_accuracy: 0.9802
Epoch 11/20
1875/1875 [==============================] - 22s 12ms/step - loss: 0.0186 - accuracy: 0.9941 - val_loss: 0.0923 - val_accuracy: 0.9824
Epoch 12/20
1875/1875 [==============================] - 26s 14ms/step - loss: 0.0126 - accuracy: 0.9961 - val_loss: 0.1040 - val_accuracy: 0.9829
Epoch 13/20
1875/1875 [==============================] - 25s 13ms/step - loss: 0.0162 - accuracy: 0.9958 - val_loss: 0.1177 - val_accuracy: 0.9792
Epoch 14/20
1875/1875 [==============================] - 28s 15ms/step - loss: 0.0156 - accuracy: 0.9958 - val_loss: 0.1153 - val_accuracy: 0.9813
Epoch 15/20
1875/1875 [==============================] - 23s 12ms/step - loss: 0.0159 - accuracy: 0.9959 - val_loss: 0.0955 - val_accuracy: 0.9833
Epoch 16/20
1875/1875 [==============================] - 23s 12ms/step - loss: 0.0128 - accuracy: 0.9964 - val_loss: 0.1216 - val_accuracy: 0.9827
Epoch 17/20
1875/1875 [==============================] - 23s 12ms/step - loss: 0.0147 - accuracy: 0.9962 - val_loss: 0.1250 - val_accuracy: 0.9807
Epoch 18/20
1875/1875 [==============================] - 22s 12ms/step - loss: 0.0096 - accuracy: 0.9975 - val_loss: 0.1409 - val_accuracy: 0.9791
Epoch 19/20
1875/1875 [==============================] - 23s 12ms/step - loss: 0.0177 - accuracy: 0.9961 - val_loss: 0.1395 - val_accuracy: 0.9789
Epoch 20/20
1875/1875 [==============================] - 23s 12ms/step - loss: 0.0122 - accuracy: 0.9967 - val_loss: 0.1324 - val_accuracy: 0.9829

注意到每一行表示一个epoch的输出,这一行中记录了’loss‘, 'accuracy', 'val_loss', 'val_accuracy' 4个参数,这些信息都会保存到history中。这一行同时也记录了每个epoch运行的时间,20~30秒。

9. 浏览训练模型的准确度

查看history(fit函数的返回值)中的记录,检查模型训练的准确度

例如,查看'accuracy'

plt.plot(history.history['accuracy'])

10. 评价模型

使用evaluate()函数执行模型评价

score = model.evaluate(X_test, y_test)
print(score)

输出如下

313/313 [==============================] - 1s 4ms/step - loss: 0.1324 - accuracy: 0.9829
[0.13238213956356049, 0.9829000234603882]

后面那个数字表示准确度:98.29%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/439040.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python刘诗诗

写在前面 刘诗诗在电视剧《一念关山》中饰演了女主角任如意,这是一个极具魅力的女性角色,她既是一位有着高超武艺和智慧的女侠士,也曾经是安国朱衣卫前左使,身怀绝技且性格坚韧不屈。剧中,任如意因不满于朱衣卫的暴行…

【Spring】Spring状态机

1.什么是状态机 (1). 什么是状态 先来解释什么是“状态”( State )。现实事物是有不同状态的,例如一个自动门,就有 open 和 closed 两种状态。我们通常所说的状态机是有限状态机,也就是被描述的事物的状态的数量是有…

vue页面刷新问题:返回之前打开的页面,走了create方法(解决)

vue页面刷新问题:返回之前打开的页面,走了create方法(解决) 直接上图, 我们在开发的时候经常会复制粘贴,导致vue文件的name没有及时修改 我们需要保证name和浏览器的地址一致,这样才能实现缓…

2024 PhpStorm激活,分享几个PhpStorm激活的方案

文章目录 PhpStorm 公司简介我这边使用PhpStorm的理由PhpStorm 2023.3 最新变化AI Assistant 预览阶段结束 正式版基于 LLM 的代码补全测试代码生成编辑器内代码生成控制台中基于 AI 的错误解释 Pest 更新PHP 8.3 支持#[\Override] 特性新的 json_validate() 函数类型化类常量弃…

OpenCascade源码剖析:Standard_Transient根类

Standard_Transient是OCCT继承体系最顶层的根类,Transient在编程中具有一定的语义,与Persistent相对应,通常用于描述数据的持久性或持久性存储。 Transient,意味着数据是临时的或瞬态的,它们不会被持久化保存&#xf…

【C语言基础】:深入理解指针(三)

文章目录 深入理解指针一、冒泡排序二、二级指针三、指针数组3.1 指针数组模拟二维数组 四、字符指针变量五、数组指针变量5.1 数组指针变量是什么?5.2 数组指针变量的初始化 六、二维数组传参的本质 深入理解指针 指针系列回顾: 【C语言基础】&#xf…

Ubuntu 24.04 抢先体验换国内源 清华源 阿里源 中科大源 163源

Update 240307:Ubuntu 24.04 LTS 进入功能冻结期 预计4月25日正式发布。 Ubuntu22.04换源 Ubuntu 24.04重要升级daily版本下载换源步骤 (阿里源)清华源中科大源网易163源 Ubuntu 24.04 LTS,代号 「Noble Numbat」,即将与我们见面! Canonica…

Java代码审计安全篇-目录穿越漏洞

前言: 堕落了三个月,现在因为被找实习而困扰,着实自己能力不足,从今天开始 每天沉淀一点点 ,准备秋招 加油 注意: 本文章参考qax的网络安全java代码审计,记录自己的学习过程,还希望各…

揭示/proc/pid/pagemap的力量:在Linux中将虚拟地址映射到物理地址

pagemap的力量:在Linux中将虚拟地址映射到物理地址 一、/proc/pid/pagemap简介二、了解虚拟地址到物理地址的转换三、利用/proc/pid/pagemap进行地址转换3.1、访问/proc/pid/pagemap3.2、pagemap文件的数据和结构 四、页表、页框架的相关概念五、总结 一、/proc/pid…

信号处理-探索相邻数据点之间的变化和关联性的操作方法

当前值减去前一个值,乘上当前值与前一个值差值的绝对值 当前值减去后一个值,乘上当前值与后一个值差值的绝对值。 意义何在? 当前值减去前一个值:表示当前数据点与前一个数据点之间的变化量。当前值与前一个值差值的绝对值&…

【Linux】软件管理器yum和编辑器vim

🔥博客主页: 小羊失眠啦. 🎥系列专栏:《C语言》 《数据结构》 《C》 《Linux》 《Cpolar》 ❤️感谢大家点赞👍收藏⭐评论✍️ 文章目录 一、Linux下安装软件的方案1.1 源代码安装1.2 rpm安装1.3 yum安装 二、Linux软件…

外贸常用的出口认证 | 全球外贸数据服务平台 | 箱讯科技

出口认证是一种贸易信任背书,对许多外贸从业者而言,产品的出口认证和当前的国际贸易环境一样复杂多变,不同的目标市场、不同的产品类别,所需要的认证及标准也不同。 国际认证 01 IECEE-CB IECEE-CB体系的中文含义是“关于电工产品测试证书的相互认可体…

记一次简单的获取虚拟机|伪终端shell权限

场景描述 某个系统是ova文件,导入虚拟机启动,但是启动后只有一个伪终端权限,即权限很小,如何拿到这个虚拟机的shell权限呢? 实际操作 这次运气比较好,所遇到的系统磁盘并没有被加密,所以直接…

吴恩达深度学习笔记:神经网络的编程基础2.1-2.3

目录 第一门课:神经网络和深度学习 (Neural Networks and Deep Learning)第二周:神经网络的编程基础 (Basics of Neural Network programming)2.1 二分类(Binary Classification)2.2 逻辑回归(Logistic Regression) 第一门课:神经网络和深度学…

c++ 11 新特性 不同数据类型之间转换函数之reinterpret_cast

一.不同数据类型之间转换函数reinterpret_cast介绍 reinterpret_cast是C中的一种类型转换操作符,用于执行低级别的位模式转换。具体来说,reinterpret_cast可以实现以下功能: 指针和整数之间的转换:这种转换通常用于在指针中存储额…

如何学习、上手点云算法(三):用VsCode、Visual Studio来debug基于PCL、Open3D的代码

写在前面 本文内容 以PCL 1.14.0,Open3D0.14.1为例,对基于PCL、Open3D开发的代码进行源码debug; 如何学习、上手点云算法系列: 如何学习、上手点云算法(一):点云基础 如何学习、上手点云算法(二):点云处理相…

【数据结构】二、线性表:6.顺序表和链表的对比不同(从数据结构三要素讨论:逻辑结构、物理结构(存储结构)、数据运算(基本操作))

文章目录 6.对比:顺序表&链表6.1逻辑结构6.2物理结构(存储结构)6.2.1顺序表6.2.2链表 6.3数据运算(基本操作)6.3.1初始化6.3.2销毁表6.3.3插入、删除6.3.4查找 6.对比:顺序表&链表 6.1逻辑结构 顺…

基于pytest的证券清算系统功能测试工具开发

需求 1.造测试数据:根据测试需要,自动化构造各业务场景的中登清算数据与清算所需起来数据 2.测试清算系统功能: 自动化测试方案 工具设计 工具框架图 工具流程图 实现技术 python, pytest, allure, 多进程,mysql, 前端 效果 测…

Web开发介绍,制作小网站流程和需要的技术【详解】

1.什么是web开发 Web:全球广域网,也称为万维网(www World Wide Web),能够通过浏览器访问的网站。 所以Web开发说白了,就是开发网站的,例如网站:淘宝,京东等等 2. 网站的工作流程 1.首先我们需…

【Godot4自学手册】第二十一节掉落金币和收集

这一节我们主要学习敌人死亡后随机掉落金币,主人公可以进行拾取功能。 一、新建金币场景 新建场景,节点选择CharacterBody2D,命名为Coins,将场景保存到Scenes目录下。 1.新建节点 为根节点依次添加CollisionShape2D节点&#…