【PyTorch单点知识】深入理解与应用转置卷积ConvTranspose2d模块

文章目录

      • 0. 前言
      • 1. 转置卷积概述
      • 2. `nn.ConvTranspose2d` 模块详解
        • 2.1 主要参数
        • 2.2 属性与方法
      • 3. 计算过程(重点)
        • 3.1 基本过程
        • 3.2 调整stride
        • 3.3 调整dilation
        • 3.4 调整padding
        • 3.5 调整output_padding
      • 4. 应用实例
      • 5. 总结

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

nn.ConvTranspose2d 模块是用于实现二维转置卷积(又称为反卷积)的核心组件。本文将详细介绍 ConvTranspose2d 的概念、工作原理、参数设置以及实际应用。

本文的说明参考了PyTorch的官方文档

1. 转置卷积概述

转置卷积(Transposed Convolution),有时也被称为“反卷积”(尽管严格来说它并不是真正意义上的卷积的逆运算),是一种特殊的卷积操作,常用于从较低分辨率的特征图上采样到较高分辨率的空间维度。

在诸如深度卷积生成对抗网络(DCGAN)和条件生成对抗网络(CGANs)等任务中,转置卷积被广泛用于将网络内部的紧凑特征(较小的特征)表示恢复为与原始输入尺寸相匹配或接近的(较大的特征)输出。

2. nn.ConvTranspose2d 模块详解

nn.ConvTranspose2d 是 PyTorch 中 torch.nn 模块的一部分,专门用于定义和实例化二维转置卷积层。其构造函数接受一系列参数来配置卷积行为:

2.1 主要参数
  1. in_channels (int) - 输入特征图的通道数,即前一层的输出通道数。

  2. out_channels (int) - 输出特征图的通道数,即本层产生的新特征通道数。

  3. kernel_size (inttuple) - 卷积核大小,通常是一个整数(当使用方形卷积核时)或包含两个整数的元组(分别对应卷积核的高度和宽度)。

  4. stride (inttuple, default=1) - 卷积步长,决定了卷积核在输入特征图上滑动的距离。与 kernel_size 类似,它可以是单个整数(对所有维度相同)或一个包含两个整数的元组。

  5. padding (inttuple, default=0) - 填充量,用于控制输出尺寸和保持边界信息。

  6. output_padding (inttuple, default=0) - 用于调整输出尺寸的额外填充量,仅应用于转置卷积。它在卷积计算后增加到输出边缘的额外像素数量。

  7. groups (int, default=1) - 分组卷积参数,当大于1时,输入和输出通道将被分成若干组,每组内的卷积相互独立。

  8. bias (bool, default=True) - 表示是否为该层添加可学习的偏置项。

  9. dilation (inttuple, default=1) - 卷积核元素之间的间距(膨胀率),控制卷积核中非零元素之间的距离。

  10. padding_mode (str , default=zeros) - 填充数据方式,zeros为全部填充0

  11. device (str , default=cpu) - 处理数据的设备

  12. dtype (str, default=None ) - 数据类型

2.2 属性与方法
  • .weight (Tensor) - 存储转置卷积核的权重,形状为 (out_channels, in_channels, kernel_size[0], kernel_size[1]),是可学习的模型参数。

  • .bias (Tensor) - 若 bias=True,则包含与每个输出通道关联的偏置项,形状为 (out_channels),也是可学习的参数。

  • .forward(input) - 接受输入张量 input,执行转置卷积运算并返回输出特征图。

3. 计算过程(重点)

输入输出图像一般为4维或3维,即[B, C, H, W]或[C, H, W],其中:

  • B:Batch_size,每批的样本数
  • C:channel,通道数
  • H, W:图像的高和宽

以图像高度H为例(宽度W同理),转置卷积的输出尺寸可以通过以下公式计算:

H o u t = ( H i n − 1 ) × stride − 2 × padding + dilation × ( kernel-size − 1 ) + output-padding + 1 H_{out}=(H_{in}-1) \times \text{stride} -2 \times \text{padding} + \text{dilation} \times (\text{kernel-size}-1) + \text{output-padding}+1 Hout=(Hin1)×stride2×padding+dilation×(kernel-size1)+output-padding+1

这个公式看起来比较复杂,下面我们通过实例来理解转置卷积的计算过程。

3.1 基本过程

输入原图size为[1, 2, 2],卷积核也size也为[1, 2, 2],其余参数如下:

in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=1,bias=False

计算过程:

在这里插入图片描述

容易看出,经历转置卷积后特征图会扩大,即上采样。使用代码验算:

import torch

