pytorch中nn.Sequential详解

1 nn.Sequential概述

1.1 nn.Sequential介绍

nn.Sequential是一个序列容器,用于搭建神经网络的模块被按照被传入构造器的顺序添加到容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。

因此,Sequential可以看成是有多个函数运算对象,串联成的神经网络,其返回的是Module类型的神经网络对象。

1.2 nn.Sequential的本质作用

与一层一层的单独调用模块组成序列相比,nn.Sequential() 可以允许将整个容器视为单个模块(即相当于把多个模块封装成一个模块),forward()方法接收输入之后,nn.Sequential()按照内部模块的顺序自动依次计算并输出结果。这就意味着我们可以利用nn.Sequential() 自定义自己的网络层。

示例代码:

from torch import nn


class net(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(net, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(in_channel, in_channel / 4, kernel_size=1),
                                    nn.BatchNorm2d(in_channel / 4),
                                    nn.ReLU())
        self.layer2 = nn.Sequential(nn.Conv2d(in_channel / 4, in_channel / 4),
                                    nn.BatchNorm2d(in_channel / 4),
                                    nn.ReLU())
        self.layer3 = nn.Sequential(nn.Conv2d(in_channel / 4, out_channel, kernel_size=1),
                                    nn.BatchNorm2d(out_channel),
                                    nn.ReLU())
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        return x

上边的代码,我们通过nn.Sequential()将卷积层,BN层和激活函数层封装在一个层中,输入x经过卷积、BN和ReLU后直接输出激活函数作用之后的结果。

1.3 nn.Sequential源码

def __init__(self, *args):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

nn.Sequential()首先判断接收的参数是否为OrderedDict类型,如果是的话,分别取出OrderedDict内每个元素的key(自定义的网络模块名)和value(网络模块),然后将其通过add_module方法添加到nn.Sequrntial()中。

    # NB: We can't really type check this function as the type of input
    # may change dynamically (as is tested in
    # TestScript.test_sequential_intermediary_types).  Cannot annotate
    # with Any as TorchScript expects a more precise type
    def forward(self, input):
        for module in self:
            input = module(input)
        return input

 调用forward()方法进行前向传播时,for循环按照顺序遍历nn.Sequential()中存储的网络模块,并以此计算输出结果,并返回最终的计算结果。

 

1.3 nn.Sequential与其它容器的区别

2 使用nn.Sequential定义网络

2.1 顺序添加网络模块到容器中

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(28 * 28, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
    nn.Softmax(dim=1)
)
print("model:", model)
print("model.parameters:", model.parameters)

x_input = torch.randn(2, 28, 28, 1)
print("x_input:", x_input)
print("x_input.shape:", x_input.shape)

y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential(
  (0): Linear(in_features=784, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=10, bias=True)
  (3): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential(
  (0): Linear(in_features=784, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=10, bias=True)
  (3): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.1127, 0.0652, 0.1399, 0.0973, 0.1085, 0.0859, 0.1193, 0.1048, 0.0865,
         0.0800],
        [0.0986, 0.0955, 0.0927, 0.0765, 0.0782, 0.1004, 0.1171, 0.1605, 0.0883,
         0.0922]], grad_fn=<SoftmaxBackward0>)

2.2 包含神经网络模块的OrderedDict传入容器中

import torch
import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
print("model:", model)
print("model.parameters:", model.parameters)

x_input = torch.randn(2, 28, 28, 1)
print("x_input.shape:", x_input.shape)

y_pred = model.forward(x_input.view(x_input.size()[0], -1))
print("y_pred:", y_pred)

运行代码显示:

model: Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): ReLU()
  (out): Linear(in_features=32, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)
model.parameters: <bound method Module.parameters of Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): ReLU()
  (out): Linear(in_features=32, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)>
x_input.shape: torch.Size([2, 28, 28, 1])
y_pred: tensor([[0.0836, 0.1185, 0.1422, 0.0801, 0.0817, 0.0870, 0.0948, 0.1099, 0.1131,
         0.0892],
        [0.0772, 0.0933, 0.1312, 0.1135, 0.1214, 0.0736, 0.1461, 0.0711, 0.0908,
         0.0818]], grad_fn=<SoftmaxBackward0>)

3 nn.Sequential网络操作

3.1 索引查看子模块

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
print("index0:", model[0])
print("index1:", model[1])
print("index2:", model[2])

运行代码显示:

index0: Linear(in_features=784, out_features=32, bias=True)
index1: ReLU()
index2: Linear(in_features=32, out_features=10, bias=True)

3.2 修改子模块

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
model[1] = nn.Sigmoid()
print(model)

运行代码显示:

Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): Sigmoid()
  (out): Linear(in_features=32, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)

3.3 添加子模块

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
model.append(nn.Linear(10, 2))
print(model)

