torch.gather(...)

1. Abstract

对于 pytorch 中的函数

torch.gather(
	input,  # (Tensor) the source tensor
	dim,    # (int)    the axis along which to index
	index,  # (LongTensor) the indices of elements to gather
	*,
	sparse_grad=False,
	out=None
) → Tensor

有点绕,很多博客画各种图讲各种故事来解释如何input 张量中 gather 位置 index 处的值,乱七八糟,我是都没看明白。所以去官网看了文档:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

从这三行看,意思还是很明晰的:输出 out 和输入 input 之间的差别就是,把相应位置(dim)的下标替换成 index[i][j][k]dim=0,1,2 分别对应替换的位置0,1,2。但这不够直观!

【注】从上面三行代码可以看出,indexinput 的维度尺寸是一样的,即 len(index.shape) == len(input.shape),但不一定是相同的形状:index.shape[dim] ≠ input.shape[dim](其他维度的形状必须满足 index.shape <= input.shape)。

2. 图解

2.1 一维向量

先从简单的一维向量看看:

x = torch.tensor([3, 4, 5, 6, 7])

按规则看,out[i] = input[index[i]] # dim == 0,即,从向量里选取指定位置 index[i] 处的数字,放到输出向量 out[i] 处。这个很好理解,pythonnumpypytorch 都有这样的语法:

x = torch.randn(3)
index = torch.randint(low=0, high=3, size=(5,))
y = x[index]
print(x)
print(index)
print(y)
### output ###
tensor([ 0.8797,  0.2459, -0.1312])
tensor([2, 0, 2, 2, 0])
tensor([-0.1312,  0.8797, -0.1312, -0.1312,  0.8797])

torch.gather(...) 函数,就是这样的:

x = torch.tensor([3, 4, 5, 6, 7])
index = torch.tensor([4, 4, 1, 1, 0, 3])
out = torch.gather(x, dim=0, index=index)
### output ###
tensor([7, 7, 4, 4, 3, 6])

举例来说,上面的 index[4] = 0,那么它会寻找 input[index[4]] = input[0] = 3,然后放入 out[4]。这就是英文单词 gather 的意思。

index 的长度是不受限制的,即 gather 多少元素都可以。

小结:在一维向量下,out = torch.gather(x, dim=0, index=index) 等价于 out = x[index]

2.2 二维矩阵

往上升一个维度,看看对二维矩阵实施 gather 函数的操作:

x = torch.tensor([[3, 4, 5, 6, 7], [9, 8, 7, 6, 5]])
idx = torch.randint(low=0, high=5, size=(2, 6))
y = torch.gather(x, dim=1, index=idx)
print(x)
print(idx)
print(y)
### output ###
tensor([[3, 4, 5, 6, 7],
        [9, 8, 7, 6, 5]])
tensor([[4, 4, 1, 1, 0, 3],
        [0, 1, 2, 1, 4, 1]])
tensor([[7, 7, 4, 4, 3, 6],
        [9, 8, 7, 8, 5, 8]])

按规则看,out[i][j] = input[i][index[i][j]] # dim == 1,即,从向量 input[i] 里选取指定位置 index[i][j] 处的数字,放到输出向量 out[i][j] 处。也许多了一个维度就有点绕了,但仔细观察,我们可以假定 i = 0,此时:

out[0][j] = input[0][index[0][j]]  # 对应上图的左侧

若假定 i = 1,则:

out[1][j] = input[1][index[1][j]]  # 对应上图的右侧

即,输出 out[i] 是对输入 imput[i] 执行了一次与一维向量时一样的操作,其中下标是 index[i]。在二维矩阵上的 gather 操作,不过是并行地执行了多个一维向量的 gather

上面是 dim = 1 时的情况,是沿着矩阵的进行 gather,当 dim = 0 时,就是沿着进行 gather

out[i][0] = input[index[i][0]][0]  # dim == 0
out[i][1] = input[index[i][1]][1]
...


也就是并行地执行多个列向量gather,每列 index 是一个并行分支,并行分支的数量可以小于 input 的列数,但不能超过,超过的话,它 gather 哪一列呢?

小结:二维矩阵的 gather 操作就是并行地执行了多个一维向量的 gather 操作;dim=1 按行 gatherdim=0 按列 gather

2.3 高维张量

弄懂一维到二维的 gather,更高维的操作也就清晰了,就是画图有一点难画。假设