input = torch.tensor([[[[0,1],
                        [2,3]]]],dtype=torch.float32)

ConvTrans = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=1,bias=False)
ConvTrans.weight = torch.nn.Parameter(torch.tensor([[[[ 1.1, 2.2],
          [ 3.3, 4.4]]]], dtype=torch.float32,requires_grad=True))

print(ConvTrans(input))

输出为:

tensor([[[[ 0.0000,  1.1000,  2.2000],
          [ 2.2000, 11.0000, 11.0000],
          [ 6.6000, 18.7000, 13.2000]]]], grad_fn=<ConvolutionBackward0>)
3.2 调整stride

把stride调整为2后,计算过程如下:

在这里插入图片描述
如果stride过大,则会在跳过的位置补0。例如上面的计算过程中,如果stride = 3输出则为:

在这里插入图片描述

注意,这里stride可以指定为tuple,即让横向和纵向的stride不一样,例如(1, 2),但其计算思路不变,这里直接用代码计算结果(懒得再画过程图了):

import torch

input = torch.tensor([[[[0,1],
                        [2,3]]]],dtype=torch.float32)

ConvTrans = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=(1,2), padding=0, output_padding=0,dilation=1,bias=False)
ConvTrans.weight = torch.nn.Parameter(torch.tensor([[[[ 1.1, 2.2],
          [ 3.3, 4.4]]]], dtype=torch.float32,requires_grad=True))

print(ConvTrans(input))

输出为:

tensor([[[[ 0.0000,  0.0000,  1.1000,  2.2000],
          [ 2.2000,  4.4000,  6.6000, 11.0000],
          [ 6.6000,  8.8000,  9.9000, 13.2000]]]],
       grad_fn=<ConvolutionBackward0>)
3.3 调整dilation

这个过程非常简单,可以分为2步:

  1. 把卷积核进行dilation(爆炸)处理
  2. 进行3.1基本过程

即:
在这里插入图片描述
代码验算过程如下:

import torch

input = torch.tensor([[[[0,1],
                        [2,3]]]],dtype=torch.float32)

ConvTrans_dilation2 = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=2,bias=False)
ConvTrans_dilation2.weight = torch.nn.Parameter(torch.tensor([[[[ 1.1, 2.2],
          [ 3.3, 4.4]]]], dtype=torch.float32,requires_grad=True))

print(ConvTrans_dilation2(input))

ConvTrans_dilation1 = torch.nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0, output_padding=0,dilation=1,bias=False)
ConvTrans_dilation1.weight = torch.nn.Parameter(torch.tensor([[[[1.1, 0, 2.2],
                                                                     [0, 0, 0],
                                                                     [3.3, 0, 4.4]]]], dtype=torch.float32,requires_grad=True))   #对卷积核进行dilation

print(ConvTrans_dilation1(input))
print(ConvTrans_dilation2(input) == ConvTrans_dilation1(input))

输出为:

tensor([[[[ 0.0000,  1.1000,  0.0000,  2.2000],
          [ 2.2000,  3.3000,  4.4000,  6.6000],
          [ 0.0000,  3.3000,  0.0000,  4.4000],
          [ 6.6000,  9.9000,  8.8000, 13.2000]]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[[ 0.0000,  1.1000,  0.0000,  2.2000],
          [ 2.2000,  3.3000,  4.4000,  6.6000],
          [ 0.0000,  3.3000,  0.0000,  4.4000],
          [ 6.6000,  9.9000,  8.8000, 13.2000]]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]]])
3.4 调整padding

这是一个下采样的过程,会减少输出size。具体计算方法也很简单:给输出数据减去padding。基于3.1基本过程举例说明padding = 1的情况如下:

在这里插入图片描述

3.5 调整output_padding

这个参数用于给最终输出补0,output_padding必须要比stride或者dilation小。需要注意的是output_padding补0只能补半圈,如下:

我也想不明白为什么不是补一整圈?

在这里插入图片描述

4. 应用实例

在实际使用中,nn.ConvTranspose2d 可以嵌入到神经网络结构中,用于实现上采样、特征图尺寸放大或生成与输入尺寸相似的输出。以下是一个简单的使用示例:

import torch
import torch.nn as nn

# 定义一个包含转置卷积层的简单模型
class TransposedConvModel(nn.Module):
    def __init__(self, in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1, output_padding=0):
        super().__init__()
        self.conv_transpose = nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            bias=True
        )

    def forward(self, x):
        return self.conv_transpose(x)

# 实例化模型并应用到输入数据
model = TransposedConvModel()
input_tensor = torch.randn(1, 32, 16, 16)  # (batch_size, in_channels, height, width)
output = model(input_tensor)
print("Output shape:", output.shape)

