pytorch复现4_Resnet

ResNet在《Deep Residual Learning for Image Recognition》论文中提出,是在CVPR 2016发表的一种影响深远的网络模型,由何凯明大神团队提出来,在ImageNet的分类比赛上将网络深度直接提高到了152层,前一年夺冠的VGG只有19层。ImageNet的目标检测以碾压的优势成功夺得了当年识别和目标检测的冠军,COCO数据集的目标检测和图像分割比赛上同样碾压夺冠,可以说ResNet的出现对深度神经网络来说具有重大的历史意义。

在这里插入图片描述
在resnet出现之前,网络层数的增加会导致梯度消失或者梯度爆炸
在ResNet网络中有如下几个亮点:
(1)提出residual结构(残差结构),并搭建超深的网络结构(突破1000层)
(2)使用Batch Normalization加速训练(丢弃dropout)

残差结构(residual)

下图是论文中给出的两种残差结构。左边的残差结构是针对层数较少网络,例如ResNet18层和ResNet34层网络
右边是针对网络层数较多的网络,例如ResNet101,ResNet152等。
为什么深层网络要使用右侧的残差结构呢。因为,右侧的残差结构能够减少网络参数与运算量。同样输入、输出一个channel为256的特征矩阵,如果使用左侧的残差结构需要大约1170648个参数,但如果使用右侧的残差结构只需要69632个参数。明显搭建深层网络时,使用右侧的残差结构更合适。

在这里插入图片描述
代码:

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

完整代码:

import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/112878.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

正点原子嵌入式linux驱动开发——Linux USB驱动

USB是很常用的接口,目前大多数的设备都是USB接口的,比如鼠标、键盘、USB摄像 头等,在实际开发中也常常遇到USB接口的设备,本章就来学习一下如何使能Linux内核自带的USB驱动。这里不会具体学习USB的驱动开发。 USB接口简介 什么是…

常用sql语句

/*表操作*/ drop table order; create table products( product_no integer primary key default 1, name text, price numeric default 9.99 ); create table orders ( order_id integer primary key default 1, product_no int, quantity integer ); create table order_…

目标检测:Proposal-Contrastive Pretraining for Object Detection from Fewer Data

论文作者:Quentin Bouniot,Romaric Audigier,Anglique Loesch,Amaury Habrard 作者单位:Universit Paris-Saclay; Universit Jean Monnet Saint-Etienne; Universitaire de France (IUF) 论文链接:http://arxiv.org/abs/2310.16835v1 内容…

ztree调整节点间距及一般使用

1.基本介绍 树形结构菜单的功能属于非常常见的一种菜单交互,本人先后也使用过多种树形结构的插件,有 ztree、xloadtree、treeview、datagrid-tree 等等等等。近期有个功能恰好又要使用tree菜单了,由于可自行选择使用的组件,所以略…

微信小程序 人工智能志愿者服务活动报名系统uniAPP+vue

基于java语言设计并实现了人工智能志愿者服务APP。该APP基于B/S即所谓浏览器/服务器模式,应用SpringBoot框架与HBuilder X技术,选择MySQL作为后台数据库。系统主要包括用户、志愿活动、活动报名、活动签到、服务职责、服务排行等功能模块。 本文首先介绍…

【2021集创赛】Risc-v杯一等奖:自适应噪声环境的超低功耗语音关键词识别系统

本作品参与极术社区组织的有奖征集|秀出你的集创赛作品风采,免费电子产品等你拿~活动。 团队介绍 参赛单位:东南大学 队伍名称:Hey Siri 指导老师:刘波 参赛队员:钱俊逸、张人元、王梓羽 总决赛奖项:全国一等奖 摘要…

数字孪生技术:金融业合规与自动化的未来

在当今数字化时代,金融行业正积极探索数字孪生技术,以实现更高效的运营和更好的客户体验。数字孪生是一种将实体世界的对象、过程和系统数字化为虚拟模型的技术,金融机构正在充分利用它带来的众多优势。 1. 风险管理与模拟 数字孪生模型可用…

docker部署minio并使用springboot连接

需求:工作中,在微信小程序播放时,返回文件流并不能有效的使用,前端需要一个可以访问的地址,springboot默认是有资源拦截器的,但是不适合生产环境的使用 可以提供使用的有例如fastdfs或者minio,这…

Flutter 05 组件状态、生命周期、数据传递(共享)、Key

