李沐之经典卷积神经网络

目录

1. LeNet

2. 代码实现


1. LeNet

输入是32*32图片,放到一个5*5的卷积层里面,卷积层的输出通道数是6,高宽都是28(32-5+1=28)。再经过2*2的池化层,把28*28变成14*14(28-2+2)/2=14,这里的步幅和窗口的大小一样。再经过5*5的卷积层,输出就变成10*10的(14-5+1=10)。通道数增加了从6变到16。再经过一个池化层,也是16个通道(用了16个卷积核),大小是5*5(10-2+2)/2=5。再把它拉成一个向量输入到全连接层,第一个全连接层的输出是120,第二个的输出是84,最后一个的输出是10。

2. 代码实现

import torch
from torch import nn
from d2l import torch as d2l

#原来的数据集是32*32,这里为了一样每边各padding2行(同时也是为了框住字),4个边就是32行*32列,
#为了得到非线性加一个sigmoid函数
net=nn.Sequential(
        nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
        nn.AvgPool2d(kernel_size=2,stride=2),
        nn.Conv2d(6,16,keinel_size=5),nn.Sigmoid(),
        nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),
        nn.Linear(16*5*5,120),nn.Sigmoid(),
        nn.Linear(120,84),nn.Sigmoid(),
        nn.Linear(84,10))

X=torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)
"""结果输出:
Conv2d output shape:         torch.Size([1, 6, 28, 28])
Sigmoid output shape:        torch.Size([1, 6, 28, 28])
AvgPool2d output shape:      torch.Size([1, 6, 14, 14])
Conv2d output shape:         torch.Size([1, 16, 10, 10])
Sigmoid output shape:        torch.Size([1, 16, 10, 10])
AvgPool2d output shape:      torch.Size([1, 16, 5, 5])
Flatten output shape:        torch.Size([1, 400])
Linear output shape:         torch.Size([1, 120])
Sigmoid output shape:        torch.Size([1, 120])
Linear output shape:         torch.Size([1, 84])
Sigmoid output shape:        torch.Size([1, 84])
Linear output shape:         torch.Size([1, 10])
"""
#在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。 第一个卷积层使用2个像素的填充,
#来补偿卷积核导致的特征减少。 相反,第二个卷积层没有填充,因此高度和宽度都减少了4个像素。 随着
#层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个。 
#同时,每个汇聚层的高度和宽度都减半。最后,每个全连接层减少维数,最终输出一个维数与结果分类数
#相匹配的输出。

#第一个模块卷积激活池化把1通道,28*28变成6通道14*14,高宽减半,通道数增加了6倍,信息变多了。


#看看LeNet在Fashion-MNIST数据集上的表现
batch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size=batch_size)


#evaluate_accuracy函数进行轻微的修改, 由于完整的数据集位于内存中,因此在模型使用GPU计算数
#据集之前,我们需要将其复制到显存中。
def evaluate_accuracy_gpu(net,data_iter,device=None):
    """使用GPU计算模型在数据集上的精度"""
    if isinstance(net,nn.Module):
    #isinstance() 函数来判断一个对象是否是一个已知的类型
        net.eval()        
        #设置为评估模式
        if not device:
        #如果device没有给定
            device=next(iter(net.parameters())).device
            #把net的参数构建成一个迭代器,把第一个net的参数拿出来看他的device在哪里
    #正确预测的数量,总预测的数量
    metric=d2l.Accumulator(2)
    with torch.no_grad():
        for X,y in data_iter:
            if isinstance(X,list):
            #如果X的类型是一个list,就每一个都挪到那个device上面去
                X=[x.to(device) for x in X]
            else:
            #如果是个tensor就挪一次
                X=X.to(device)
            y=y.to(device)
            metric=add(d2l.accuracy(net(X),y),y.numel())
    return metric[0]/metric[1]


