动手学深度学习—卷积神经网络LeNet(代码详解)

1. LeNet

LeNet由两个部分组成:

  • 卷积编码器:由两个卷积层组成;
  • 全连接层密集块:由三个全连接层组成。

在这里插入图片描述

  1. 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层;
  2. 每个卷积层使用5×5卷积核和一个sigmoid激活函数;
  3. 这些层将输入映射到多个二维特征输出,通常同时增加通道的数量;
  4. 每个4×4池操作(步幅2)通过空间下采样将维数减少4倍。
import torch
from torch import nn
from d2l import torch as d2l

# 定义模型net
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

该模型去掉了最后一层的高斯激活,下面将一个大小为28×28的单通道(黑白)图像通过LeNet,打印每一层输出的形状。

# 观察各层的输入输出通道数,宽度和高度
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

在这里插入图片描述

  1. 第一个卷积层使用2个像素的填充,来补偿5×5卷积核导致的特征减少;
  2. 第二个卷积层没有填充,因此高度和宽度都减少了4个像素;
  3. 随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个;
  4. 每个汇聚层的高度和宽度都减半;
  5. 每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

2. 模型训练

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
"""
    定义精度评估函数:
    1、将数据集复制到显存中
    2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    # 判断net是否属于torch.nn.Module类
    if isinstance(net, nn.Module):
        net.eval()
        
        # 如果不在参数选定的设备,将其传输到设备中
        if not device:
            device = next(iter(net.parameters())).device
    
    # Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            # 将X, y复制到设备中
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            
            # 计算正确预测的数量,总预测的数量,并存储到metric中
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]
"""
    定义GPU训练函数:
    1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;
    2、使用Xavier随机初始化模型参数;
    3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    # 定义初始化参数,对线性层和卷积层生效
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    
    # 在设备device上进行训练
    print('training on', device)
    net.to(device)
    
    # 优化器:随机梯度下降
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    
    # 损失函数:交叉熵损失函数
    loss = nn.CrossEntropyLoss()
    
    # Animator为绘图函数
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    
    # 调用Timer函数统计时间
    timer, num_batches = d2l.Timer(), len(train_iter)
    
    for epoch in range(num_epochs):
        
        # Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量
        metric = d2l.Accumulator(3)
        net.train()
        
        # enumerate() 函数用于将一个可遍历的数据对象
        for i, (X, y) in enumerate(train_iter):
            timer.start() # 进行计时
            optimizer.zero_grad() # 梯度清零
            X, y = X.to(device), y.to(device) # 将特征和标签转移到device
            y_hat = net(X)
            l = loss(y_hat, y) # 交叉熵损失
            l.backward() # 进行梯度传递返回
            optimizer.step()
            with torch.no_grad():
                # 统计损失、预测正确数和样本数
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop() # 计时结束
            train_l = metric[0] / metric[2] # 计算损失
            train_acc = metric[1] / metric[2] # 计算精度
            
            # 进行绘图
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
                
        # 测试精度
        test_acc = evaluate_accuracy_gpu(net, test_iter) 
        animator.add(epoch + 1, (None, None, test_acc))
        
    # 输出损失值、训练精度、测试精度
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'
          f'test acc {test_acc:.3f}')
    
    # 设备的计算能力
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'
          f'on {str(device)}')

在这里插入图片描述

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

3. 小结

  1. 卷积神经网络(CNN)是一类使用卷积层的网络;
  2. 卷积神经网络中,可以组合使用卷积层、非线性激活函数和汇聚层;
  3. 为了构造高性能的卷积神经网络,通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数;
  4. 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/80797.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

强训第31天

选择 传输层叫段 网络层叫包 链路层叫帧 A 2^16-2 C D C 70都没收到,确认号代表你该从这个号开始发给我了,所以发70而不是71 B D C 248&123120 OSI 物理层 数据链路层 网络层 传输层 会话层 表示层 应用层 C 记一下304读取浏览器缓存 502错误网关 编…

349. 两个数组的交集 题解

题目描述:349. 两个数组的交集 - 力扣(LeetCode) 给定两个数组 nums1 和 nums2 ,返回 它们的交集 。输出结果中的每个元素一定是 唯一 的。我们可以 不考虑输出结果的顺序 。 方法一: 解题思路: 我们可以…

时序数据库influxdb笔记

官方资料 https://docs.influxdata.com/influxdb/v2.7/install/?tLinux https://www.influxdata.com/influxdb/ 安装 1、linux平台下 1)下载 2)解压 3)添加账户( adduser influx) 4)设置目录权限 5…

缺少或找不到vcruntime140_1.dll的解决方法

某天,当我准备打开电脑上的一个应用程序时,突然收到一个错误提示,显示缺少了vcruntime140_1.dll文件。这个文件是一个重要的系统组件,它的丢失导致了我无法正常运行该应用程序。于是,我开始了一场寻找和修复旅程。然而…

RFID技术助力汽车零配件装配产线,提升效率与准确性

随着科技的不断发展,越来越多的自动化设备被应用到汽车零配件装配产线中。其中,射频识别(Radio Frequency Identification,简称RFID)技术凭借其独特的优势,已经成为了这一领域的重要技术之一。本文将介绍RF…

【BUG】docker安装nacos,浏览器却无法访问到页面

个人主页:金鳞踏雨 个人简介:大家好,我是金鳞,一个初出茅庐的Java小白 目前状况:22届普通本科毕业生,几经波折了,现在任职于一家国内大型知名日化公司,从事Java开发工作 我的博客&am…