x = tensor([[[ 0,  1,  2,  3,  4],
             [ 5,  6,  7,  8,  9]],

		    [[10, 11, 12, 13, 14],
             [15, 16, 17, 18, 19]],

		    [[20, 21, 22, 13, 24],
             [25, 26, 27, 28, 29]]])

则当 dim == 0 时,是沿着第一维进行 gather 的,那么 index.shape[0] (一个并行分支 gather 的元素的数量) 可为任意数,这里设置为 4,其他 index.shape[i≠0] <= input.shape[i≠0]

index = tensor([[[1, 2, 2],
         		 [2, 2, 0]],

       			[[0, 0, 1],
         		 [1, 0, 1]],

		        [[2, 0, 0],
        		 [0, 1, 2]],

		        [[1, 1, 0],
        		 [0, 0, 0]]])

index.shape == (4, 2, 3),执行:

y = torch.gather(x, dim=0, index=index)

的示意图如下:

只画了看得见的前两列(两个并行 gather 分支)。红色和绿色箭头表示两列下标沿着 dim=0 进行 gather 操作,每一列和一维向量的 gather 是一样的,只不过这里有 2*3 个列。

再往高维拓展,也是一样,都是从基本的一维向量 gather 拓到并行 gather

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/255734.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Webpack安装及使用

win系统 全局安装Webpack及使用 前提&#xff1a;使用Webpack必须安装node环境&#xff0c;建议使用nvm管理node版本。 1&#xff1a;查看自己电脑是否安装了node 2&#xff1a;npm install webpack版本号 -g 3&#xff1a;npm install webpack-cli -g -g:表示全局安装 4&…

ElasticSearch单机或集群未授权访问漏洞

漏洞处理方法&#xff1a; 1、可以使用系统防火墙 来做限制只允许ES集群和Server节点的IP来访问漏洞节点的9200端口&#xff0c;其他的全部拒绝。 2、在ES节点上设置用户密码 漏洞现象&#xff1a;直接访问9200端口不需要密码验证 修复过程 2.1 生成认证文件 必须要生成…

力扣225. 用队列实现栈【附进阶版】

文章目录 力扣225. 用队列实现栈示例思路及其实现两个队列模拟栈一个队列模拟栈 力扣225. 用队列实现栈 示例 思路及其实现 两个队列模拟栈 队列是先进先出的规则&#xff0c;把一个队列中的数据导入另一个队列中&#xff0c;数据的顺序并没有变&#xff0c;并没有变成先进后…

【转载】【Unity】WebSocket通信

1 前言 Unity客户端常用的与服务器通信的方式有socket、http、webSocket。本文主要实现一个简单的WebSocket通信案例&#xff0c;包含客户端、服务器&#xff0c;实现了两端的通信以及客户端向服务器发送关闭连接请求的功能。实现上没有使用Unity相关插件&#xff0c;使用的就是…

【经典LeetCode算法题目专栏分类】【第5期】贪心算法:分发饼干、跳跃游戏、模拟行走机器人

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能AI、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 分发饼干 class Solutio…

换热站数字孪生 | 图扑智慧供热 3D 可视化

换热站作为供热系统不可或缺的一部分&#xff0c;其能源消耗对城市环保至关重要。在双碳目标下&#xff0c;供热企业可通过搭建智慧供热系统&#xff0c;实现供热方式的低碳、高效、智能化&#xff0c;从而减少碳排放和能源浪费。通过应用物联网、大数据等高新技术&#xff0c;…

MongoDB中的关系

本文主要介绍MongoDB中的关系。 目录 MongoDB的关系嵌入关系引用关系 MongoDB的关系 MongoDB是一个非关系型数据库&#xff0c;它使用了键值对的方式来存储数据。因此&#xff0c;MongoDB没有像传统关系型数据库中那样的表、行和列的概念。相反&#xff0c;MongoDB中的关系是通…

美颜SDK是什么?视频美颜SDK在直播平台中的集成与接入教程详解

当下&#xff0c;主播们追求更加自然、精致的外观&#xff0c;而观众也期待在屏幕前欣赏到更为清晰、美丽的画面。为了满足这一需求&#xff0c;美颜SDK应运而生&#xff0c;成为直播平台的重要利器之一。 一、什么是美颜SDK&#xff1f; 通过美颜SDK&#xff0c;开发者可以…

docker在线安装minio

