机器学习-11-卷积神经网络-基于paddle实现神经网络

文章目录

    • 总结
    • 参考
    • 本门课程的目标
    • 机器学习定义
    • 第一步:数据准备
    • 第二步:定义网络
    • 第三步:训练网络
    • 第四步:测试训练好的网络

总结

本系列是机器学习课程的系列课程,主要介绍基于paddle实现神经网络。

参考

MNIST 训练_副本

本门课程的目标

完成一个特定行业的算法应用全过程:

懂业务+会选择合适的算法+数据处理+算法训练+算法调优+算法融合
+算法评估+持续调优+工程化接口实现

机器学习定义

关于机器学习的定义,Tom Michael Mitchell的这段话被广泛引用:
对于某类任务T性能度量P,如果一个计算机程序在T上其性能P随着经验E而自我完善,那么我们称这个计算机程序从经验E中学习
在这里插入图片描述

使用MNIST数据集训练和测试模型。

第一步:数据准备

MNIST数据集

import paddle
from paddle.vision.datasets import MNIST
from paddle.vision.transforms import ToTensor

train_dataset = MNIST(mode='train', transform=ToTensor())
test_dataset = MNIST(mode='test', transform=ToTensor())

展示数据集图片

import matplotlib.pyplot as plt
import numpy as np

train_data0, train_label_0 = train_dataset[0][0], train_dataset[0][1]
train_data0 = train_data0.reshape([28, 28])
plt.figure(figsize=(2, 2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 的标签为: ' + str(train_label_0))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))


train_data0 的标签为: [5]

第二步:定义网络

import paddle
import paddle.nn.functional as F
from paddle.nn import Conv2D, MaxPool2D, Linear

class MyModel(paddle.nn.Layer):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
        self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
        self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
        self.linear1 = Linear(in_features=16*5*5, out_features=120)
        self.linear2 = Linear(in_features=120, out_features=84)
        self.linear3 = Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.max_pool1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.max_pool2(x)
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        return x

模型可视化

import paddle
mnist = MyModel()
paddle.summary(mnist, (1, 1, 28, 28))
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          156      
  MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0       
   Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416     
  MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0       
   Linear-1          [[1, 400]]            [1, 120]           48,120     
   Linear-2          [[1, 120]]            [1, 84]            10,164     
   Linear-3          [[1, 84]]             [1, 10]              850      
===========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.30
---------------------------------------------------------------------------






{'total_params': 61706, 'trainable_params': 61706}

第三步:训练网络

import paddle
from paddle.metric import Accuracy
from paddle.static import InputSpec

inputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')

# 用Model封装模型
model = paddle.Model(MyModel(), inputs, labels)

# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

# 配置模型
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())
# 训练模型
model.fit(train_dataset, test_dataset, epochs=3, batch_size=64, save_dir='mnist_checkpoint', verbose=1)

The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/3
step 938/938 [==============================] - loss: 0.0208 - acc: 0.9456 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/0
Eval begin...
step 157/157 [==============================] - loss: 0.0041 - acc: 0.9777 - 19ms/step          
Eval samples: 10000
Epoch 2/3
step 938/938 [==============================] - loss: 0.0021 - acc: 0.9820 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/1
Eval begin...
step 157/157 [==============================] - loss: 2.1037e-04 - acc: 0.9858 - 19ms/step      
Eval samples: 10000
Epoch 3/3
step 938/938 [==============================] - loss: 0.0126 - acc: 0.9876 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/2
Eval begin...
step 157/157 [==============================] - loss: 4.7168e-04 - acc: 0.9884 - 19ms/step      
Eval samples: 10000
save checkpoint at /home/aistudio/mnist_checkpoint/final

第四步:测试训练好的网络

import paddle
import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracy
from paddle.static import InputSpec

inputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')
model = paddle.Model(MyModel(), inputs, labels)
model.load('./mnist_checkpoint/final')
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())

