DEiT中如何处理mask数据的?与MAE的不同

在DeiT里面,是通过mask的方式,将mask+unmasked的patches输出进ViT中,但其实在下游任务输入的patches还是和训练时patches的数量N是一致的(encoder所有的patches)。

而MAE是在encoder中只encoder未被mask的patches

通过什么方式支持的?

  • 在处理文本时,可以根据最长的句子在批次中动态padding或截断长句子
  • 而在处理图像(如使用ViT)时,可以将图像划分为大小相等的patches,数量可以根据图像的大小动态变化。

在训练阶段,部分patches被mask为0,但是处理的所有patches加起来的总长度还是一样的。被mask的位置在模型内部仍然占位,保持了输入序列的“框架”。这样,即使实际参与计算的只是部分元素,模型也能够适应在推理时使用全部元素的情况。

具体的计算步骤如下:

  1. 确定mask哪些patches
  2. 将mask的patches位置设置为0
  3. 这些被mask和未被mask的所有patches一起被输入进attention模块
  4. 将被mask的patches的注意力分数手动设置为“无穷大负数”(-inf)
  5. 这些被mask的patches的softmax值就会变为0,也就意味着这些patches并未参与注意力的计算
import torch
import torch.nn as nn
import torch.nn.functional as F


class MaskedSelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(MaskedSelfAttention, self).__init__()
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # 计算自注意力得分
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float32))

        # 将mask值为0的位置在attention_scores中设置为一个非常大的负数
        attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

        # 使得这些位置的softmax结果接近0
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # 算最终的注意力加权和
        output = torch.matmul(attention_weights, V)

        return output


# 假设嵌入大小为512
embed_size = 512
# 创建一个mask,假设我们有4个patches,我们想要mask掉第2个和第4个patches
mask = torch.tensor([[1, 0, 1, 0]])
# 扩展mask维度以适应attention_scores的形状(假设批大小为1,序列长度为4),mask需要与attention_scores形状匹配,即(batch_size, 1, 1, seq_length)
mask = mask.unsqueeze(1).unsqueeze(2)

# 初始化模型和数据
sa = MaskedSelfAttention(embed_size)
x = torch.randn(1, 4, embed_size)  # 假设有一个批大小为1,序列长度为4的输入

output = sa(x, mask)
print(output)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/472344.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯java组 螺旋折线

题目描述 如图所示的螺旋折线经过平面上所有整点恰好一次。 对于整点(X, Y),我们定义它到原点的距离dis(X, Y)是从原点到(X, Y)的螺旋折线段的长度。 例如dis(0, 1)3, dis(-2, -1)9 给出整点坐标(X, Y),你能计算出dis(X, Y)吗? 【输入格…

直播预告|Sora 会怎样驱动视频编解码领域的突破与革新

在数字化时代,视频内容的传播与消费已成为日常生活的一部分。视频编解码技术是数字媒体领域的一项核心技术,它影响着视频质量,传输速度以及观看体验。与此同时,视频产业正在经历一场由技术驱动的变革,Sora、AIGC 等相关…

通用组件封装——iconfont 封装图标组件

文章目录 背景一、iconfont 处理1. 一键添加入库功能2. 图标项目配置 二、代码实现 背景 在项目中会使用到大量的图标,而 element 等组件库现有的 icon 图标可能无法满足项目的需要,比如很多图标没有可以替代的,或者项目中有彩色图标的需求都…

前端VUE笔记整理

一:PDA H5 1、对于PDA用到的三个命令说明: npm install: 根据package.json安装依赖文件到node_modules文件夹下(如果是第一次可以删除此文件夹下的文件,这个目录不会上传) ​ npm run serve: 运行PDA程序在本地做客户端 ​ npm run build: 打包文件到d…

【CSP】2020-12-2 期末预测之最佳阈值 排序+差分+前缀和

2020-12-2 期末预测之最佳阈值 排序差分前缀和 索引2020-12-2 期末预测之最佳阈值 排序差分前缀和思路遇到的问题完整代码 索引 历年CSP认证考试真题题解总汇持续更新 2020-12-2 期末预测之最佳阈值 排序差分前缀和 这题并不算难,但也不是直接套公式那么简单&…

SpringBoot3框架,事件和监听器、SPI

