pytorch04:网络模型创建

目录

  • 一、模型创建过程
    • 1.1 以LeNet网络为例
    • 1.2 LeNet结构
    • 1.3 nn.Module
  • 二、网络层容器(Containers)
    • 2.1 nn.Sequential
      • 2.1.1 常规方法实现
      • 2.1.2 OrderedDict方法实现
    • 2.2 nn.ModuleList
    • 2.3 nn.ModuleDict
    • 2.4 三种容器构建总结
  • 三、AlexNet网络构建

一、模型创建过程

在这里插入图片描述

1.1 以LeNet网络为例

在这里插入图片描述

网络代码如下:

class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()  # 调用父类方法,作用是调用nn.Module类的构造函数,
        # 确保LeNet类被正确地初始化,并继承了nn.Module 的所有属性和方法
        self.conv1 = nn.Conv2d(3, 6, 5) # 卷积层
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

1.2 LeNet结构

在这里插入图片描述

LeNet:conv1–>pool1–>conv2–>pool2–>fc1–>fc2–>fc3
在这里插入图片描述

1.3 nn.Module

Module是nn模块中的功能,nn模块还有Parameter、functional等模块。
在这里插入图片描述
nn.Module主要有以下参数:
• parameters : 存储管理nn.Parameter类
• modules : 存储管理nn.Module类
• buffers:存储管理缓冲属性,如BN层中的running_mean

二、网络层容器(Containers)

在这里插入图片描述

2.1 nn.Sequential

nn.Sequential 是 nn.module的容器,也是最常用的容器,用于按顺序包装一组网络层
• 顺序性:各网络层之间严格按照顺序构建
• 自带forward():自带的forward里,通过for循环依次执行前向传播运算

2.1.1 常规方法实现

LeNet网络由两部分构成,中间的卷积池化特征提取部分(features),以及最后的分类部分(classifier)。
在这里插入图片描述
具体代码如下:

class LeNetSequential(nn.Module):
    def __init__(self, classes):
        super(LeNetSequential, self).__init__()
        self.features = nn.Sequential(  #特征提取部分
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),)

        self.classifier = nn.Sequential(  #分类部分
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes),)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

打印网络层:
在这里插入图片描述

2.1.2 OrderedDict方法实现

使用有序字典的方法构建Sequential
代码如下:

class LeNetSequentialOrderDict(nn.Module):
    def __init__(self, classes):
        super(LeNetSequentialOrderDict, self).__init__()

        self.features = nn.Sequential(OrderedDict({
            'conv1': nn.Conv2d(3, 6, 5),
            'relu1': nn.ReLU(inplace=True),
            'pool1': nn.MaxPool2d(kernel_size=2, stride=2),

            'conv2': nn.Conv2d(6, 16, 5),
            'relu2': nn.ReLU(inplace=True),
            'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
        }))

        self.classifier = nn.Sequential(OrderedDict({
            'fc1': nn.Linear(16 * 5 * 5, 120),
            'relu3': nn.ReLU(),

            'fc2': nn.Linear(120, 84),
            'relu4': nn.ReLU(inplace=True),

            'fc3': nn.Linear(84, classes),
        }))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

先看一下Sequential函数中init初始化的两种方法,当我们使用OrderedDict方法时,会进行判断,使用self.add_module(key, module)方法将字典中的key和value取出来添加到Sequential中。

class Sequential(Module):
    def __init__(self, *args):
        super().__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

通过这种方法构建可以给每一网络层添加一个名称,网络输出结果如下:
在这里插入图片描述

2.2 nn.ModuleList

nn.ModuleList是 nn.module的容器,用于包装一组网络层,以迭代方式调用网络层
主要方法:
• append():在ModuleList后面添加网络层
• extend():拼接两个ModuleList
• insert():指定在ModuleList中位置插入网络层

使用列表生成式,通过一行代码就能构建20个网络层。
代码演示:

class ModuleList(nn.Module):
    def __init__(self):
        super(ModuleList, self).__init__()
        # 使用列表生成式构建20个全连接层,每个全连接层10个神经元的网络
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])

    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
        return x


net = ModuleList()

2.3 nn.ModuleDict

nn.ModuleDict是 nn.module的容器,用于包装一组网络层,以索引方式调用网络层,可以用过参数的形式选取想要调用的网络层。
主要方法:
• clear():清空ModuleDict
• items():返回可迭代的键值对(key-value pairs)
• keys():返回字典的键(key)
• values():返回字典的值(value)
• pop():返回一对键值,并从字典中删除

代码展示,只选取conv和relu两个网络层:

class ModuleDict(nn.Module):
    def __init__(self):
        super(ModuleDict, self).__init__()
        self.choices = nn.ModuleDict({
            'conv': nn.Conv2d(10, 10, 3),
            'pool': nn.MaxPool2d(3)
        })

        # 激活函数
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'prelu': nn.PReLU()
        })

    def forward(self, x, choice, act):  # 传入两个参数 用来选择网络层
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x
net = ModuleDict()
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu')  #只选取conv和relu两个网络层。
print(output)

2.4 三种容器构建总结