运行代码显示:

Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): ReLU()
  (out): Linear(in_features=32, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
  (4): Linear(in_features=10, out_features=2, bias=True)
)

3.4 删除子模块

import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential(OrderedDict([('h1', nn.Linear(28*28, 32)),
                                     ('relu1', nn.ReLU()),
                                     ('out', nn.Linear(32, 10)),
                                     ('softmax', nn.Softmax(dim=1))]))
del model[2]
print(model)

运行代码显示:

Sequential(
  (h1): Linear(in_features=784, out_features=32, bias=True)
  (relu1): ReLU()
  (softmax): Softmax(dim=1)
)

3.5 嵌套子模块

import torch.nn as nn

seq_1 = nn.Sequential(nn.Linear(15, 10), nn.ReLU(), nn.Linear(10, 5))
seq_2 = nn.Sequential(nn.Linear(25, 15), nn.Sigmoid(), nn.Linear(15, 10))
seq_3 = nn.Sequential(seq_1, seq_2)
print(seq_3)

运行代码显示:

Sequential(
  (0): Sequential(
    (0): Linear(in_features=15, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=5, bias=True)
  )
  (1): Sequential(
    (0): Linear(in_features=25, out_features=15, bias=True)
    (1): Sigmoid()
    (2): Linear(in_features=15, out_features=10, bias=True)
  )
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/257833.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AWS 知识二:AWS同一个VPC下的ubuntu实例通过ldapsearch命令查询目录用户信息

前言&#xff1a; 前提&#xff1a;需要完成我的AWS 知识一创建一个成功运行的目录。 主要两个重要&#xff1a;1.本地windows如何通过SSH的方式连接到Ubuntu实例 2.ldapsearch命令的构成 一 &#xff0c;启动一个新的Ubuntu实例 1.创建一个ubuntu实例 具体创建实例步骤我就不…

useConsole的封装,vue,react,htmlscript标签,通用

之前用了接近hack的方式实现了console的封装&#xff0c;目标是获取console.log函数的执行&#xff08;调用栈所在位置&#xff09;所在的代码行数。 例如以下代码&#xff0c;执行window.mylog(1)时候&#xff0c;console.log实际是在匿名的箭头函数()>{//这里执行的} con…

通过https协议访问Tomcat部署并使用Shiro认证的应用跳转登到录页时协议变为http的问题

问题描述&#xff1a; 在最近的一个项目中&#xff0c;有一个存在较久&#xff0c;并且只在内部城域网可访问的一个使用Shiro框架进行安全管理的Java应用&#xff0c;该应用部署在Tomcat服务器上。起初&#xff0c;应用程序可以通过HTTP协议访问&#xff0c;一切运行都没…

力扣面试题 16.19. 水域大小(java DFS解法)

Problem: 面试题 16.19. 水域大小 文章目录 题目描述思路解题方法复杂度Code 题目描述 思路 该问题可以归纳为一类遍历二维矩阵的题目&#xff0c;此类中的一部分题目可以利用DFS来解决&#xff0c;具体到本题目&#xff08;该题目可以的写法大体不变可参看前面几个题目&#…

XZ_iOS 之 M1 M2 M3的M系列芯片的Mac苹果电脑安装cocoapods

安装的前提&#xff0c;应用程序->终端->右键-显示简介->勾选 使用Rosetta打开&#xff0c;如下图&#xff0c;然后重启终端 安装的顺序如下&#xff1a;Homebrew->rvm->ruby->cocoapods 1、安装Homebrew /bin/bash -c "$(curl -fsSL https://raw.git…

淘宝类目信息API接口获取淘宝商品分类信息API调用说明(含APIkey密钥)

cat_get-获得淘宝分类详情 item_cat_get-获得淘宝商品类目 公共参数 名称类型必须描述keyString是调用key&#xff08;点此获取&#xff09;secretString是调用密钥api_nameString是API接口名称&#xff08;包括在请求地址中&#xff09;[item_search,item_get,item_search_…

【Mac】flutter项目集成高德定位SDK,获取key

一、获取调试版安全码SHA1 1.进入当前用户文件夹下的~/.android目录 cd ~/.android2.查看 debug.keystore ls3.运行 debug.keystore keytool -list -v -keystore debug.keystore这里报错&#xff1a; The operation couldn’t be completed. Unable to locate a Java Runt…

docker 安装及配置 nginx + tomcat(四):高可用

文章目录 1. 引言2. 高可用架构3. 实际步骤3.1 虚拟机新建系统3.2 安装 keepalived3.3 配置 keepalived3.4 启动 keepalived3.5 验证高可用3.5.1 查看当前效果3.5.2 模拟灾难 4 参考 1. 引言 前情提要&#xff1a; 《docker 安装及配置 nginx tomcat&#xff08;一&#xff0…

安全运营之安全加固和运维

安全运营是一个将技术、流程和人有机结合的复杂系统工程&#xff0c;通过对已有安全产品、工具和服务产出的数据进行有效的分析&#xff0c;持续输出价值&#xff0c;解决安全问题&#xff0c;以确保网络安全为最终目标。 安全加固和运维是网络安全运营中的两个重要方面。 安全…

在本地通过 k8s 部署一个 nginx 镜像

目标 目标:通过 deployment 启动一个 nginx,并且通过浏览器访问。 目的,熟悉并学习一下 k8s 的一些特性,毕竟看文档和实操是两码事。 本地部署 k8s 简单点,也不用 minikube 和 kubeadmin,直接通过 docker desktop 部署 k8s。 下载 docker desktop 下载完成后会自动…

Linux系统之部署Linux管理面板1Panel

一、介绍 1.1简介 1Panel 是一个现代化、开源的 Linux 服务器运维管理面板。 1.2特点 快速建站&#xff1a;深度集成 Wordpress 和 Halo&#xff0c;域名绑定、SSL 证书配置等一键搞定&#xff1b; 高效管理&#xff1a;通过 Web 端轻松管理 Linux 服务器&#xff0c;包括应用管…

ios备忘录怎么导入华为 方法介绍

作为一个常常需要在不同设备间切换的人&#xff0c;我深知备忘录的重要性。那些突如其来的灵感、重要的会议提醒、甚至是生活中的琐碎小事&#xff0c;我们都习惯性地记录在备忘录里。但当我决定从iPhone转向华为时&#xff0c;一个问题困扰了我&#xff1a;如何将那些珍贵的备…

使用Axure的中继器的交互动作解决增删改查h

&#x1f3ac; 艳艳耶✌️&#xff1a;个人主页 &#x1f525; 个人专栏 &#xff1a;《产品经理如何画泳道图&流程图》 ⛺️ 越努力 &#xff0c;越幸运 目录 一、中继器的交互 1、什么是中继器的交互 2、Axure中继器的交互 3、如何使用中继器&#xff1f; 二…

CleanMyMac X 4 for Mac(Mac优化清理工具)v4.14.6中文破解版

CleanMyMac X for Mac中文破解版只需两个简单步骤就可以把系统里那些乱七八糟的无用文件统统清理掉&#xff0c;节省宝贵的磁盘空间。cleanmymac x个人认为X代表界面上的最大升级&#xff0c;功能方面有更多增加&#xff0c;与最新macOS系统更加兼容&#xff0c;流畅地与系统性…

人工智能中不可预测的潜在错误可能是灾难性的——数学解释

一、说明 有没有人研究评估AI的错误产生的后果有多么严重&#xff0c;是否存在AI分险评估机制&#xff1f;更高维度上&#xff0c;人工智能的未来是反乌托邦还是乌托邦&#xff1f;这个问题一直是争论的话题&#xff0c;各大阵营都支持。我相信我们无法准确预测这两种结果。这是…

Mybatis复习总结

MyBatis是一款优秀的持久层框架&#xff0c;用于简化JDBC的开发 MyBatis本是Apache的一个开源项目&#xff0c;2010年这个项目由apache迁移到了Google&#xff0c;并且改名为 Mybatis&#xff0c;2013年11月迁移至Github。 持久层 指的就是数据访问层&#xff0c;用来操作数…

《点云处理》 点云去噪

前言 通常从传感器&#xff08;3D相机、雷达&#xff09;中获取到的点云存在噪点&#xff08;杂点、离群点、孤岛点等各种叫法&#xff09;。噪点产生的原因有不同&#xff0c;可能是扫描到了不想要扫描的物体&#xff0c;可能是待测工件表面反光形成的&#xff0c;也可能是相…

Axure中继器完成表格的增删改查的自定义元件(三列表格与十列表格)

目录 一、中继器 1.1 定义 1.2 特点 1.3 适用场景 二、三列表格增删改查 2.1 实现思路 2.2 效果演示 三、十列表格增删改查 3.1 实现思路 3.2 效果演示 一、中继器 1.1 定义 在Axure中&#xff0c;"中继器"通常指的是界面设计中的一个元素&#xff0c;用…

第二百一十五回 如何创建单例模式

文章目录 1. 概念介绍2. 思路与方法2.1 实现思路2.2 实现方法 3. 示例代码4. 内容总结 我们在上一章回中介绍了"分享三个使用TextField的细节"沉浸式状态样相关的内容&#xff0c;本章回中将介绍 如何创建单例模式.闲话休提&#xff0c;让我们一起Talk Flutter吧。 …

03-数据结构-栈与队列

1.栈 栈和队列是两种操作受限的线性表。如上图所示显示栈的结构 栈&#xff1a;先进后出&#xff0c;入栈&#xff08;数据进入&#xff09; 和出栈&#xff08;数据出去&#xff09;均在栈顶操作。 常见栈的应用场景包括括号问题的求解&#xff0c;表达式的转换和求值&#…