输出为:

Output shape: torch.Size([1, 64, 32, 32])

5. 总结

nn.ConvTranspose2d 是 PyTorch 中用于实现二维转置卷积的关键模块,它通过逆向的卷积操作实现了特征图的上采样和空间维度的扩大。

正确理解和配置其参数(如 kernel_sizestridepaddingoutput_padding 等),可以帮助开发者构建出适应特定任务需求的神经网络架构,特别是在图像生成、超分辨率、语义分割等需要从低分辨率特征恢复到高分辨率输出的应用场景中发挥关键作用。通过实践和调整这些参数,研究人员和工程师能够灵活地设计和优化基于转置卷积的深度学习模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/598858.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

3399 ubuntu系统启动后,gpio已被初始化问题查找

问题描述: 使用cat /sys/kernel/debug/gpio后发现,gpio-55已经被设备树初始化了。 如果要找到这个引脚的设置代码,需要一点点查找。这里记录了比较快速的办法 gpio引脚变换 gpio-55需要转换成对应的引脚编号 根据https://blog.csdn.net/ch122633/article/details/120233…

C语言实现面向对象—以LED驱动为例

点亮一个LED 常见的LED代码 分层分离思想 面向对象的LED驱动 LED左边高电平。 当LED右边为低电平时&#xff0c;LED有电流通过&#xff0c;LED亮。反之&#xff0c;LED灭 GPIO功能描述&#xff1a; 点亮LED的步骤及代码&#xff1a; 开启GPIO的时钟 配置GPIO为输出模式 …

前端数据可视化基础(折线图)

目录 前言&#xff1a; 画布&#xff1a; 折线图 (Line Chart): 前言&#xff1a; 前端中的数据可视化是指将大量数据以图形或图像的形式在前端页面上展示出来&#xff0c;以便用户能够更直观地理解和分析这些数据。数据可视化是一种强大的工具&#xff0c;它利用了人类视觉…

城市二手房数据分析与房价预测

实现功能 数据分析 二手房价格-时间分析 二手房数量-时间分析 二手房分布-区域分析 二手房户型分析 二手房朝向分析 二手房价格-区域分析 二手房热词词云 房价预测 采用合适的算法模型&#xff0c;对模型进行评估。通过输入影响因素输出预测价格。 采用技术与框架 M…

在Unity中实现分页数据显示和分页控制

参考&#xff1a;用两种简单的方式实现unity的分页效果 using System.Collections; using System.Collections.Generic; using UnityEngine; using UnityEngine.UI; using UnityEngine.Rendering.VirtualTexturing; using UnityEngine.TerrainUtils;public class PageControll…

五一反向旅游,景区“AI+视频监控”将持续助力旅游业发展

一、建设背景 每年五一劳动节出去旅游都是人挤人状态&#xff0c;这导致景区的体验感极差。今年“五一反向旅游”的话题冲上了热搜&#xff0c;好多人选择了五一之后再出去旅游&#xff0c;避开拥挤的人群&#xff0c;这个时候景区的监管力度和感知能力就更要跟上去&#xff0…

Gradio之blocks灵活搭建页面

这里写目录标题 搭建一个UI界面搭建上半部分的框架比例调节以及其他效果搭建下半部分左边部分搭建下半部分右边部分拓展-CSS的应用 使用标签搭建第二个页面示例 补充AccordionGroup() 搭建一个UI界面 搭建上半部分的框架 如下图&#xff0c;我们想要基本还原下图右边的UI界面…

AI去衣技术在动画制作中的应用

随着科技的发展&#xff0c;人工智能&#xff08;AI&#xff09;已经在各个领域中发挥了重要作用&#xff0c;其中包括动画制作。在动画制作中&#xff0c;AI去衣技术是一个重要的工具&#xff0c;它可以帮助动画师们更加高效地完成工作。 AI去衣技术是一种基于人工智能的图像…

如何自己快速的制作流程图?6个软件教你快速进行流程图制作

如何自己快速的制作流程图&#xff1f;6个软件教你快速进行流程图制作 自己制作流程图可以是项目管理、流程设计或教学展示中的重要环节。以下是六款常用的流程图制作软件&#xff0c;它们都提供了快速、简单的方式来制作流程图&#xff1a; 迅捷画图&#xff1a;这是一款非…

Azide-PEG-Azide,82055-94-5可以用于制备抗体、蛋白质、多肽等生物分子的标记物