事件和监听器 生命周期监听 自定义监听器的步骤: 编写SpringApplicationRunListener实现类(各个实现方法的功能写在其sout内) public class MyAppListener implements SpringApplicationRunListener {Overridepublic void starting(Configu…

git 安装、创建仓库、常用命令、克隆下载、上传项目、删除分支 -- 一篇文章总结

一、git安装 1、git安装地址:https://git-scm.com/downloads 2、选择操作系统 3、安装自己系统对应的操作位数 4、等待下载完,一路next安装就可以了 5、安装完成后,在任意文件夹点击右键,看到下图说明安装成功 二、创建仓库 1…

法语「奶奶」明明是阴性,为什么不用配合?柯桥法语口语学习小语种学校

咦,法语中“奶奶”到底怎么写?是Grande-mre还是Grand mre?又或者 Grand-mre ? 先写下你的回答,法语君再公布答案哦! 面对这个问题,你已经开始犹豫了对不对? 那么在法语中,到底哪一个…

蓝桥杯之动态规划冲刺

文章目录 动态规划01背包小练一下01背包网格图上的DP完全背包 最长公共字符串最长递增子序列 动态规划 动态规划:确定好状态方程,我们常常是确定前 当状态来到 i 时,前 i 个物体的状态是怎么样的,我们并不是从一个点去考虑&#x…

Docker部署JumpServer3.9.0

简介 JumpServer 是什么? JumpServer 是广受欢迎的开源堡垒机,是符合 4A 规范的专业运维安全审计系统。JumpServer 帮助企业以更安全的方式管控和登录所有类型的资产,实现事前授权、事中监察、事后审计,满足等保合规要求。 Jump…

C/C++炸弹人游戏

参考书籍《啊哈,算法》,很有意思的一本算法书,小白也可以看懂,详细见书,这里只提供代码和运行结果。 这里用到的是枚举思想,还有更好地搜索做法。 如果大家有看不懂的地方或提出建议,欢迎评论区…

如何将Excel两列数据转换为统计图、曲线图、折线图?如何自定义某一列作为Excel的统计图横纵坐标?

这样,横坐标就更换为指定选中的数据了 我们还可以修改统计图的样式 也可以修改统计图的类型

cordova安装安卓版本,遇到的各种坑。折腾了两天才弄好

cordova官网地址 https://cordova.apache.org/docs/en/12.x/guide/cli/index.html 1. 输入命令 npm install -g cordova 全局安装cordova 2. 创建文件和项目以及app的应用名称 cordova create hello com.example.hello HelloWorld 我写的是这个 cordova create myApp 3.co…

深入浅出Hive性能优化策略

我们将从基础的HiveQL优化讲起,涵盖数据存储格式选择、数据模型设计、查询执行计划优化等多个方面。会的直接滑到最后看代码和语法。 目录 引言 Hive架构概览 示例1:创建表并加载数据 示例2:优化查询 Hive查询优化 1. 选择适当的文件格…

安卓findViewById 的优化方案:ViewBinding与ButterKnife(一)

好多小伙伴现在还用findViewById来获取控件的id, 在这里提供俩种替代方案:ViewBinding与ButterKnife; 先来说说ButterKnife ButterKnife ButterKnife是一个专注于Android系统的View注入框架,在过去的项目中总是需要很多的findViewById来查…

Java Day13 多线程

多线程 1、 方式一 Thread2、实现Runnable接口3、实现 Callable接口4、与线程有关的操作方法5、线程安全问题5.1 取钱案例5.2 线程同步5.2.1 同步代码块5.2.2 同步方法5.2.3 Lock锁 6、线程池6.2 创建线程池6.2.1 使用ExecutorService创建新任务策略6.2.2 使用Executors工具类创…

2024年云仓酒庄佛山发布会:赋能

原标题:2024年云仓酒庄佛山发布会圆满落幕,朱囿臻总赋能引领行业新篇章 近日,备受瞩目的云仓酒庄佛山发布会圆满落幕。此次发布会汇聚了业内精英、经销商代表以及媒体人士,共同见证了云仓酒庄在佛山市场的启航。在此,…

智慧公厕:卫生、便捷、安全的新时代厕所变革

在城市快速发展的背景下,公共厕所的建设和管理变得越来越重要。智慧公厕作为厕所变革的一项全新举措,通过建立公共厕所全面感知监测系统,以物联网、互联网、大数据、云计算、自动化控制技术为支撑,实现对公共厕所的智能化管理和运…

练习4-权重衰减(李沐函数简要解析)

环境:练习1的环境 代码详解 0.导入库 import torch from torch import nn from d2l import torch as d2l1.初始化数据 这里初始化出train_iter test_iter 可以查一下之前的获取Fashion数据集后的数据格式与此对应 n_train, n_test, num_inputs, batch_size 20, 100, 200, …

50. 【Linux教程】源码安装软件

本小节介绍如何使用软件的源码包安装软件,以安装 nginx 源码包为例。 1.下载软件源码包 使用如下命令下载 nginx 源码包: wget http://nginx.org/download/nginx-1.18.0.tar.gz执行结果如下图所示: 2.解压源码包 下载好了压缩包之后&#…
最新文章