Pytorch常用的函数(九)torch.gather()用法

Pytorch常用的函数(九)torch.gather()用法

torch.gather() 就是在指定维度上收集value。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

1、举例直观理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

# 2、我们有index_tensor如下
>>> index_tensor = torch.tensor(
       [[[0, 0, 0, 0],
         [2, 2, 2, 2]],
         
        [[0, 0, 0, 0],
         [2, 2, 2, 2]]]
)	

# 3、我们通过torch.gather()函数获取out_tensor
>>> out_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
tensor([[[ 0,  1,  2,  3],
         [ 8,  9, 10, 11]],
         
        [[12, 13, 14, 15],
         [20, 21, 22, 23]]])

我们以out_tensor中[0,1,0]=8为例,解释下如何利用dim和index,从input_tensor中获得8。

在这里插入图片描述

根据上图,我们很直观的了解根据 index ,在 inputdim 维度上收集 value的过程。

  • 假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可;
  • 但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报错 ;
  • 所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。在上面示例中,由于dim=1,那么我们就替换索引列表第1个值,即[i,dim,k],因此由原来的[0,1,0]替换为[0,2,0]后,再去input_tensor中取值。
  • pytorch官方文档的写法如下,同一个意思。
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

2、反推法再理解

# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

# 2、假设我们要得到out_tensor如下
>>> out_tensor
tensor([[[ 0,  1,  2,  3],
         [ 8,  9, 10, 11]],
         
        [[12, 13, 14, 15],
         [20, 21, 22, 23]]])# 3、如何知道dim 和 index_tensor呢? 
# 首先,我们要记住:out_tensor的shape = index_tensor的shape

# 从 output_tensor 的第一个位置开始:
# 此时[i, j, k]一样,看不出来 dim 应该是多少
output_tensor[0, 0, :] = input_tensor[0, 0, :] = 0
# 同理可知,此时index都为0
output_tensor[0, 0, 1] = input_tensor[0, 0, 1] = 1
output_tensor[0, 0, 2] = input_tensor[0, 0, 2] = 2
output_tensor[0, 0, 3] = input_tensor[0, 0, 3] = 3

# 我们从下一行的第一个位置开始:
# 这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2
output_tensor[0, 1, 0] = input_tensor[0, 2, 0] = 8
# 同理可知,此时index都为2
output_tensor[0, 1, 1] = input_tensor[0, 2, 1] = 9
output_tensor[0, 1, 2] = input_tensor[0, 2, 2] = 10
output_tensor[0, 1, 3] = input_tensor[0, 2, 3] = 11

# 根据上面推导我们易知dim=1,index_tensor为:
>>> index_tensor = torch.tensor(
       [[[0, 0, 0, 0],
         [2, 2, 2, 2]],
         
        [[0, 0, 0, 0],
         [2, 2, 2, 2]]]
)	

3、实际案例

在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中,多次使用了torch.gather() 函数。

  • 论文链接:https://arxiv.org/pdf/2111.06377
  • 官方源码:https://github.com/facebookresearch/mae

在MAE中根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分mask 掉。采样25%作为unmasked tokens过程中,使用了torch.gather() 函数。

# models_mae.py

import torch

def random_masking(x, mask_ratio=0.75):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))  # 计算unmasked的片数
    # 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】
    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

    # sort noise for each sample【核心代码】
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    # 利用torch.gather()从源tensor中获取25%的unmasked tokens
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore

if __name__ == '__main__':
    x = torch.arange(64).reshape(1, 16, 4)
    random_masking(x)
# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量,实际上一般为(img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196
# 4是一个patch中像素个数,这里只是模拟,实际上一般为(in_chans * patch_size * patch_size = 3*16*16 = 768)
>>> x = torch.arange(64).reshape(1, 16, 4) 
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15],
         [16, 17, 18, 19], # 4
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31],
         [32, 33, 34, 35],
         [36, 37, 38, 39],
         [40, 41, 42, 43], # 10
         [44, 45, 46, 47],
         [48, 49, 50, 51], # 12
         [52, 53, 54, 55], # 13
         [56, 57, 58, 59],
         [60, 61, 62, 63]]])