# results = model.evaluate(test_dataset, batch_size=64, verbose=1)
# print(results)

results = model.predict(test_dataset, batch_size=64)

test_data0, test_label_0 = test_dataset[0][0], test_dataset[0][1]
test_data0 = test_data0.reshape([28, 28])
plt.figure(figsize=(2,2))
plt.imshow(test_data0, cmap=plt.cm.binary)

print('test_data0 的标签为: ' + str(test_label_0))
print('test_data0 预测的数值为:%d' % np.argsort(results[0][0])[0][-1])

Predict begin...
step 157/157 [==============================] - 27ms/step          
Predict samples: 10000
test_data0 的标签为: [7]
test_data0 预测的数值为:7


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/583213.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

FreeRTOS:3.信号量

FreeRTOS信号量 参考链接:FreeRTOS-信号量详解_freertos信号量-CSDN博客 目录 FreeRTOS信号量一、信号量是什么二、 FreeRTOS信号量1、二值信号量1、获取信号量2、释放信号量 2、计数信号量3、互斥信号量1、优先级反转2、优先级继承3、源码解析1、互斥量创建2、获取…

FFmpeg常用结构体、关键函数、ffplay.c分析

一、常用结构体: 1、AVFormatContext结构体: AVFormatContext是一个贯穿全局的数据结构,很多函数都要用它作为参数。FFmpeg代码中对这个数据结构的注释是format I/O context,此结构包含了一个视频流的格式内容。其中存有AVIputFor…

电脑使用笔记

1.电脑亮度调节 亮度:50 对比度:45 暗部平衡:40

毕业设计注意事项(嘉庚学院2024届更新中)

1.开题 根据学院发的开题报告模板完成,其中大纲部分可参考资料。 2.毕设以及实习 2.1 毕设 根据资料中的毕设评价标准,对照工作量 2.2 实习材料提交 2.2.1 校外实习 实习前:根据学院要求,填写好实习承诺书,实习单位…

【数据结构初阶】时间复杂度和空间复杂度详解

今天我们来详细讲讲时间复杂度和空间复杂度,途中如果有不懂的地方可翻阅我之前文章。 个人主页:小八哥向前冲~-CSDN博客 数据结构专栏:数据结构【c语言版】_小八哥向前冲~的博客-CSDN博客 c语言专栏:c语言_小八哥向前冲~的博客-CS…

网络服务SSH-远程访问及控制

一.SSH远程管理 1.SSH介绍 SSH(Secure Shell)是一种安全通道协议,最早是由芬兰的一家公司开发出来,并且在IETF (Internet Engineering Task Force)的网络草案基础上制定而成的标准协议。主要用来实现字符…

合规基线:让安全大检查更顺利

前言 说起安全检查,安全从业人员可能都非常熟悉“安全标准”概念。所有企事业单位网络安全建设都需要满足来自于国家或监管单位的安全标准,如等保2.0、CIS安全标准等。安全标准,还有一个叫法就是“安全基线”。字典上对“基线”的解释是&…

【PostgreSQL】pg触发器介绍

注: 本文为云贝教育 刘峰 原创,请尊重知识产权,转发请注明出处,不接受任何抄袭、演绎和未经注明出处的转载。 触发器是在对指定表执行指定更改操作(SQL INSERT、UPDATE、DELETE 或 TRUNCATE 语句)时自动运行的一组操作…

Java全栈开发前端+后端(全栈工程师进阶之路)-环境搭建

在课程开始前我们要配置好我们的开发环境,这里我的电脑太乱了,我使用vm虚拟机进行搭建开发环境,如果有需要环境的或者安装包,可以私信我。 那我们开始 首先我们安装数据库 这里我们使用小皮面板 小皮面板(phpstudy) - 让天下没…

面向表格数据的大模型推理