【试剂详情】 英文名称 Azide-PEG-Azide&#xff0c;N3-PEG-N3 中文名称 叠氮-聚乙二醇-叠氮&#xff0c;聚氧乙烯二叠氮化物 CAS号 82055-94-5 外观性状 由分子量决定&#xff0c;粘稠液体或者固体。 分子量 0.4k&#xff0c;0.6k&#xff0c;1k&#xff0c;2k&#…

用友GRP A++Cloud 政府财务云 任意文件读取漏洞复现

0x01 产品简介 用友GRP A++Cloud 政府财务云系统具有多项核心功能,旨在满足各类组织的财务管理需求。首先,它提供了财务核算功能,能够全面管理企业的总账、固定资产、现金、应付应收等模块,实时掌握企业的财务状况,并通过科目管理、凭证处理、报表分析等功能为决策提供有…

启明云端ESP8266+企业微信考勤机项目,多种方式认证能防止代打

智能考勤机需要有识别功能&#xff0c;用户容量&#xff0c;记录容量限制&#xff0c;还有物联网通讯方式&#xff0c;最后衔接到云平台&#xff0c;最后就是根据具体需求来设计。 ①识别方式&#xff1a;现如今市场上的考勤机主要有人脸、指纹、IC卡和ID卡等多种识别方式。不…

虚拟机文件夹共享操作(本地访问)

新建一个文件夹 右击文件夹点击属性 找到共享 点击共享 选择本地用户共享就可以了 本地winr 输入我们图片中的格式&#xff08;IP前加 “\\” &#xff09; 会弹一个窗口&#xff0c;输入虚拟机的入户名和密码就可以共享了&#xff08;一般默认用户名都是administrator&am…

人工智能-2024期中考试

前言 人工智能期中考试&#xff0c;认真准备了但是没考好&#xff0c;结果中游偏下水平。 第4题没拿分 &#xff08;遗传算法&#xff1a;知识点在课堂上一笔带过没有细讲&#xff0c;轮盘赌算法在书本上没有提到&#xff0c;考试的时候也没讲清楚&#xff0c;只能靠猜&…

Linux进程——Linux进程与进程优先级

前言&#xff1a;在上一篇了解完一部分常见的进程状态后&#xff0c;我们先来把剩下的进程状态了解一下&#xff0c;再来进入进程优先级的学习&#xff01; 如果对前面Linux进程不太熟悉可以先阅读&#xff1a; Linux进程 本篇主要内容&#xff1a; 僵尸进程和孤儿进程 Linux进…

绘画作品3d数字云展厅提升大众的艺术鉴赏和欣赏能力

3D虚拟展厅作为未来艺术的展示途径&#xff0c;正逐渐成为文化创意产业蓬勃发展的重要引擎。这一创新形式不仅打破了传统艺术展览的局限性&#xff0c;更以其独特的魅力吸引着全球观众的目光。 3D虚拟艺术品展厅以其独特的魅力&#xff0c;助力提升大众的艺术鉴赏和欣赏能力。观…

python - rst file to html

文章目录 python - rst file to html概述笔记下载安装PyCharm最新的学习版新建虚拟环境为Conda的工程添加docutils库新建python文件&#xff0c;添加转换代码运行自己写的python文件&#xff0c;执行转换转换结果END python - rst file to html 概述 开源工程中有一个.rst文件…

自动驾驶主流芯片及平台架构(一)

零部件成本下降、中低端车竞争加剧&#xff0c;推动ADAS渗透率在中国市场快速提升&#xff0c;自主品牌ADAS装配量大幅提升 零部件成本下降、中低端车竞争加剧&#xff0c;推动ADAS渗透率在中国市场快速提升&#xff0c;自主品牌ADAS装配量大幅提升。5年前在一些高端车型上才有…

(持续更新升级)火爆的ChatGPT源码+高质量AI绘画系统+分销功能+详细图文搭建部署教程

随着人工智能技术的迅猛发展&#xff0c;智能对话和创意艺术不再是遥不可及&#xff0c;而是可以触手可及的现实。 分享一款集ChatGPT源码、高质量AI绘画系统以及强大分销功能于一体的系统源码&#xff0c;对接了大名鼎鼎的ChatGPT接口及Midjourney两个王牌接口&#xff0c;另…

C++ 函数与指针

函数内部数据是地址需要传递给调用函数&#xff0c;返回的当然是指针了&#xff01;当然&#xff0c;这个返回地址也可以通过函数参数返回&#xff01; 函数的参数是指针可以输出函数多个结果&#xff0c;返回值本身就是返回数据&#xff0c;什么时候需要返回指针呢&#xff1f…
最新文章