def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):
    """用GPU训练模型"""
    def init_weights(m):
        if type(m)==mm.Linear or type(m)==nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    #对每一个parameter都run一下初始化权重函数
    print('training on',device)
    net.to(device)
    #把整个参数搬到GPU上
    optimizer=torch.optim.SGD(net.parameters(),lr=lr)
    loss=nn.CrossEntropyLoss()
    animator=d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train loss',
                           'train acc','test_acc'])
    timer,num_batches=d2l.Timer(),len(train_iter)
    for epoch in range(num_epochs):
        metric=d2l.Accumulator(3)
        net.train()
        for i,(X,y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X,y=X.to(device),y.to(device)
            y_hat=net(X)
            l=loss(y_hat,y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l*X.shape[0],d2l.accuracy(y_hat,y),X.shape[0])
            timer.stop()
            train_l=metric[0]/metric[2]
            train_acc=metric[1]/metric[2]
            if(i+1)%(num_batches//5)==0 or i==num_batches-1:
                animator.add(epoch+(i+1)/num_batches),(train_l,train_acc,None)
        test_acc=evaluate_accuracy_gpu(net,test_iter)
        animator.add(epoch+1,(None,None,test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')


#训练和评估LeNet-5模型。
lr,num_epochs=0.9,10
train_ch6=(net,train_iter,num_epochs,lr,d2l.try_gpu())
"""结果输出:
loss 0.469, train acc 0.823, test acc 0.779
55296.6 examples/sec on cuda:0"""


  • 卷积神经网络(CNN)是一类使用卷积层的网络。

  • 在卷积神经网络中,我们组合使用卷积层、非线性激活函数和汇聚层。

  • 为了构造高性能的卷积神经网络,我们通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数。

  • 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

  • LeNet是最早发布的卷积神经网络之一。

参考:

python中isinstance()函数详解_python instance函数-CSDN博客

Pytorch torch.device()的简单用法_torch.device('cuda:0')-CSDN博客

Python迭代器基本方法iter()及其魔法方法__iter__()原理详解-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/312593.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据在内存中的存储(C语言)

​ ✨✨ 欢迎大家来到贝蒂大讲堂✨✨ ​ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 ​ 所属专栏:C语言学习 ​ 贝蒂的主页:Betty‘s blog 引言 ​ 我们早就学完基本的数据类型,那这些数据类型…

window中安装Apache http server(httpd-2.4.58-win64-VS17)

windows中安装Apache http server(httpd-2.4.58-win64-VS17) 1、下载windows版本的的httpd, https://httpd.apache.org/docs/current/platform/windows.html#down 这里选择的是Apache Lounge编译的版本 https://www.apachelounge.com/download/ 2、解压到指定目录,这…

【5】商密测评密码辅助工具

0X01 前言 最近在学了下商密测评,研究了下技术层面的测评,感觉找工具不方便,就顺手自己造了个辅助工具,都是自己遇到需要用的。 0x02 工具功能介绍 不爱打字,直接上图。后续根据技术测评层面需要继续完善和增加功能。…

Hive基础知识(九):Hive对数据库表的增删改查操作

1. 创建表 1)建表语法 CREATE [EXTERNAL] TABLE [IF NOT EXISTS] table_name #EXTERNAL:外部的 [(col_name data_type [COMMENT col_comment],...)] [COMMENT table_comment] [PARTITIONED BY (col_name data_type [COMMENT col_comment],...)]#PARTITIO…

Session与Cookie

目录 一、Session会话技术 概念 常用方法 生命周期 有效期 场景 二、Cookie技术 一、Session会话技术 概念 浏览器和服务器之间为了实现某个功能,产生了多次请求和响应,从第一次请求开始到最后一次请求结束,这期间所有的请求和响应加…

走进Docker的世界

文章目录 前言一、Docker相关概述1、什么是docker?2、为什么出现docker?2.1 容器与kvm虚拟化的对比2.2 docker的作用 二、安装docker及配置文件调整1.配置宿主机网卡转发2.yum安装docker3.修改daemon.json文件4.修改docker镜像和容器的默认存储路径5.启动…

高效构建Java应用:Maven入门和进阶(四)

高效构建Java应用:Maven入门和进阶(四) 四. Maven聚合和继承特性4.1 Maven工程继承关系4.2 Maven工程聚合关系 四. Maven聚合和继承特性 4.1 Maven工程继承关系 继承概念 Maven 继承是指在 Maven 的项目中,让一个项目从另一个项目…

GLES学习笔记---立方体贴图(一张图)

一、首先看一张效果图 立方体贴图 二、纹理坐标划分 如上图是一张2D纹理,我们需要将这个2D纹理贴到立方体上,立方体有6个面,所以上面的2D图分成了6个面,共有14个纹理坐标 三、立方体 上边的立方体一共8个顶点坐标,范围…

Redis(四)事务

文章目录 事务Redis事务 vs 数据库事务常用命令总结 事务 一个队列中、一次性、顺序性、排他性执行一系列命令 官网https://redis.io/docs/interact/transactions/ Redis事务 vs 数据库事务 概述详述1、单独的隔离操作Redis的事务仅仅是保证事务里的操作会被连续独占的执行&a…

2022 年全国职业院校技能大赛高职组云计算赛项试卷

【赛程名称】云计算赛项第一场-私有云 某企业拟使用OpenStack 搭建一个企业云平台,以实现资源池化弹性管理、企业应用集中管理、统一安全认证和授权等管理。 系统架构如图 1 所示,IP 地址规划如表 1 所示。 图 1 系统架构图 表 1 IP 地址规划 设备…

Java零基础教学文档第四篇:HTML_CSS_JavaScript(2)

【HTML】 【主要内容】WEB: 1.Web前端简介 2.创建第一个前端项目 3.相关标签详解 4.表格标签详解 5.表单标签详解 6.框架和实体字符 【学习目标】 1. Web前端简介 1.1 为什么要学习Web前端&#…

【Python机器学习】SVM——预处理数据

为了解决特征特征数量级差异过大,导致的模型过拟合问题,有一种方法就是对每个特征进行缩放,使其大致处于同一范围。核SVM常用的缩放方法是将所有的特征缩放到0和1之间。 “人工”处理方法: import matplotlib.pyplot as plt from…

Java异常处理之旅:解救迷失的程序员

目录​​​​​​​ 一、前言 二、基础知识 2.1 异常的概念 ​​​​​​2.2 异常分类 2.3 异常处理的原则 ​​​​​​三、异常处理的语法 3.1 try-catch语句 3.2 finally语句 3.3 throw语句 3.4 throws关键字 3.5 自定义异常 四、常见异常及处理方式 4.1 NullP…

【C语言】linux内核set_task_stack_end_magic函数

一、函数定义 void set_task_stack_end_magic(struct task_struct *tsk) {unsigned long *stackend;stackend end_of_stack(tsk);*stackend STACK_END_MAGIC; /* for overflow detection */ } 内核版本6.4.3、6.7。 二、代码解读 解读1 这段代码是一个在Linux内核中定…

芯课堂 | 固件升级方法及架构

本次介绍一种固件升级方法及架构。 所述方法通过运行引导加载程序,并基于引导加载程序,获取启动引导标志位; 在启动引导标志位为预设枚举标志位时,执行对应启动引导标志位的固件升级动作; 在启动引导标志位为非预设…

Cesium 模型压平

最近整理了下手上的代码,以下是对模型压平的说明。 原理是使用了customShader来重新设置了模型的着色器,通过修改模型顶点的坐标来实现了压平。 废话不多说,下面上代码: /*** class* description 3dtiles模型压平*/ class Flat…

leetcode 每日一题 2024年01月11日 构造有效字符串的最少插入数

题目 2645. 构造有效字符串的最少插入数 给你一个字符串 word ,你可以向其中任何位置插入 “a”、“b” 或 “c” 任意次,返回使 word 有效 需要插入的最少字母数。 如果字符串可以由 “abc” 串联多次得到,则认为该字符串 有效 。 示例 …

辞旧岁,赢新篇,创维汽车召开年度会议立足过去,展望未来

为辞旧迎新,再创辉煌,创维汽车于1月4日-5日召开了事业部年度会议。本次会议将23年整体运营情况作出总结并对新一年的发展作出了目标规划。创维集团、创维汽车创始人黄宏生先生,开沃新能源汽车集团执行董事兼首席运营官诸萍女士,创…

记录一次华为云服务器扩容系统磁盘

转载说明:如果您喜欢这篇文章并打算转载它,请私信作者取得授权。感谢您喜爱本文,请文明转载,谢谢。 1. 扩容步骤 1.1 在华为云控制台操作磁盘扩容 1.2 服务器上操作扩容步骤 1)fdisk -l 查看扩容情况,确认…

git: Updates were rejected because the tip of your current branch is behind

一、报错含义 由于本地分支的tip落后远程分支,push操作被拒绝。 二、产生原因 我再本地拉去了新的分支并未同步到远程仓库,在新分支进行开发,由于前几天同步也创建了该分支并同步到了远程仓库,导致我本次push失败 三、解决方…