1、下载最新minio docker pull minio/minio 2、在宿主机创建 /usr/local/data/miniodocker/config 和 /usr/local/data/miniodocker/data,执行docker命令 docker run -p 9000:9000 -p 9090:9090 --name minio -d --restartalways -e MINIO_ACCESS_KEYminio -e MINIO_SECRET_K…

数据结构--图

树具有灵活性&#xff0c;并且存在许多不同的树的应用&#xff0c;但是就树本身而言有一定的局限性&#xff0c;树只能表示层次关系&#xff0c;比如父子关系。而其他的比如兄弟关系只能够间接表示。 推广--- 图 图形结构中&#xff0c;数据元素之间的关系是任意的。 一、图…

Shell编程基础 – C语言风格的Bash for循环

Shell编程基础 – C语言风格的Bash for循环 Shell Programming Essentials - C Style For Loop in Bash By JacksonML 循环是编程语言的基本概念之一&#xff0c;同样也是Bash编程的核心。当用户需要一遍又一遍地运行一系列命令直到达到特定条件时&#xff0c;例如&#xff1…

输电线路定位:精确导航,确保电力传输安全

在现代社会中&#xff0c;电力作为生活的基石&#xff0c;其安全稳定运行至关重要。而输电线路作为电力传输的重要通道&#xff0c;其故障定位和修复显得尤为重要。恒峰智慧科技将为您介绍一种采用分布式行波测量技术的输电线路定位方法&#xff0c;以提高故障定位精度&#xf…

06-部署knative-eventing

环境要求 For prototyping purposes 单节点的Kubernetes集群&#xff0c;有2个可用的CPU核心&#xff0c;以及4g内存&#xff1b; For production purposes 单节点的Kubernetes集群&#xff0c;需要至少有6个CPU核心、6G内存和30G磁盘空间多节点的Kubernetes集群中&#xff0c;…

电影小镇智慧旅游项目技术方案:PPT全文111页,附下载

关键词&#xff1a;智慧旅游项目平台&#xff0c;智慧文旅建设&#xff0c;智慧城市建设&#xff0c;智慧文旅解决方案&#xff0c;智慧旅游技术应用&#xff0c;智慧旅游典型方案&#xff0c;智慧旅游景区方案&#xff0c;智慧旅游发展规划 一、智慧旅游的起源 智慧地球是IB…

windows 安装jenkins

下载jenkins 官方下载地址&#xff1a;Jenkins 的安装和设置 清华源下载地址&#xff1a;https://mirrors.tuna.tsinghua.edu.cn/jenkins/windows-stable/ 最新支持java8的版本时2.346.1版本&#xff0c;在清华源中找不到&#xff0c;在官网中没找到windows的下载历史&#xff…

MySQL数据库 约束

目录 约束概述 外键约束 添加外键 删除外键 删除/更新行为 约束概述 概念&#xff1a;约束是作用于表中字段上的规则&#xff0c;用于限制存储在表中的数据。 目的&#xff1a;保证数据库中数据的正确、有效性和完整性。 分类: 注意&#xff1a;约束是作用于表中字段上…

java继承

1.为什么需要继承 我们编写了两个类,一个是Puppil类(小学生),一个是Graduate(大学生),问题:两个类的属性和方法有很多是相同的,怎么办&#xff1f; 把共有的属性和方法抽离出来: 父类&#xff1a; package com.hspedu.extends01;//父类,是Pupil和Graduate的父类 public cla…

50ms时延工业相机

华睿工业相机A3504CG000 参数配置&#xff1a; 相机端到端理论时延&#xff1a;80ms 厂家同步信息&#xff0c;此款设备帧率上线23fps&#xff0c;单帧时延&#xff1a;43.48ms&#xff0c;按照一图缓存加上传输显示的话&#xff0c;厂家预估时延在&#xff1a;80ms 厂家还有…

AXure的情景交互

目录 导语&#xff1a; 1.erp多样性登录界面 2.主页跳转 3.省级联动​编辑 4. 下拉加载 导语&#xff1a; Axure是一种流行的原型设计工具&#xff0c;可以用来创建网站和应用程序的交互原型。通过Axure&#xff0c;设计师可以创建情景交互&#xff0c;以展示用户与系统的交…

机器学习之线性回归(Linear Regression)

概念 线性回归(Linear Regression)是机器学习中的一种基本的监督学习算法,用于建立输入变量(特征)与输出变量(目标)之间的线性关系。它假设输入变量与输出变量之间存在线性关系,并试图找到最佳拟合线来描述这种关系。 在简单线性回归中,只涉及两个变量:一个是自变量…
最新文章