[PyTorch][chapter 41][卷积网络实战-LeNet5]

前言

    这里结合前面学过的LeNet5 模型,总结一下卷积网络搭建,训练的整个流程

目录:

    1: LeNet-5 

    2:    卷积网络总体流程

    3:  代码


一  LeNet-5

      LeNet-5是一个经典的深度卷积神经网络,由Yann LeCun在1998年提出,旨在解决手写数字识别问题,被认为是卷积神经网络的开创性工作之一。该网络是第一个被广泛应用于数字图像识别的神经网络之一,也是深度学习领域的里程碑之一

参数

输出shape

输入层

[batch,channel,32,32]

  C1(卷积层) 

6@5x5 卷积核 ,stride=1 ,padding=0

[batch,6,28,28]

  S2(池化层) 

kernel_size=2,stride=2,padding=0

[batch,6,14,14]

  C3(卷积层)


 

16@5x5 卷积核,stride=1,padding=0

[batch,16,10,10]

 S4(池化层) 

kernel_size=2,stride=2,padding=0

[batch,16,5,5]

  C5(卷积层)


 

120@5x5卷积核,stride=1padding=0

[batch,120,1,1]

 F6-全连接层 

nn.Linear(in_features=120,  out_features=84)

[batch,120]

 Output-全连接层 

nn.Linear(in_features=120,  out_features=10)

[batch,10]


二 卷积网络的总体流程

     

2.1、nn.Module建立神经网络模型
          model = LeNet5()

          

2.2、建立此网络的可学习的参数,以及更新规则
       optimizer = optim.Adam(model.Parameters(), lr=1e-3) 

        梯度更新的公式

2.3、构建损失函数

        损失函数模型
        criteon = nn.CrossEntropyLoss() 

2.4    前向传播

      logits = model(x)

       根据现有的权重系数,预测输出

2.5   反向传播

      optimizer.zero_grad() #先将梯度归零w_grad
      loss.backward()       #反向传播计算得到每个参数的梯度值w_grad

      通过当前的loss ,计算梯度

2.6   利用optim 更新权重系数

       optimizer.step() #更新权重系数W

       利用优化器更新权重系数
          

        


  三  代码 

# -*- coding: utf-8 -*-
"""
Created on Thu Jun 15 14:32:54 2023

@author: chengxf2
"""
import torch
from torch import nn
from torch.nn import functional as F 
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim 
import ssl


class  LeNet5(nn.Module):
    
    
    """
    for cifar10 dataset
    """
    
    def __init__(self):
        
        super(LeNet5, self).__init__()
        
        self.conv_unit = nn.Sequential(
            
            #卷积层1 x:[b,3,32,32] => [b,6, 30,30]
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5,stride=1,padding=0),
            #池化层1
            nn.MaxPool2d(kernel_size=2,stride=2, padding =0),
            
            #卷积层2  
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5,stride=1, padding=0),
            #池化层2
            nn.MaxPool2d(kernel_size=2,stride=2, padding =0)
            #x:[b,16,5,5]
            )
        
        self.flatten = nn.Flatten(start_dim =1, end_dim = -1)
        
        self.fc_unit = nn.Sequential(
              nn.Linear(in_features=16*5*5, out_features=120),
              nn.ReLU(),
              nn.Linear(in_features=120, out_features=84),
              nn.ReLU(),
              nn.Linear(in_features=84, out_features=10)
              )
       
        
    
    def forward(self, x):
        '''
        

        Parameters
        ----------
        x : 
            [batch,channel=3, width=32, height=32].

        Returns
        -------
        out : 
            DESCRIPTION.

        '''
        #[b,3,32,32] =>[b,16,5,5]
        out = self.conv_unit(x)
        
        #print("\n 卷积层输出 :",out.shape)
        #[b,16,5,5]=>[b,16*5*5]
        out = self.flatten(out)
        #print("\n flatten层输出 :",out.shape)
        #[b,400]=>[b,10]
        out = self.fc_unit(out)
        #print("\n 全连接层输出 :",out.shape)
        
        #pred = F.softmax(out,dim=1)
        return out
            

            
def train():
    
    x = torch.randn(8,3,32,32)
    net = LeNet5()
    
    out = net(x)
    
    print(out.shape)
               