一、Android界面渲染流程UI树与FlutterUI树的设计思路对比 二、Widget组件生命周期详解 1、Widget组件生命周期 和其他的视图框架比如android的Activity一样,flutter中的视图Widget也存在生命周期,生命周期的回调函数体现在了State上面。组件State的生命…

QT实现在线流媒体播放平台

文章目录 QT实现在线流媒体播放平台简介开发视频ffmpeg下载SimpleVideoPlayer.hSimpleVideoPlayer.cpp 开发音频添加功能打开文件夹播放暂停播放上下一首选择倍速 效果展示项目下载 QT实现在线流媒体播放平台 简介 Qt是一种流行的C开发框架,它提供了用于构建图形用…

FPGA_Signal TapII 逻辑分析仪 在线信号波形抓取

FPGA_Signal TapII 逻辑分析仪 在线信号波形抓取 由于一些工程的仿真文件不易产生,所以我们可以利用 quartus 软件自带的 SignalTap 工具对波形进行抓取 对各个信号进行分析处理,让电子器件与FPGA进行正常通讯工作,也验证所绘制的波形图是否一…

Scala库用HTTP爬虫IP代码示例

根据提供的引用内容,sttp.client3和sttp.model库是用于HTTP请求和响应处理的Scala库,可以与各种Scala堆栈集成,提供同步和异步,过程和功能接口。这些库可以用于爬虫程序中,用于发送HTTP请求和处理响应。需要注意的是&a…

JavaScript从入门到精通系列第二十七篇:详解JavaScript中的包装类

大神引荐:作者有幸结识技术大神孙哥为好友获益匪浅,现在把孙哥视频分享给大家 孙哥链接:孙哥个人主页 作者简介:一个颜值99分,只比孙哥差一点的程序员 本专栏简介:话不多说,让我们一起干翻JavaS…

vue自定义组件:中线分割拖动盘

在GitHub上可以找到类似的组件,比如4年前发布的vue2版本的 Vue Split Pane, 但是我还是自己写了一个类似的: 组件效果: 特点: 不是照抄别人的。同时支持vue2、vue3(组件内部使用选项式API风格&#xff09…

不一样的网络协议-------KCP协议

1、kcp 的协议特点 1.1、RTO 不翻倍 RTO(Retransmission TimeOut),重传超时时间。tcp x 2,kcp x 1.5,提高传输速度 1.2、选择重传 TCP丢包时会全部重传从该包开始以后的数据,而KCP选择性重传,只重传真正丢失的数据包…

同为科技(TOWE)自动断电倒计时定时桌面PDU插排

在每个家庭中,插排插座都是必不可少的电源设备。随着各种电器的普及应用和生活节奏的加快,人们对插排也有着多样化的需求,比如在插排中加入定时开关、自动断电、断电记忆、倒计时等等功能,让原本不支持智能家居的用电器秒变智能。…

DL4J无法下载MNIST数据集解决 Server returned HTTP response code: 403 for URL解决方法

报错情况 报错如下: 16:45:41.463 [main] INFO org.nd4j.nativeblas.Nd4jBlas - Number of threads used for OpenMP BLAS: 6 16:45:41.497 [main] INFO org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner - Backend used: [CPU]; OS: [Windows 10] 16:4…

Android 驱动学习调试

1 Android 驱动代码编译 参考https://www.sharetechnote.com/html/Linux_DeviceDriver_Programing.html#Device_Driver_HelloWorld编译ko文件调试驱动代码&#xff0c;将ko文件push到手机上验证 相关C文件testdriver.c #include <linux/init.h> #include <linux/mod…

云安全—kubelet攻击面

0x00 前言 虽然说总结的是kubelet的攻击面&#xff0c;但是在总结攻击面之前还是需要去了解kubelet的基本原理&#xff0c;虽然说我们浅尝即止&#xff0c;但是还是要以能给别人讲出来为基本原则。 其他文章: 云安全—K8s APi Server 6443 攻击面云安全—K8S API Server 未授…

08-Docker-网络管理

Docker 在网络管理这块提供了多种的网络选择方式&#xff0c;他们分别是桥接网络、主机网络、覆盖网络、MACLAN 网络、无桥接网络、自定义网络。 1-无桥接网络&#xff08;None Network&#xff09; 当使用无桥接网络时&#xff0c;容器不会分配 IP 地址&#xff0c;也不会连…