【DeepLearning-8】MobileViT模块配置

完整代码: 

import torch
import torch.nn as nn
from einops import rearrange
def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn # mg
    
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)# mg
        ) if project_out else nn.Identity()
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b p h n d -> b p n (h d)')
        return self.to_out(out)
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)
class UserDefined(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class IRBlock(nn.Module):
    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]
        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup
        if expansion == 1: # 构建没有扩展层的卷积块
            self.conv = nn.Sequential(
                # 深度可分离卷积(Depthwise Convolution)
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # “线性”逐点卷积 (Pointwise-Linear Convolution)
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:  # 构建包含扩展层的卷积块
            self.conv = nn.Sequential(
                # 逐点卷积 (Pointwise Convolution)
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # 深度可分离卷积 (Depthwise Convolution)
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # “线性”逐点卷积 (Pointwise-Linear Convolution)
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class MobileViTBv3(nn.Module):
    def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size
        self.mv01 = IRBlock(channel, channel) 
        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv2 = conv_1x1_bn(channel, dim)
        self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
    def forward(self, x):
        y = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        z = x.clone()
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
        x = self.conv3(x)
        x = torch.cat((x, z), 1)
        x = self.conv4(x)
        x = x + y
        x = self.mv01(x)
        return x

文件配置在D:\yolov5-master\models路径下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/353848.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

PostgreSQL技术大讲堂 - 第43讲:流复制原理

PostgreSQL从小白到专家,是从入门逐渐能力提升的一个系列教程,内容包括对PG基础的认知、包括安装使用、包括角色权限、包括维护管理、、等内容,希望对热爱PG、学习PG的同学们有帮助,欢迎持续关注CUUG PG技术大讲堂。 第43讲&#…

每日OJ题_算法_二分查找⑦_力扣153. 寻找旋转排序数组中的最小值

目录 力扣153. 寻找旋转排序数组中的最小值 解析代码 力扣153. 寻找旋转排序数组中的最小值 153. 寻找旋转排序数组中的最小值 - 力扣(LeetCode) 难度 中等 已知一个长度为 n 的数组,预先按照升序排列,经由 1 到 n 次 旋转 后…

Java NIO初体验

概述 由于BIO(同步阻塞IO)对系统资源的浪费较大。Java1.4中引⼊了NIO框架,在java.nio包中提供了Channel、Selector、Buffer等抽象类,可以快速构建多路复⽤的IO程序,⽤于提供更接近操作系统底层的⾼性能数据操作⽅式。…

11. 双目视觉之立体视觉基础

目录 1. 深度恢复1.1 单目相机缺少深度信息1.2 如何恢复场景深度?1.3 深度恢复的思路 2. 对极几何约束2.1 直观感受2.2 数学上的描述 1. 深度恢复 1.1 单目相机缺少深度信息 之前学习过相机模型,最经典的就是小孔成像模型。我们知道相机通过小孔成像模…

零基础轻松学编程,中文编程开发工具下载及构件教程

零基础轻松学编程,中文编程开发工具下载及构件教程 一、前言 零基础学编程,中文编程工具下载,中文编程开发工具构件之选择件列表框构件教程, 想学编程可 编程交流.群.一起交流学习(关注收藏或点击最下方官网卡片进入…

【c语言】详解操作符(上)

1. 操作符的分类 2. 原码、反码、补码 整数的2进制表示方法有三种,即原码、反码、补码 有符号整数的三种表示方法均有符号位和数值位两部分,2进制序列中,最高位的1位是被当做符号位其余都是数值位。 符号位都是用0表示“正”,用…

设计模式:简介及基本原则

简介 设计模式是一套被反复使用的、多数人知晓的、经过分类编目的、代码设计经验的总结。使用设计模式是为了重用代码、让代码更容易被他人理解、保证代码可靠性。 毫无疑问,设计模式于己于他人于系统都是多赢的,设计模式使代码编制真正工程化&#xff…

Kotlin快速入门5

Kotlin的继承与重写 kotlin的继承 Kotlin中所有类都继承自Any类,Any类是所有类的超类,对于没有超类型声明的类是默认超类(Any 不是 java.lang.Object): class LearnKotlin // 默认继承自Any Any类默认提供三个函数…

Scikit-learn (sklearn)速通 -【莫凡Python学习笔记】

视频教程链接:【莫烦Python】Scikit-learn (sklearn) 优雅地学会机器学习 视频教程代码 scikit-learn官网 莫烦官网学习链接 本人matplotlib、numpy、pandas笔记 1 为什么学习 Scikit learn 也简称 sklearn, 是机器学习领域当中最知名的 python 模块之一. Sk…

IDEA搭建JDK源码学习环境(可添加注释、修改、debug)

工程详见:https://github.com/wenpanwenpan/study-source-jdk1.8.0_281 1、找到src.zip和javafx-src.zip 找到你想要调试的JDK,笔者本地电脑上装了两个版本的JDK,这里以jdk1.8.0_281为例将JDK目录下的javafx-src.zip和src.zip两个压缩包进行…

【GitHub项目推荐--不错的 TypeScript 学习项目】【转载】

在线白板工具 Excalidraw 标星 33k,是一款非常轻量的在线白板工具,可以直接在浏览器打开,轻松绘制具有手绘风格的图形。 如下图所示,Excalidraw 支持最常用的图形元素:方框、圆、菱形、线,可以方便的使用…

【Web】小白也能做的RWCTF体验赛baby题部分wp

遇到不会的题,怎么办!有的师傅告诉你完了,废了,寄了!只有Z3告诉你,稳辣!稳辣!都稳辣! 这种CVE复现的题型,不可能要求选手从0到1进行0day挖掘,其实…

12.15字符编码(血干JAVA系列)

字符编码 12.15.1 Java常见编码简介12.15.2得到本机的编码显示【例12.68】使用上述方法得到JVM的默认编码 12.15.3乱码产生【例12.69】使用1SO8859-1编码 12.15.1 Java常见编码简介 12.15.2得到本机的编码显示 在 前 面 讲 解 常 用 类 库 的 时 候 曾 经 介 绍 过 &#xff0…

HarmonyOS 沉浸式状态栏实现

比如说 我们代码是这样的 Entry Component struct Index {build() {Column() {Column() {Column() {Text(定标标题).fontSize(20).fontColor(Color.White)}.height(50).justifyContent(FlexAlign.Center)}.width(100%).backgroundColor(#FF0000)}.height(100%)} }你预览器上看…

k8s 版本发布与回滚

一、实验环境准备: kubectl get pods -o wide kubectl get nodes -o wide kubectl get svc 准备两个nginx镜像,版本号一个是V3,一个是V4 二、准备一个nginx.yaml文件 apiVersion: apps/v1 kind: Deployment metadata:name: nginx-deploylab…

AutoGen实战应用(二):多代理协作(Multi-Agent Collaboration)

AutoGen是微软推出的一个全新工具,它用来帮助开发者创建基于大语言模型(LLM)的复杂应用程序. AutoGen能让LLM在复杂工作流程启用多个角色代理来共同协作完成人类提出的任务。在我之前的一篇博客: AutoGen实战应用(一):代码生成、执行和调试 中我们通过一…

CTF CRYPTO 密码学-6

题目名称:敲击 题目描述: 方方格格,不断敲击 “wdvtdz qsxdr werdzxc esxcfr uygbn” 解题过程: step1:根据题目描述敲击,wdvtdz对应的字符为x step2:依此类推r,z,o&…

使用 create-react-app 创建 react 应用

一、创建项目并启动 第一步:全局安装:npm install -g create-react-app 第二步:切换到想创建项目的目录,使用命令create-react-app hello-react 第三步:进入项目目录,cd hello-react 第四步:启…

Pyecharts 风采:从基础到高级,打造炫酷象形柱状图的完整指南【第40篇—python:象形柱状图】

文章目录 引言安装PyechartsPyecharts象形柱状图参数详解1. Bar 类的基本参数2. 自定义图表样式3. 添加标签和提示框 代码实战:绘制多种炫酷象形柱状图进阶技巧:动态数据更新与交互性1. 动态数据更新2. 交互性设计 拓展应用:结合其他图表类型…

Android 基础技术——列表卡顿问题如何分析解决

笔者希望做一个系列,整理 Android 基础技术,本章是关于列表卡顿问题如何分析解决 onBindViewHolder 优化 是否有耗时操作、重复创建对象、设置监听器、findViewByID、局部的动画对象等操作 是否存在内存泄漏 发生内存泄露,会导致一些不再使用…
最新文章