pytorch花式索引提取topk的张量

文章目录

  • pytorch花式索引提取topk的张量
    • 问题设定
    • 代码实现
      • 索引方法
      • gather方法
      • 验证
    • 补充知识
      • expand方法
      • gather方法
      • randint

pytorch花式索引提取topk的张量

问题设定

在这里插入图片描述
或者说,有一个(bs, dim, L)的大张量,索引的index形状为(bs, X),想得到一个(bs, dim, X)的reduced向量。我们在进行topk操作(以减少计算量)的时候经常碰到这种情况。
给出如下两种实现方法,分别使用花式索引(参考informer的代码)以及pytorch的gather方法

代码实现

索引方法

参考https://blog.csdn.net/qq_36560894/article/details/122005808

feature = torch.rand(2,16,4*4)
indices = torch.randint(0,16, (2, 3))
indices
indices_expand = indices.unsqueeze(1).expand(-1, dim, -1).to(torch.long) # (bs, dim, H*W)
indices_expand.shape
indices_expand[:,1,:] # 结果和indices一致,说明在第二个channel上,每个样本的索引是一样的
bs,dim=feature.shape[:2]
bs,dim 
feature_reduce = feature.view(bs, dim, -1)[torch.arange(bs)[:, None, None], torch.arange(dim)[None,:,None], indices_expand]
feature_reduce.shape

在这里插入图片描述
在这里插入图片描述

gather方法

reduce_feature = torch.gather(feature, 2, indices_expand)

验证

两种方法得到的结果完全相同
在这里插入图片描述

补充知识

expand方法

在 PyTorch 中,expand() 方法用于扩展张量的大小。它会在不实际复制数据的情况下,重复张量的元素以填充新的形状。这个方法可以用于广播操作,以便在执行一些需要相同形状的张量之间的数学运算时,使它们具有相同的形状。

下面是使用 expand() 方法的基本用法:

import torch

# 创建一个原始张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 使用 expand 扩展张量的大小
expanded_x = x.expand(2, 3, 4)  # 扩展成维度为(2, 3, 4)的张量

print(expanded_x)

在上面的例子中,我们首先创建了一个形状为 (2, 3) 的原始张量 x。然后,我们使用 expand() 方法将其扩展成一个维度为 (2, 3, 4) 的新张量 expanded_x,该张量的形状是在原始张量形状的基础上每个维度都扩展了一倍。

需要注意的是,expand() 方法只能用于增加张量的大小,不能减小。另外,扩展后的张量与原始张量共享底层数据,因此在原始张量上进行的任何修改都会反映在扩展后的张量上,反之亦然。

gather方法

在 PyTorch 中,gather() 方法用于从输入张量中按照指定索引提取元素。这个方法通常用于根据索引收集特定的元素,例如根据类别索引从分类得分张量中获取对应类别的得分。

下面是使用 gather() 方法的基本用法:

import torch

# 创建一个输入张量
input_tensor = torch.tensor([[1, 2],
                             [3, 4],
                             [5, 6]])

# 创建一个索引张量
indices = torch.tensor([[0, 0],
                        [1, 0]])

# 使用 gather 方法根据索引收集元素
output_tensor = torch.gather(input_tensor, dim=1, index=indices)

print(output_tensor)

在上面的例子中,我们首先创建了一个形状为 (3, 2) 的输入张量 input_tensor,以及一个形状为 (2, 2) 的索引张量 indices。然后,我们使用 gather() 方法从输入张量 input_tensor 中按照索引张量 indices 收集元素。

gather() 方法中,参数 dim 指定了在哪个维度上进行收集操作,而 index 参数指定了收集元素所使用的索引张量。

需要注意的是,索引张量 indices 的形状必须与输出张量的形状一致,或者是可以广播成与输出张量形状一致的形状。

randint

torch.randint() 是 PyTorch 中用于生成随机整数张量的函数。它可以生成一个张量,其中的元素是在指定范围内随机抽样的整数。

下面是 torch.randint() 的基本用法示例:

import torch

# 生成一个形状为 (3, 3) 的随机整数张量,范围是 [0, 10)
random_integers = torch.randint(low=0, high=10, size=(3, 3))

print(random_integers)

在上面的示例中,我们使用了 torch.randint() 函数来生成一个形状为 (3, 3) 的随机整数张量,其中的元素取值范围在闭区间 [low, high) 内,即从 0 到 9。

torch.randint() 函数的主要参数包括:

  • low:生成的随机整数的最小值(包含)。
  • high:生成的随机整数的最大值(不包含)。
  • size:生成的张量的形状。

你也可以不指定 low 参数,默认情况下它为 0。此外,还可以使用其他参数来控制生成的随机整数张量的设备类型、数据类型等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/385847.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java 基于 SpringBoot 的大药房管理系统

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

备战蓝桥杯---动态规划(入门1)

先补充一下背包问题: 于是,我们把每一组当成一个物品,f[k][v]表示前k组花费v的最大值。 转移方程还是max(f[k-1][v],f[k-1][v-c[i]]w[i]) 伪代码(注意循环顺序): for 所有组: for vmax.....0…

使用Word Embedding+Keras进行自然语言处理NLP

目录 介绍: one-hot: pad_sequences: 建模: 介绍: Word Embedding是一种将单词表示为低维稠密向量的技术。它通过学习单词在文本中的上下文关系,将其映射到一个连续的向量空间中。在这个向量空间中,相似的单词在空间…

实现JNDI

实现JNDI 问题陈述 Smart Software Developer Ltd.想要开发一款Web应用程序,它使用servlt基于雇员ID显示雇员信息,雇员ID由用户通过HTML用户界面传递。雇员详细信息存储在Employee_Master表中。另外,Web应用程序应显示网站被访问的次数。 解决方案 要解决上述问题,需要执…