def main():
    
    batchSize =32 
    maxIter = 10
    dataset_trans = transforms.Compose([transforms.ToTensor(),transforms.Resize((32,32))]) 
    imgDir='./data'
    print("\n ---beg----")
    cifar_train = datasets.CIFAR10(root= imgDir,train=True, transform= dataset_trans,download =False) 
    cifar_test =  datasets.CIFAR10(root= imgDir,train=False,transform= dataset_trans,download =False) 
    train_data = DataLoader(cifar_train, batch_size=batchSize,shuffle=True)
    test_data = DataLoader(cifar_test, batch_size=batchSize,shuffle=True)
   
    print("\n --download finsh---")
    device = torch.device('cuda')
    # DataLoader迭代产生训练数据提供给模型 
    model = LeNet5().to(device)
    
    criteon = nn.CrossEntropyLoss() #前向传播计算loss
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) #反向传播
    
    for epoch in range(maxIter):
       
       for batchindex,(x,label) in enumerate(train_data):
          
          #x: [b,3,32,32]
          #label: [b]
          x,label = x.to(device),label.to(device)
          
          logits = model(x)
          loss = criteon(logits, label)
          
          #backpop
          optimizer.zero_grad()
          loss.backward()
          optimizer.step() #更新梯度
          
          if batchindex%500 ==0:
              print('batchindex {}, loss {}'.format(batchindex, loss.item()))
    
       model.eval()
       total_correct =0.0
       total_num = 0.0
       with torch.no_grad():
           
           for batchindex,(x,label) in enumerate(test_data):
               x,label = x.to(device),label.to(device)
               logits = model(x)
               pred = logits.argmax(dim=1)
               
               total_correct += torch.eq(pred, label).float().sum()
               total_num += x.size(0)
           acc = total_correct/total_num
           print('\n epoch: {} ,acc: {}  total_num: {}'.format(epoch, acc, total_num))
           
           
           

            
          
          
      

    
if __name__ == "__main__":
    
     main()
    
    
    

因为不是灰度图,训练10轮,acc 只有 epoch: 9 ,acc: 0.6310999989509583  total_num: 10000.0

可以把卷积核调整小一点

参考:

https://mp.csdn.net/mp_blog/creation/editor/131209651

课时79 卷积神经网络训练_哔哩哔哩_bilibili

课时77 卷积神经网络实战-1_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/34936.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

虹科教程 | Linux网络命名空间与虹科PROFINET协议栈的GOAL中间件结合使用

前言 PROFINET是由PI推出的开放式工业以太网标准,它使用TCP/IP等IT标准,并由IEC 61158和IEC 61784 标准化,具有实时功能,并能够无缝集成到现场总线系统中。凭借其技术的开放性、灵活性和性能优势,PROFINET可应用于过程…

动态规划-杨辉三角

动态规划-杨辉三角 1 [杨辉三角]1.1 给定一个非负整数 numRows,生成「杨辉三角」的前 numRows 行。1.2 示例1.2.1 示例 1:1.2.2 示例 2:1.2.3 提示: 1.3 算法解决方法1.3.1 算法解题思路1.3.1.1 确定状态1.3.1.2 转移方程1.3.1.3 初始条件以及边界情况1.3.1.4 计算顺…

【CANoe示例分析】PythonCAPL_Call_Demo

该工程由Vector官方提供,目的是演示Python如何调用CAPL文件里的自定义函数。里面除了CANoe工程文件外,还有python文件和CAPL: 提供了两种CANoe版本的工程文件,选择其中一种打开即可。 首先我们要确定CAPL文件AnalyseFunctions.can在CANoe工程内的什么地方?首先想到的是Si…

sqlserver收缩数据库

1.收缩数据库 首先收缩的前提是需要有可用空间如下图,没有可用空间无法收缩数据库 2.减小数据库大小 通过链接: 查询数据库中各表的大小 如果查询的比较大而且无用的数据可以直接把表结构给拿出来,然后删除该表空间就直接释放出来了 3.收缩文件 我…

2023年 vue使用腾讯地图搜索、关键字输入提示、地点显示

先看结果 vue 在public文件下的index.html文件中引入&#xff1a; <script src"//map.qq.com/api/js?v2.exp&key你自己的key"></script><script src"https://map.qq.com/api/gljs?v1.exp&librariesservice&key你自己的key"&…

计算机网络编程 | 多路I/O转接服务器

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和…

Django框架-6

向服务器传参 通过url - path传参 path(articles/<int:year>/<int:month>/<slug:slug>/, views.article_detail),查询字符串方式传参 http://localhost:8000?key1value1&key2value2 ;&#xff08;body&#xff09;请求体的方式传参&#xff0c;比如文…

vue2【监听器】

目录 1&#xff1a;监听器的作用 2&#xff1a;语法格式 3&#xff1a;示例 4&#xff1a;应用场景 4.1&#xff1a;axios发送请求 4.2&#xff1a;JQuery发送请求 5&#xff1a;监听器的格式&#xff1a; 5.1&#xff1a;函数格式的监听器&#xff1a; 缺点一&#x…

基于VUE3+Layui从头搭建通用后台管理系统(前端篇)三:找回密码界面及对应功能实现

一、本章内容 本章实现找回密码功能,包括短信验证码找回、邮箱验证码找回等功能,并通过node-send-email发送邮箱验证码,实现找回密码界面、接口等功能。 1. 详细课程地址: 待发布 2. 源码下载地址: 待发布 二、界面预览 三、开发视频