go 协程并发数控制

错误的写法&#xff1a; 这里的<-ch 是为了从channel 中读取 数据&#xff0c;为了不使channel通道被写满&#xff0c;阻塞 go 协程数的创建。但是请注意&#xff0c;go workForDraw(v, &wg) 是不阻塞后续的<-ch 执行的&#xff0c;所以就一直go workForDraw(v, &…

数学建模之“灰色预测”模型

灰色系统分析法在建模中的应用 1、CUMCM2003A SARS的传播问题 2、CUMCM2005A长江水质的评价和预测CUMCM2006A出版社的资源配置 3、CUMCM2006B艾滋病疗法的评价及疗效的预测问题 4、CUMCM2007A 中国人口增长预测 灰色系统的应用范畴大致分为以下几方面: (1&#xff09;灰色关…

js简介以及在html中的2种使用方式(hello world)

简介 javascript &#xff1a;是一个跨平台的脚本语言&#xff1b;是一种轻量级的编程语言。 JavaScript 是 Web 的编程语言。所有现代的 HTML 页面都使用 JavaScript。 HTML&#xff1a; 结构 css&#xff1a; 表现 JS&#xff1a; 行为 HTMLCSS 只能称之为静态网页&#xff0…

IronPDF for .NET Crack

IronPDF for .NET Crack ronPDF现在将等待HTML元素加载后再进行渲染。 IronPDF现在将等待字体加载后再进行渲染。 添加了在绘制文本时指定旋转的功能。 添加了在保存为PDFA时指定自定义颜色配置文件的功能。 IronPDF for.NET允许开发人员在C#、F#和VB.NET for.NET Core和.NET F…

批量提取文件名到excel,详细的提取步骤

如何批量提取文件名到excel&#xff1f;我们的电脑中可能存储着数量非常多的电子文件&#xff0c;现在需要快速将这些文件的名称全部提取到Excel中。虽然少量数据可以通过复制粘贴的方式轻松完成&#xff0c;但是对于上万个数据而言&#xff0c;复制粘贴都是行不通的&#xff0…

改进YOLO系列:2.添加ShuffleAttention注意力机制

添加ShuffleAttention注意力机制 1. ShuffleAttention注意力机制论文2. ShuffleAttention注意力机制原理3. ShuffleAttention注意力机制的配置3.1common.py配置3.2yolo.py配置3.3yaml文件配置1. ShuffleAttention注意力机制论文 论文题目:SA-NET: SHUFFLE ATTENTION …

【CSS动画02--卡片旋转3D】

CSS动画02--卡片旋转3D 介绍代码HTMLCSS css动画02--旋转卡片3D 介绍 当鼠标移动到中间的卡片上会有随着中间的Y轴进行360的旋转&#xff0c;以下是几张图片的介绍&#xff0c;上面是鄙人自己录得一个供大家参考的小视频&#x1f92d; 代码 HTML <!DOCTYPE html>…

计算机视觉--距离变换算法的实战应用

前言&#xff1a; Hello大家好&#xff0c;我是Dream。 计算机视觉CV是人工智能一个非常重要的领域。 在本次的距离变换任务中&#xff0c;我们将使用D4距离度量方法来对图像进行处理。通过这次实验&#xff0c;我们可以更好地理解距离度量在计算机视觉中的应用。希望大家对计算…

MobaXterm网络远程工具介绍下载安装破解使用

一、介绍 obaXterm 是远程计算机的工具箱。在单个 Windows 应用程序中&#xff0c;它提供了大量为程序员、网站管理员、IT 管理员量身定制的功能。 MobaXterm 为 Windows 桌面提供了重要的远程网络工具&#xff08;SSH、X11、RDP、VNC、FTP、MOSH 等&#xff09;和Unix 命令&a…

Unity 找不到 Navigation 组件的解决

当我们想利用unity 里面的Navigation 组件来实现我们的物体的自动导航时&#xff0c;有时竟然会发现我们的菜单栏里面找不到 该组件 这时我们应该怎么办&#xff1f; 请确保你的项目中已经导入了Unity的AI模块。要导入该模块&#xff0c;请打开"Project Settings"&am…

计算机网络----CRC冗余码的运算

目录 1. 冗余码的介绍及原理2. CRC检验编码的例子3. 小练习 1. 冗余码的介绍及原理 冗余码是用于在数据链路层的通信链路和传输数据过程中可能会出错的一种检错编码方法&#xff08;检错码&#xff09;。原理&#xff1a;发送发把数据划分为组&#xff0c;设每组K个比特&#…

pytest自动化测试框架,真正做到从0到1由浅入深详细讲解【万字级】

目录 嗨咯铁汁们&#xff0c;很久不见&#xff0c;我还是你们的老朋友凡叔&#xff0c;这里也感谢各位小伙伴的点赞和关注&#xff0c;你们的三连是我最大的动力哈&#xff0c;我也不会辜负各位的期盼&#xff0c;这里呢给大家出了一个pytest自动化测试框架由浅入深详细讲解。 …

Tomcat日志中文乱码

修改安装目录下的日志配置 D:\ProgramFiles\apache-tomcat-9.0.78\conf\logging.properties java.util.logging.ConsoleHandler.encoding GBK

thinkphp6前后端验证码分离以及验证

1.验证码接口生成验证码: public function verify(){return captcha(); } 也可以自己写方法 2.验证方法和普通模式session验证有区别,需要改原文件: 修改后的代码: <?php // +---------------------------------------------------------------------- // | ThinkP…
最新文章