2024.2.6 模拟实现 RabbitMQ —— 数据库操作

目录 引言 选择数据库 环境配置 设计数据库表 实现流程 封装数据库操作 针对 DataBaseManager 单元测试 引言 硬盘保存分为两个部分 数据库:交换机(Exchange)、队列(Queue)、绑定(Binding&#xff0…

腾讯云4核8G服务器够用吗?能支持多少人?

腾讯云4核8G服务器支持多少人在线访问?支持25人同时访问。实际上程序效率不同支持人数在线人数不同,公网带宽也是影响4核8G服务器并发数的一大因素,假设公网带宽太小,流量直接卡在入口,4核8G配置的CPU内存也会造成计算…

webpack面试解析

参考: 上一篇webpack相关的系列:webpack深入学习,搭建和优化react项目 爪哇教育字节面试官解析webpack-路白 1、Webpack中的module是什么? 通常来讲,一个 module 模块就是指一个文件中导出的内容,webpack…

全国计算机等级考试二级,MySQL数据库考试大纲(2023年版)

基本要求: 1.掌握数据库的基本概念和方法。 2.熟练掌握MySQL的安装与配置。 3.熟练掌握MySQL平台下使用SQL语言实现数据库的交互操作。 4.熟练掌握 MySQL的数据库编程。 5.熟悉 PHP 应用开发语言,初步具备利用该语言进…

Vue-自定义属性和插槽(五)

目录 自定义指令 基本语法 (全局&局部注册) 指令的值 练习:v-loading 指令封装 总结: 插槽(slot) 默认插槽 插槽 - 后备内容(默认值) 具名插槽 具名插槽基本语法: 具名插槽简化语法: 作…

RocksDB:高性能键值存储引擎初探

在现代的分布式系统和大数据应用中,一个高效、可靠的存储引擎是不可或缺的。RocksDB,由Facebook于2012年开发并随后开源,正是为了满足这类需求而诞生的。它是一个持久化的键值存储系统,特别适合在闪存(Flash&#xff0…

AtCoder Beginner Contest 340 C - Divide and Divide【打表推公式】

原题链接:https://atcoder.jp/contests/abc340/tasks/abc340_c Time Limit: 2 sec / Memory Limit: 1024 MB Score: 300 points 问题陈述 黑板上写着一个整数 N。 高桥将重复下面的一系列操作,直到所有不小于2的整数都从黑板上移除: 选择…

SCI论文作图规范

SCI论文作图规范包括以下几个方面: 一、图片格式 SCI论文通常接受的图片格式包括TIFF、EPS和PDF等。其中,TIFF格式是一种高质量的图像格式,适用于需要高分辨率和颜色准确性的图片;EPS格式是一种矢量图形格式,适用于需…

Leetcode 1035 不相交的线

题意理解: 在两条独立的水平线上按给定的顺序写下 nums1 和 nums2 中的整数。 现在,可以绘制一些连接两个数字 nums1[i] 和 nums2[j] 的直线,这些直线需要同时满足满足: nums1[i] nums2[j]且绘制的直线不与任何其他连线&#xff…

Spring 事务

Spring 事务传播(Propagation)特性 REQUIRED 支持一个当前的事务,如果不存在创建一个新的。SUPPORTS 支持一个当前事务,如果不存在以非事务执行。MANDATORY 支持一个当前事务,如果不存在任何抛出异常。REQUIRES_NEW 创…

百面嵌入式专栏(面试题)驱动开发面试题汇总 2.0

沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们将介绍驱动开发面试题 。 1、Linux系统的组成部分? Linux内核、Linux文件系统、Linux shell、Linux应用程序。 2、Linux内核的组成部分? (1)第一种分类方式:内存管理子系统、进程管理子系统、文件管理子系…

react【六】 React-Router

文章目录 1、Router1.1 路由1.2 认识React-Router1.3 Link和NavLink1.4 Navigate1.5 Not Found页面配置1.6 路由的嵌套1.7 手动路由的跳转1.7.1 在函数式组件中使用hook1.7.2 在类组件中封装高阶组件 1.8 动态路由传递参数1.9 路由的配置文件以及懒加载 1、Router 1.1 路由 1.…

VC++ 绘制折线学习

win32 有三个绘制折线的函数; Polyline,根据给定点数组绘制折线; PolylineTo,除了绘制也更新当前位置; PolyPolyline,绘制多条折线,第一个参数是点数组,第二个参数是一个数组、指…

论文阅读-One for All : 动态多租户边缘云平台的统一工作负载预测

论文名称:One for All: Unified Workload Prediction for Dynamic Multi-tenant Edge Cloud Platforms 摘要 多租户边缘云平台中的工作负载预测对于高效的应用部署和资源供给至关重要。然而,在多租户边缘云平台中,异构的应用模式、可变的基…

HTTP的基本格式

HTTP的基本格式 .HTTPhttp的协议格式HTTPhttp的协议格式 . HTTP 应用层,一方面是需要自定义协议,一方面也会用到一些现成的协议. HTTP协议,就是最常用到的应用层协议. 使用浏览器,打开网站,使用手机app,加载数据,这些过程大概率都是HTTP来支持的 HTTP是一个超文本传输协议, 文…

家政小程序系统源码开发:引领智能生活新篇章

随着科技的飞速发展,小程序作为一种便捷的应用形态,已经深入到我们生活的方方面面。尤其在家庭服务领域,家政小程序的出现为人们带来了前所未有的便利。它不仅简化了家政服务的流程,提升了服务质量,还为家政服务行业注…