现实世界中存在大量未利用的先验知识和未标记数据。而在医疗等高价值领域,获取足够标记数据训练机器学习模型尤其困难,这限制了传统监督学习算法的应用。尽管深度学习方法在其他领域取得了显著进展,但在表格数据分类上仍未能超越传统的机器学…

WEB攻防-PHP特性-函数缺陷对比

目录 和 MD5函数 intval ​strpos in_array preg_match str_replace 和 使用 时,如果两个比较的操作数类型不同,PHP 会尝试将它们转换为相同的类型,然后再进行比较。 使用 进行比较时,不仅比较值,还比较变量…

在Linux操作系统中关于磁盘(硬盘)管理的操作

电脑中数据存储设备:硬盘(实现数据的持久化存储),内存 在Linux操作系统中一切皆文件的思想,所有的设备在Linux操作系统中都是通过文件来标识的,所以每一个硬盘都对应一个块设备文件。 在Linux操作系统中所…

虚函数表与虚函数表指针

虚函数表与虚函数表是用来实现多态的,每一个类只有一个虚函数表 静态多态:函数重载(编译期确定) 动态多态:虚函数(运行期确定) 虚函数表的创建时机: 生成时间: 编译期…

C++——map和set的基础应用

目录 1.关联式容器 2.键值对 3.树型结构的关联式容器 4.set(key模型) 1.set性质 2.注意 3.set的使用 1.创建对象 2.插入 3.删除 4.查找 4.map(key value模型) 1.性质 2.map的使用 1.创建对象 2.插入 3.删除 4.[] 1.关联式容器 C中…

pytho爬取南京房源成交价信息并导入到excel

# encoding: utf-8 # File_name: import requests from bs4 import BeautifulSoup import xlrd #导入xlrd库 import pandas as pd import openpyxl# 定义函数来获取南京最新的二手房房子成交价 def get_nanjing_latest_second_hand_prices():cookies {select_city: 320100,li…

C语言 | Leetcode C语言题解之第50题Pow(x,n)

题目&#xff1a; 题解&#xff1a; double myPow(double x, int n){if(n 0 || x 1){return 1;}if(n < 0){return 1/(x*myPow(x,-(n1)));}if(n % 2 0){return myPow(x*x,n/2);}else{return x*myPow(x*x,(n - 1)/2);} }

selenium 4.x入门篇(环境搭建、八大元素定位)

背景 Web自动化测现状 1. 属于 E2E 测试 2. 过去通过点点点 3. 好的测试&#xff0c;还需要记录、调试网页的细节 一、selenium4.x环境搭建 一键搭建 pip3 install webdriver-helper 安装后自动的完成&#xff1a; 1. 查看浏览器的版本号 2. 查询操作系统的类型…

68.网络游戏逆向分析与漏洞攻防-利用数据包构建角色信息-自动生成CPP函数解决数据更新的问题

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 如果看不懂、不知道现在做的什么&#xff0c;那就跟着做完看效果&#xff0c;代码看不懂是正常的&#xff0c;只要会抄就行&#xff0c;抄着抄着就能懂了 内容…

0426GoodsBiddingAJAX项目

0426GoodsBiddingAJAX项目包-CSDN博客 数据库字段 ​ 管理员的登录界面 ​ 登录成功跳转在线拍卖界面&#xff0c;使用监听器拦截请求&#xff0c;只能登录管理员后访问该界面 ​ 商品竞拍列表 ​ 商品竞拍列表的竞拍操作&#xff1a; ​ 1 用户未登录跳转用户登录界面&#x…

(三十一)第 5 章 数组和广义表(稀疏矩阵的三元组行逻辑链接的顺序存储表示实现)

1. 背景说明 2. 示例代码 1)errorRecord.h // 记录错误宏定义头文件#ifndef ERROR_RECORD_H #define ERROR_RECORD_H#include <stdio.h> #include <string.h> #include <stdint.h>// 从文件路径中提取文件名 #define FILE_NAME(X) strrchr(X, \\) ? strrch…