实现pytorch版的mobileNetV1

mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客

这里是根据网络结构,搭建模型,用于图像分类任务。

1. 网络结构和基本组件

2. 搭建组件

(1)普通的卷积组件:CBL = Conv2d + BN + ReLU6;

(2)深度可分离卷积:DwCBL  = Conv dw+ Conv dp;

Conv dw+ Conv dp = {Conv2d(3x3) + BN + ReLU6 }  + {Conv2d(1x1) + BN + ReLU6};

Conv dw是3x3的深度卷积,通过步长控制是否进行下采样;

Conv dp是1x1的逐点卷积,通过控制输出通道数,控制通道维度的变化;

# 普通卷积
class CBN(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super(CBN, self).__init__()
        self.conv = nn.Conv2d(in_c, out_c, 3, stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU6(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
# 深度可分离卷积: 深度卷积(3x3x1) + 逐点卷积(1x1xc卷积)
class DwCBN(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super(DwCBN, self).__init__()
        # conv3x3x1, 深度卷积,通过步长,只控制是否缩小特征hw
        self.conv3x3 = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False)
        self.bn1 = nn.BatchNorm2d(in_c)
        self.relu1 = nn.ReLU6(inplace=True)
        # conv1x1xc, 逐点卷积,通过控制输出通道数,控制通道维度的变化
        self.conv1x1 = nn.Conv2d(in_c, out_c, 1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu2 = nn.ReLU6(inplace=True)

    def forward(self, x):
        x = self.conv3x3(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv1x1(x)
        x = self.bn2(x)
        x = self.relu2(x)
        return x

3. 搭建网络

class MobileNetV1(nn.Module):
    def __init__(self, class_num=1000):
        super(MobileNetV1, self).__init__()
        self.stage1 = torch.nn.Sequential(
            CBN(3, 32, 2),  # 下采样/2
            DwCBN(32, 64, 1)
        )
        self.stage2 = torch.nn.Sequential(
            DwCBN(64, 128, 2),  # 下采样/4
            DwCBN(128, 128, 1)
        )
        self.stage3 = torch.nn.Sequential(
            DwCBN(128, 256, 2),  # 下采样/8
            DwCBN(256, 256, 1)
        )
        self.stage4 = torch.nn.Sequential(
            DwCBN(256, 512, 2),  # 下采样/16
            DwCBN(512, 512, 1),  # 5个
            DwCBN(512, 512, 1),
            DwCBN(512, 512, 1),
            DwCBN(512, 512, 1),
            DwCBN(512, 512, 1),
        )
        self.stage5 = torch.nn.Sequential(
            DwCBN(512, 1024, 2),  # 下采样/32
            DwCBN(1024, 1024, 1)
        )

        # classifier
        self.avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(1024, class_num, bias=True)

        # self.classifier = torch.nn.Softmax()  # 原始的softmax值
        # torch.log_softmax 首先计算 softmax 然后再取对数,因此在数值上更加稳定。
        # 在分类网络在训练过程中,通常使用交叉熵损失函数(Cross-Entropy Loss)。
        # torch.nn.CrossEntropyLoss 会在内部进行 softmax 操作,因此在网络的最后一层不需要手动加上 softmax 操作。

    def forward(self, x):
        scale1 = self.stage1(x)  # /2
        scale2 = self.stage2(scale1)
        scale3 = self.stage3(scale2)
        scale4 = self.stage4(scale3)
        scale5 = self.stage5(scale4)  # /32. 7x7

        x = self.avg_pooling(scale5)  # (b,1024,7,7)->(b,1024,1,1)
        x = torch.flatten(x, 1)  # (b,1024,1,1)->(b,1024,)
        x = self.fc(x)  # (b,1024,)  -> (b,1000,)
        return x


if __name__ == '__main__':
    m1 = MobileNetV1(class_num=1000)
    input_data = torch.randn(64, 3, 224, 224)
    output = m1.forward(input_data)
    print(output.shape)

待续。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/298483.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大模型学习第一课

学习目标: 大模型开源体系 学习内容: 大模型简述大模型性能开源体系 学习时间: 周四上午 10点 学习记录: 大模型简述 大模型是发展通用人工智能的重要途经专用模型到通用大模型实验室开源历程,大模型系列7B-20B-12…

k8s实践(14)--scheduler调度器和pod调度策略

一、scheduler调度器 1、kube-scheduler简介 k8s实践(10) -- Kubernetes集群运行原理详解 介绍过kube-scheduler。 kube-scheduler是运行在master节点上,其主要作用是负责资源的调度(Pod调度),通过API Server的Watch接口监听新建…

C++中的new和delete

相关文章 C智能指针 文章目录 相关文章前言一、new 运算符1. operator new 函数的范围2. 在类中重载new运算符3. 分配失败 二、delete 运算符1. 内存泄露统计示例2. 在类中重载delete运算符 总结 前言 在C中,new和delete是用于动态内存管理的运算符,它们…

Halcon计算一个区域的最大内接圆 inner_circle

Halcon计算一个区域的最大内接圆 该算子用于计算一个区域的最大内接圆,其原型如下: inner_circle(Regions : :: Row, Column, Radius)参数1:Regions 表示输入的区域。 参数2和3:Row、Column为输出参数,表示最大内接圆…

面试经典题---6.Z字形变换

6.Z字形变换 我的解法: 首先定义了3个变量:index、add和step。 index:当前处理字符在原字符串中的下标;add:Z字形中相邻两个字符在原字符串中的下标之差(非固定值,值随着行的改变会发生变化&am…

Linux 上 Nginx 配置访问 web 服务器及配置 https 访问配置过程记录

目录 一、前言说明二、配置思路三、开始修改配置四、结尾 一、前言说明 最近自己搭建了个 Blog 网站,想把网站部署到服务器上面,本文记录一下搭建过程中 Nginx 配置请求转发的过程。 二、配置思路 web项目已经在服务器上面运行起来了,运行的端…

EtherCAT主站SOEM -- 13 --Qt-Soem通过界面按键控制 EtherCAT IO模块的io输出

EtherCAT主站SOEM -- 13 --Qt-Soem通过界面按键控制 EtherCAT IO模块的io输出 一 mainwindow.c 文件函数:1.1 自定义PDO配置2.2 主站初始化2.3 去motrorcontrol界面二 motrorcontrol.c 文件三 allvalue.h 文件该文档修改记录:总结一 mainwindow.c 文件函数: mainwindow主界…

性能分析与调优: Linux 性能分析60秒

目录 一、实验 1.环境 2.Linux性能分析60秒 一、实验 1.环境 (1)主机 表1-1 主机 主机架构组件IP备注prometheus 监测 系统 prometheus、node_exporter 192.168.204.18grafana监测GUIgrafana192.168.204.19agent 监测 主机 node_exporter192.168…

数据分析基础之《numpy(6)—IO操作与数据处理》

了解即可,用panads 一、numpy读取 1、问题 大多数数据并不是我们自己构造的,而是存在文件当中,需要我们用工具获取 但是numpy其实并不适合用来读取和处理数据,因此我们这里了解相关API,以及numpy不方便的地方即可 2…

java解析json复杂数据的两种思路

文章目录 一、原始需求二、简单分析三、具体实现一1. api接口2. 接口返回3. json 数据解析1.)引入Jackson库2.)定义实体3.)解析json字符串4.)运行结果 4. 过程分析 四、具体实现二1. 核心代码2.运行结果 五、方案比较六、源码传送…

python数据可视化之折线图案例讲解

学习完python基础知识点,终于来到了新的模块——数据可视化。 我理解的数据可视化是对大量的数据进行分析以更直观的形式展现出来。 今天我们用python数据可视化来实现一个2023年三大购物平台销售额比重的折线图。 准备工作:我们需要下载用于生成图表的第…

MySQL之视图外连接、内连接和子查询的使用

一、视图 1.1 含义 虚拟表,和普通表一样使用 1.2 操作 创建视图 create view 视图名 as 修改视图 方式一: create or replace view 视图名 as 【查看视图相关字段】 方式二: alter view 视图名 as 【查看的SQL语句】 查看视图 方式一&…

【算法笔记】深入理解dfs(两道dp题)

DFS过程的概述 一个一个节点的搜,如果是树状结构的话,先找到最左边那一条分支搜到最后一个节点,这个时候最后一个节点(假设是b)的数据会被更新(具体看题目的要求),然后返回到上一个…

服务器终端快速下载coco数据集

######解压到当前文件夹 sudo apt-get install aria2 aria2c -c <url> #<url>即为官网下载地址# url # download images http://images.cocodataset.org/zips/train2017.zip http://images.cocodataset.org/zips/val2017.zip# download annotations http://i…

Pytest的测试报告——Allure

一、html-report测试报告 html-report测试报告。是pytest下基本的测试报告。要使用pytest-html测试报告&#xff0c;就要确保python版本在3.6及以上即可。本身pytest所提供的测试结果汇总&#xff0c;是基于控制台的文本输出形式。 pytest-html是基于HTML格式实现的测试报告的…

Spark调优解析-spark数据倾斜优化2(七)

1 数据倾斜优化 1.1为何要处理数据倾斜&#xff08;Data Skew&#xff09; 什么是数据倾斜 对Spark/Hadoop这样的大数据系统来讲&#xff0c;数据量大并不可怕&#xff0c;可怕的是数据倾斜。 何谓数据倾斜&#xff1f;数据倾斜指的是&#xff0c;并行处理的数据集中&#xf…

py的基础语法

前言:本章节主播会详细描述py的基础语法&#xff0c;其中包括语句之间的转换和拼接&#xff0c;内容较多&#xff0c;友友们加油 目录 一.字面量 1.1关于字面量 1.2举例 1.3小结 二.注释 2.1关于注释 2.2举例 2.3小结 三.变量 3.1关于变量 3.2举例 3.3小结 四.数据…

Iceberg从入门到精通系列之十九:分区

Iceberg从入门到精通系列之十九&#xff1a;分区 一、认识分区二、Iceberg的分区三、Hive 中的分区四、Hive 分区问题五、Iceberg的隐藏分区六、分区变换七、分区变换 一、认识分区 分区是一种通过在写入时将相似的行分组在一起来加快查询速度的方法。 例如&#xff0c;从日志…

LeetCode 2807.在链表中插入最大公约数

【LetMeFly】2807.在链表中插入最大公约数 力扣题目链接&#xff1a;https://leetcode.cn/problems/insert-greatest-common-divisors-in-linked-list/ 给你一个链表的头 head &#xff0c;每个结点包含一个整数值。 在相邻结点之间&#xff0c;请你插入一个新的结点&#x…

【MYSQL】MYSQL 的学习教程(十一)之 MySQL 不同隔离级别,都使用了哪些锁

聊聊不同隔离级别下&#xff0c;都会使用哪些锁&#xff1f; 1. MySQL 锁机制 对于 MySQL 来说&#xff0c;如果只支持串行访问的话&#xff0c;那么其效率会非常低。因此&#xff0c;为了提高数据库的运行效率&#xff0c;MySQL 需要支持并发访问。而在并发访问的情况下&…
最新文章