【高端设计】DDR4设计方法与仿真分析(一)

本文主要介绍了DDR4设计方法与仿真分析&#xff0c;并示范SIwave如何做DDR4的瞬时眼图、SSN、on-die de-cap影响、DBI耗电分析与规范性测试。 1.DDR4和DDR3的区别 1.1 DDR4传输速度与带宽增加 DDR3 1600/1866MHz -> DDR4 1866/3200MHz DDR3采用多点分支单流架构&#xff…

HTML基础

推荐W3school、developer.mozilla.org学习 文章目录 前言标签html标签标题标签h1-h6段落标签p换行标签br超链接标签a锚点定位图像标签img列表标签有序列表ol无序列表ul定义列表dl 表格标签表格列合并表单标签input&#xff1a;输入框文本框密码框单选按钮复选框上传文件按钮下拉…

Elasticsearch:文档版本控制和乐观并发控制

在今天的文章中&#xff0c;我来详细描述一下 Elasticsearch 文档的版本控制以及如何更新文档。你也可以阅读我之前的文章 “Elasticsearch&#xff1a;深刻理解文档中的 verision 及乐观并发控制”。 版本控制 我们知道 Elasticsearch 的每个文档都有一个相对应的版本。这个版…

html前端输入框模糊查询

1、一个页面内多个模糊查询情况&#xff1a; <!DOCTYPE html> <html> <head> <meta charset"UTF-8" /> <meta name"viewport" content"widthdevice-width, initial-scale1.0, user-scalable0, minimum-scale1.0, maximum-…

geoserver加载arcgis server瓦片地图显示异常问题处理

1.全能地图下载的瓦片conf.xml格式有问题首先要修改格式&#xff0c;conf.cdi文件也需要修改格式&#xff0c;修改为UTF-8或者UTF-8无BOM编码(不同的notepadd显示不同) 2. 下载的conf.xml坐标系默认从最小级别开始&#xff0c;一定要把前几级也补全&#xff0c;从0级开始 <L…

抖音SEO矩阵源码开发(一)

前言&#xff1a; 1.抖音SEO矩阵系统源码开发 是一项技术密集型工作&#xff0c;需要对大数据处理、人工智能等领域有深入了解。该系统开发过程中需要用到多种编程语言在服务器上安装LNMP环境&#xff0c;包括Linux操作系统、Nginx、MySQL、PHP等&#xff0c;如Java、Python等…

GitHub打不开的解决方案(超简单)

在国内&#xff0c;github官网经常面临打不开或访问极慢的问题&#xff0c;不挂梯子&#xff08;VPN&#xff0c;飞机&#xff0c;魔法&#xff09;使用体验极差&#xff0c;那有什么好办法解决GitHub官网访问不了的问题&#xff1f;今天小布教你几招轻松访问github官网。 git…

自定义MVC工作原理

目录 一、MVC二、MVC的演变2.1 极易MVCController层——Servletview层——JSP缺点&#xff1a;Servlet过多、代码冗余 2.2 简易MVCController层——Servletview层——JSP缺点&#xff1a;在Servlet中if语句冗余 2.3普易MVCController层——Servletview层——JSP缺点&#xff1a…

CentOS Linux的最佳替代方案(二)_AlmaLinux OS 8.6基础安装教程

文章目录 CentOS Linux的最佳替代方案&#xff08;二&#xff09;_AlmaLinux OS 8.6基础安装教程一 AlmaLinux介绍和发展历史二 AlmaLinux基础安装2.1 下载地址2.2 安装过程 三 AlmaLinux使用3.1 关闭selinux/firewalld3.2 替换默认源3.3 安装一些必要工具 CentOS Linux的最佳替…

uniapp - [全端兼容] 多选弹框选择器,弹框形式的列表多选选择器组件插件(底部弹框式列表多选功能,支持数据回显、动态数据、主题色等配置)

前言 网上的教程都太乱了,各种不兼容且 BUG 太多,注释也没有很难进行改造。 本文 实现了 uniapp 全端兼容的弹框多选选择器,从底部弹出列表项进行多选(可回显已选中和各种主题色、样式配置), 您可以直接复制代码,稍微改改样式就能用了。 如下图所示,数据列表(支持接口…

Centos7.9通过expect脚本批量修改H3C交换机配置

背景&#xff1a; 公司有几百台H3C二层交换机设备&#xff0c;当需要批量更改配置时非常的消耗工作量 解决&#xff1a; 通过一台Linux服务器&#xff0c;编写shell脚本&#xff0c;模拟Telnet至各台交换机&#xff0c;让一切变的很容易 1.首先在安装Telnet服务前需要检测centO…
最新文章