• nn.Sequential:顺序性,各网络层之间严格按顺序执行,常用于block构建
• nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建
• nn.ModuleDict:索引性,常用于可选择的网络层

三、AlexNet网络构建

AlexNet:2012年以高出第二名10多个百分点的准确率获得ImageNet分类任务冠军,开创了卷积神经网络的新时代
AlexNet特点如下:

  1. 采用ReLU:替换饱和激活函数,减轻梯度消失
  2. 采用LRN(Local Response Normalization):对数据归一化,减轻梯度消失
  3. Dropout:提高全连接层的鲁棒性,增加网络的泛化能力
  4. Data Augmentation:TenCrop,色彩修改

网络结构图如下:
在这里插入图片描述
构建代码:

import torch.nn as nn
import torch
from torchsummary import summary
# 定义一个名为AlexNet的神经网络模型,继承自nn.Module基类
class AlexNet(nn.Module):
    # 构造函数,初始化网络的参数
    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
        # 调用父类的构造函数
        super().__init__()
        # 定义神经网络的特征提取部分,包含多个卷积层和池化层
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),  # 输入通道3,输出通道64,卷积核大小11x11,步长4,填充2
            nn.ReLU(inplace=True),  # 使用ReLU激活函数,inplace=True表示原地操作,节省内存
            nn.MaxPool2d(kernel_size=3, stride=2),  # 最大池化层,核大小3x3,步长2
            nn.Conv2d(64, 192, kernel_size=5, padding=2),  # 输入通道64,输出通道192,卷积核大小5x5,填充2
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        # 定义自适应平均池化层,将输入的任意大小的特征图池化为固定大小6x6
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        # 定义分类器部分,包含全连接层和Dropout层
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),  # 使用Dropout进行正则化,随机丢弃一部分神经元以防止过拟合
            nn.Linear(256 * 6 * 6, 4096),  # 输入大小为256*6*6,输出大小为4096
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),  # 最后的全连接层输出类别数
        )

    # 前向传播函数,定义数据在网络中的传播过程
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)  # 特征提取
        x = self.avgpool(x)  # 平均池化
        x = torch.flatten(x, 1)  # 将特征图展平成一维向量
        x = self.classifier(x)  # 分类器
        return x
 if __name__ == '__main__':
    net = AlexNet().cuda()
    summary(net, (3, 256, 256))

打印出的网络结构图如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/288162.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

短剧分销系统搭建,打造新的蓝海项目

近一年来,短剧占据了当下大众的碎片化时间,各大影视公司也纷纷加入到了短剧行业中。2023年一整年短剧的规模已经达到了三百多亿元,发展非常快。目前,短剧作为一种新的商业模式,已经受到了广泛认可,也为创业…

Python在金融大数据分析中的AI应用实战

💂 个人网站:【 海拥】【神级代码资源网站】【办公神器】🤟 基于Web端打造的:👉轻量化工具创作平台💅 想寻找共同学习交流的小伙伴,请点击【全栈技术交流群】 随着人工智能时代的到来,Python作为…

app store里面的构建版本在线上传

开发苹果ios应用,无论是用原生开发、用hbuilderx开发还是用其他h5框架开发的app,都需要将打包好的ipa文件上传到app store。 在上架app store的过程中,我们会遇到下图的这样一个问题: 就是它要求我们上传一个构建版本&#xff0c…

基于SSM(非maven)的教室预约管理系统——有报告(Javaweb)

项目简介 本项目为基于SSM(非maven)的教室预约管理系统,本项目主要分为二种角色:用户,管理员 管理员拥有功能:教室信息管理、预约审核管理、预约记录查询、用户注册管理、修改个人信息、退出登录等 用户…

为团队进行文档赋能

大家好,才是真的好。 说来也巧,最近看一个论坛,有人问他们在公司内网管理接收到的外部发文,请问有什么办法工具能够快速的进行管理,在需要的时候供给大家搜索和查看。很多人提了不同的办法,比如说用文件共…

JavaBean

学习目的与要求 熟练掌握<jsp:useBean>、<jsp:setProperty>、<jsp:getProperty>等JSP的操作指令。 本章主要内容 编写JavaBean在JSP中使用JavaBean 一个JSP页面通过使用HTML标记为用户显示数据&#xff08;静态部分&#xff09;&#xff0c;页面中变量的…

【每日一题】2487. 从链表中移除节点-2024.1.3

题目&#xff1a; 2487. 从链表中移除节点 给你一个链表的头节点 head 。 移除每个右侧有一个更大数值的节点。 返回修改后链表的头节点 head 。 示例 1&#xff1a; 输入&#xff1a;head [5,2,13,3,8] 输出&#xff1a;[13,8] 解释&#xff1a;需要移除的节点是 5 &…

LeetCode做题总结 15. 三数之和(未完)

不会做&#xff0c;参考了代码随想录和力扣官方题解&#xff0c;对此题进行整理。 代码思路 思想&#xff1a;利用双指针法&#xff0c;对数组从小到大排序。先固定一个数&#xff0c;找到其他两个。 &#xff08;1&#xff09;首先对数组从小到大排序。 &#xff08;2&…