# dim=1, index相当于index_tensor
>>> index
tensor([[[10, 10, 10, 10],
         [12, 12, 12, 12],
         [ 4,  4,  4,  4],
         [13, 13, 13, 13]]])


# x_masked(从源tensor即x中,随机获取25%(4个patch)的unmasked tokens)     
>>> x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],
         [48, 49, 50, 51],
         [16, 17, 18, 19],
         [52, 53, 54, 55]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/605772.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

android图标底色问题,debug与release不一致

背景 在android 8(sdk 26)之前的版本,直接使用图片文件作为图标,开发时比较容易控制图标,但是不同的安卓定制版本就不容易统一图标风格了。 在android 8及之后的版本,图标对应的是ic_launcher.xml&#x…

android权限申请说明

theme: condensed-night-purple 在Android开发中,权限是指应用程序需要访问特定的设备功能或数据时所需的用户许可。从Android 6.0(API级别23)开始,Android引入了运行时权限模型,在应用程序运行期间向用户请求权限&am…

2024年全国五大数学建模竞赛Top榜及难度分析!推荐数维杯!!!

发现最近许多同学都陆续开始准备今年的数学建模竞赛了,但是随着数学建模领域越来越普及,影响力越来越广泛,参加的同学也越来越多,就导致有越来越多各式各样的数学建模竞赛此起彼伏出现,但其中有一些竞赛其实并不值得参…

Github 50k star!吴恩达联合OpenAi共同编写<面向开发者的LLM入门教程> PDF推荐!

今天给大家推荐一本由吴恩达和OpenAI团队共同编写的关于大型语言模型&#xff08;LLM&#xff09;的权威教程<面向开发者的LLM入门教程>&#xff01;&#xff0c;在Github上已经高达50k star了&#xff0c;这含金量不用多说&#xff0c;在这里给大家强烈推荐一波&#xf…

Wappalyzer指纹识别下载安装使用教程,图文教程(超详细)

「作者简介」&#xff1a;2022年北京冬奥会网络安全中国代表队&#xff0c;CSDN Top100&#xff0c;就职奇安信多年&#xff0c;以实战工作为基础对安全知识体系进行总结与归纳&#xff0c;著作适用于快速入门的 《网络安全自学教程》&#xff0c;内容涵盖系统安全、信息收集等…

Spring Security初探

url说明方法/login/oauth/authorize无登录态时跳转到/authentication/require&#xff0c;有登录态时跳转到/loginorg.springframework.security.oauth2.provider.endpoint.AuthorizationEndpoint#authorize/authentication/require自己写的用于重定向到登录页面的urlcn.merryy…

【笔记】常用USB转串口芯片CH340驱动自动静默安装方法

微信关注公众号 “DLGG创客DIY” 设为“星标”&#xff0c;重磅干货&#xff0c;第一时间送达。 前言 CH340是沁恒(南京沁恒微电子股份有限公司)一款非常有名USB转串口芯片&#xff0c;很多廉价的开发板上都使用这款USB转串口芯片&#xff0c;我觉得主要原因是因为它成本最低&a…

SSM【Spring SpringMVC Mybatis】——Mybatis

目录 1、初识Mybatis 1.1Mybatis简介 1.2 官网地址 2、搭建Mybatis框架 2.1 准备 2.2 搭建Mybatis框架步骤 1. 导入jar包 2. 编写核心配置文件【mybatis-config.xml】 3. 书写相关接口及映射文件 4. 测试【SqlSession】 2.3 添加Log4j日志框架 导入jar包 编写配置文…

Nginx配置/.well-known/pki-validation/

当你需要在Nginx上配置.well-known/pki-validation/时&#xff0c;这通常是为了支持SSL证书的自动续订或其他验证目的。以下是配置步骤&#xff1a; 创建目录结构&#xff1a; 在你的网站根目录下创建一个名为.well-known的目录&#xff08;SSL证书申请之如何创建/.well-known/…

重启酱酒老品牌茅泉酒,肆拾玖坊打造“第二增长曲线”

执笔 | 尼 奥 编辑 | 扬 灵 在肆拾玖坊9周年纪念日前夕&#xff0c;来自业内外的精英朋友相聚茅台镇&#xff0c;见证肆拾玖坊“第二增长曲线”——茅泉酒的焕新出发。正如肆拾玖坊创始人、CEO张传宗所言&#xff1a;“这是重要的里程碑&#xff0c;也是小小的开始。” 从…

【excel】数据非数值导致排序失效

场景 存在待排序列的数值列&#xff0c;但排序失效&#xff0c;提示类型有问题&#xff1a; 解决 选中该列&#xff0c;数据→分列 而后发现提示消失&#xff0c;识别为数字&#xff0c;可正常排序。

cypress的安装使用

cypress npm install -g cnpm --registryhttps://registry.npm.taobao.org

力扣41. 缺失的第一个正数

Problem: 41. 缺失的第一个正数 文章目录 题目描述思路复杂度Code 题目描述 思路 1.将nums看作为一个哈希表&#xff0c;每次我们将数字n移动到nums[n - 1]的位置(例如数字1应该存在nums[0]处…),则在实际的代码操作中应该判断nums[i]与nums[nums[i] - 1]是否相等&#xff0c;若…

直播报名 | 珈和科技携手潍柴雷沃共探“现代农场”未来式

数据赋农季系列直播第四期&#xff0c;我们将以“未来农业发展趋势之农场智慧化、管理数据化”为主题展开&#xff0c;此次系列直播由珈和科技及湖北珞珈实验室共同主办&#xff0c;第四期直播很荣幸邀请到潍柴雷沃参与其中&#xff0c;双方将就智慧农服平台和数据交易SaaS平台…

交友软件源码-源码+搭建+售后,上线即可运营聊天交友源码 专业语聊交友app开发+源码搭建-快速上线

交友小程序源码是一种可以帮助开发者快速搭建交友类小程序的代码模板。它通常包括用户注册、登录、个人信息编辑、匹配推荐、好友聊天等常见功能&#xff0c;以及与后台数据交互的接口。使用这种源码可以极大地缩短开发时间&#xff0c;同时也可以根据自己的需求进行二次开发和…

Docker与Harbor:构建企业级私有Docker镜像仓库

目录 引言 一、本地私有仓库 &#xff08;一&#xff09;基本概述 &#xff08;二&#xff09;搭建本地私有仓库 1.下载registry镜像 2.启动容器 3.上传本地镜像 4.客户端下载镜像 二、Harbor简介 &#xff08;一&#xff09;什么是 Harbor &#xff08;二&#xff…

优质资料:大型制造企业等级保护安全建设整改依据,系统现状分析,网络安全风险分析

第1章 项目概述 XX 大型制造型企业是国内一家大型从事制造型出口贸易的大型综合企业集团&#xff0c;为了落实国家及集团的信息安全等级保护制度&#xff0c;提高信息系统的安全防护水平&#xff0c;细化各项信息网络安全工作措施&#xff0c;提升网络与信息系统工作的效率&am…

123.Android 简单的定位和语音识别 免费高德定位 免费语音识别 不需要接入SDK 不需要导入任何的离线包

//免费的定位 高德定位 不需要接入高德SDk也可进行高德定位&#xff1a; //免费的语音识别 不需要接入任何的SDK 也不需要导入任何的离线语音包&#xff1a; //CSDN 小妞得意 //具体代码实现 私聊 //---------------------------------------------------------------END…

驱动比例线圈功率放大器

驱动比例线圈功率放大器是一种用于控制比例电磁铁的电流大小实现被控设备的位移&#xff0c;采用高性能的嵌入式32位微处理器作为运算核心&#xff0c;这些微处理器具有高速指令运行能力&#xff0c;电源24VDC驱动&#xff0c;输入指令兼容性强&#xff0c;输出电流大小可调&am…

首席数据官CCRC-CDO如何构筑企业数据合规的坚固防线

在当今信息化快速发展的时代&#xff0c;数据已经成为企业最宝贵的资产之一。然而&#xff0c;随着数据规模的迅速增长&#xff0c;数据合规问题也日益凸显。首席数据官&#xff08;CDO&#xff09;作为企业中负责数据战略和管理的核心人物&#xff0c;构筑企业数据合规的坚固防…
最新文章