【vue/uniapp】使用 uni.chooseImage 和 uni.uploadFile 实现图片上传(包含样式,可以解决手机上无法上传的问题)

引入&#xff1a; 之前写过一篇关于 uview 1.x 版本上传照片 的文章&#xff0c;但是发现如果是在微信小程序的项目中嵌入 h5 的模块&#xff0c;这个 h5 的项目使用 u-upload 的话&#xff0c;图片上传功能在电脑上正常&#xff0c;但是在手机的小程序上测试就不会生效&#x…

声明式管理方(yaml)文件

声明式管理方(yaml)文件: 1、适合对资源的修改操作 2、声明式管理依赖于yaml文件&#xff0c;所有的内容都在yaml文件当中。 3、编辑好的yaml文件需要依靠陈述是还是要依靠陈述式的命令发布到k8s集群当中 create只能创建&#xff0c;不能更新。从指定yaml文件中读取配置&#…

【华为机试】2023年真题B卷(python)-考古问题

一、题目 题目描述&#xff1a; 考古问题&#xff0c;假设以前的石碑被打碎成了很多块&#xff0c;每块上面都有一个或若干个字符&#xff0c;请你写个程序来把之前石碑上文字可能的组合全部写出来&#xff0c;按升序进行排列。 二、输入输出 三、示例 示例1: 输入输出示例仅供…

java练习题之常用类Object类,包装类

常用类 应用知识点&#xff1a; Object类 包装类 习题&#xff1a; 1&#xff1a;(Object 类)仔细阅读以下代码&#xff0c;写出程序运行的结果&#xff1b;并简述 和 equals 的区别。 true false 是判断两个变量或实例是不是指向同一个内存空间。 比较两个引用类型的地址&…

声明式管理方法

声明式管理方法&#xff08;yaml&#xff09;文件&#xff1a; 1&#xff0c;适合对资源的修改操作 2&#xff0c;声明式管理依赖于yaml文件&#xff0c;所有的内容都在yamI文件当中 3&#xff0c;编辑好的yaml文件&#xff0c;还是要依靠陈述式命令发布到k8s集群当中 发布的…

Spring见解 1

1.Spring概述 1.1.Spring介绍 ​ Spring是轻量级Java EE应用开源框架&#xff08;官网&#xff1a; http://spring.io/ &#xff09;&#xff0c;它由Rod Johnson创为了解决企业级编程开发的复杂性而创建 1.2.简化应用开发体现在哪些方面&#xff1f; IOC 解决传统Web开发中…

SpringBoot—支付—微信

一、支付流程 1.1、支付准备 1.获取商户号 微信商户平台 申请成为商户 > 提交资料 > 签署协议 > 获取商户号 2.获取 AppID 微信公众平台 注册服务号 > 服务号认证 > 获取APPID > 绑定商户号 3.申请商户证书 登录商户平台 > 选择 账户中心 > 安全…

Kali Linux实现UEFI和传统BIOS(Legacy)引导启动

默认Kali linux安装会根据当前启动的引导模式进行安装 例:以UEFI引导启动安装程序,安装后仅能在UEFI引导模式下进入系统 安装Kali系统 这边基于VirtualBox虚拟机镜像实战操作 首先创建一个Kali虚拟机 这里需要注意,把启动 EFI (只针对某些操作系统)选项勾选上,内存、处理器…

I.MX6ULL_Linux_驱动篇(52)linux CAN驱动

CAN 是目前应用非常广泛的现场总线之一&#xff0c;主要应用于汽车电子和工业领域&#xff0c;尤其是汽车领域&#xff0c;汽车上大量的传感器与模块都是通过 CAN 总线连接起来的。 CAN 总线目前是自动化领域发展的热点技术之一&#xff0c;由于其高可靠性&#xff0c; CAN 总线…

酷雷曼精彩亮相CMC 2023中国元宇宙大会,助力云上VR直播

12月23日&#xff0c;2023中关村论坛系列活动——CMC 2023中国元宇宙大会在石景山首钢园冰壶馆成功举办。酷雷曼VR作为元宇宙领域代表企业之一受邀出席会议&#xff0c;分享元宇宙技术研发成果及应用方案&#xff0c;并为大会提供VR直播技术支持。 大咖云集&#xff0c;共商元宇…

数据库进阶教学——主从复制(Ubuntu22.04主+Win10从)

目录 一、概述 二、原理 三、搭建 1、备份数据 2、主库配置Ubuntu22.04 2.1、设置阿里云服务器安全组 2.2、修改配置文件 /etc/my.cnf 2.3、重启MySQL服务 2.4、登录mysql&#xff0c;创建远程连接的账号&#xff0c;并授予主从复制权限 2.5、通过指令&#xff0c;查…

进程终结之道:kill与pskill的神奇战斗

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 进程终结之道&#xff1a;kill与pskill的神奇战斗 前言基本用法kill命令&#xff1a;基础语法&#xff1a;选项&#xff1a;示例&#xff1a; pskill命令&#xff1a;基础语法&#xff1a;选项&#x…
最新文章