Pytorch:张量的形状操作

文章目录

      • 一、维度改变
        • 1.flatten展开
          • a.函数的基本用法
          • b.示例
        • 2.unsqueeze增维
          • a.函数的基本用法
          • b.示例
        • 3.squeeze降维
          • a.函数的基本用法
          • b.示例
      • 二、张量变形
        • 1.view()
          • a.函数的基本用法
          • b.参数:
          • c.注意事项
          • d.示例
        • 2.reshape()
          • a.注意事项
          • b.示例
        • 3.reshape_as()
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意
      • 三、维度重排
        • 1.permute
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意
        • 2.transpose
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意

维度改变和张量变形都不改变内存中存储的结构,因此改变后的张量的值顺序和没改变前是一样的。

一、维度改变

1.flatten展开
  • torch.flatten(tensor)
  • tensor.flatten()

torch.flatten() 是一个在 PyTorch 中常用于张量(tensor)处理的函数,它将输入张量展开成一个一维张量。该函数通常用于准备数据,将多维数据转换为一维,以便用于机器学习模型,特别是在模型的全连接层(fully connected layers)之前。
常用于展开成一维

a.函数的基本用法

只给定一个张量,将直接展开成一维。
torch.flatten(input, start_dim=0, end_dim=-1) 的参数解释如下:

  • input: 输入的张量。
  • start_dim: 开始展开的维度,默认为 0。这意味着从哪个维度开始将张量展开。
  • end_dim: 结束展开的维度,默认为 -1,即最后一个维度。这意味着展开将持续到哪个维度。
b.示例

考虑一个三维张量,例如形状为 (2, 3, 4) 的张量。如果使用 torch.flatten() 将其展开,可以有多种方式处理:

  1. 完全展开: 将整个张量展开成一维数组。

    import torch
    x = torch.randn(2, 3, 4)
    flat_x = torch.flatten(x)
    # 结果形状为 [24]
    
  2. 从特定维度开始展开: 指定从哪个维度开始展开。例如,从第一维(索引为 0 的维度)开始展开。

    flat_x = torch.flatten(x, start_dim=1)
    # 结果形状为 [2, 12],保留了第一个维度,其余维度被展开
    
2.unsqueeze增维
  • torch.unsqueeze(tensor)
  • tensor.unsqueeze()

torch.unsqueeze() 是 PyTorch 中用来增加张量的维度的函数。该函数可以在张量的指定位置插入一个维度,它非常有用于调整张量的形状,以满足特定操作或模型的需求,例如在单样本张量上应用需要批处理的模型。
常用于在第0个维度上增加大小为1的维度

a.函数的基本用法

torch.unsqueeze(input, dim) 的参数解释如下:

  • input: 输入的张量。
  • dim: 要插入新维度的索引位置。这个位置遵循 Python 的索引规则,支持负索引。
b.示例

假设有一个二维张量 x 形状为 (3, 4),表示一个包含3个样本,每个样本4个特征的数据集。如果需要在特定维度增加一个维度,可以使用 torch.unsqueeze() 如下:

import torch
x = torch.randn(3, 4)

# 在第0维增加一个维度
x_unsqueezed = x.unsqueeze(0)
print(x_unsqueezed.shape)
# 输出: torch.Size([1, 3, 4])

# 在第1维增加一个维度
x_unsqueezed = torch.unsqueeze(x, 1)
print(x_unsqueezed.shape)
# 输出: torch.Size([3, 1, 4])

# 使用负索引,在最后一个维度后增加一个维度
x_unsqueezed = torch.unsqueeze(x, -1)
print(x_unsqueezed.shape)
# 输出: torch.Size([3, 4, 1])
3.squeeze降维
  • torch.squeeze(tensor)
  • tensor.squeeze()

torch.squeeze() 是 PyTorch 中的一个函数,用于减少张量的维度,特别是去除那些维度大小为1的维度。这个函数非常有用于去除由于某些操作(比如 unsqueeze)产生的单一维度,从而使张量的形状更加紧凑。

a.函数的基本用法

只给定一个张量,将直接去掉所有大小为1的维度。
torch.squeeze(input, dim=None) 的参数解释如下:

  • input: 输入的张量。
  • dim: 指定要压缩的维度。如果指定的维度大小为1,则该维度会被去除如果大小不为1,则该维度不会被压缩如果不指定 dim 参数,那么所有大小为1的维度都会被去除。
b.示例

考虑一个张量 x,其形状包括一些大小为1的维度。以下是如何使用 torch.squeeze() 来去除这些维度的示例:

import torch
x = torch.randn(1, 3, 1, 5)

# 去除所有大小为1的维度
squeezed_x = x.squeeze()
print(squeezed_x.shape)
# 输出: torch.Size([3, 5])

# 只压缩第0维(大小为1)
squeezed_x = x.squeeze(0)
print(squeezed_x.shape)
# 输出: torch.Size([3, 1, 5])

# 只压缩第2维(大小为1)
squeezed_x = torch.squeeze(x, 2)
print(squeezed_x.shape)
# 输出: torch.Size([1, 3, 5])

# 尝试压缩一个不是大小为1的维度(没有变化)
squeezed_x = torch.squeeze(x, 1)
print(squeezed_x.shape)
# 输出: torch.Size([1, 3, 1, 5])

二、张量变形

1.view()

在 PyTorch 中,.view() 方法是一个非常重要且常用的功能,用于改变张量的形状而不改变其数据内容。此方法提供了一种高效的方式来重新排列张量的维度,使其适应不同的需求,例如输入到一个模型或对数据进行不同的操作。
view是共享内存的!

a.函数的基本用法

.view() 方法的基本用法是 tensor.view(*shape),其中 *shape 是希望张量拥有的新形状,由一组维度大小组成。

b.参数:
  • shape: 新的形状,是一个由整数构成的元组,其中的每个整数指定相应维度的大小。你也可以在某个位置使用 -1,让 PyTorch 自动计算该维度的大小。(注意某个位置是任意的某个位置,但是只能有一个)
c.注意事项
  1. 连续性.view() 要求张量在内存中是连续的(即一维数组中的元素顺序与多维视图中的顺序相同)。如果张量不是连续的,你可能需要首先调用 .contiguous() 方法来使其连续。

  2. 自动计算维度:使用 -1 作为形状参数的一部分,PyTorch 将自动计算该维度的正确大小,以便保持元素总数不变。

  3. 大小不变.view()要求张量变换形状之后的大小和变换之前的大小是一样的。即维度大小之积相等。比如tensor.Size([2,4])tensor.Size([8])是一样的。

d.示例
import torch
x = torch.randn(4, 4)  # 创建一个 4x4 的张量

# 改变形状为 2x8
y = x.view(2, 8)
print(y.shape)
# 输出: torch.Size([2, 8])

# 改变形状为 16(一维)
z = x.view(-1)#z = x.view(16)
print(z.shape)
# 输出: torch.Size([16])

# 使用 -1 自动计算维度
w = x.view(-1, 8)
print(w.shape)
# 输出: torch.Size([2, 8])
import torch
x = torch.randn(2, 1)  # 创建一个 2×1 的张量

# 改变形状为 2x8
y = x.view(2)
print(y)
# 输出: torch.Size([2, 8])
x[0][0]=2 #共享内存,y也会变
print(x)
print(y)
tensor([-0.5001,  0.5409])
tensor([[2.0000],
        [0.5409]])
tensor([2.0000, 0.5409])
2.reshape()

在 PyTorch 中,.reshape() 方法用于改变张量的形状而不改变其数据内容。
这一方法与 .view() 类似,都允许您重新排列张量的维度,但它们在处理非连续张量时的行为不同。
只有当非连续张量时,才会导致和.view不一样,如果是连续的,同样也是共享内存的。

a.注意事项
  1. 数据连续性:与 .view() 相比,.reshape() 可以处理非连续张量,如果必要,它会自动处理数据的内存复制。因此,如果原始张量不连续,而你尝试用 .view() 改变其形状可能会导致错误,但 .reshape() 会自动解决这个问题。

  2. 自动计算维度:使用 -1 作为形状参数的一部分时,PyTorch 会自动计算该维度的大小,以确保总元素数量与原张量相同。

b.示例
import torch
x = torch.randn(2, 3, 4)  # 创建一个 2x3x4 的张量

# 改变形状为 6x4
y = x.reshape(6, 4)
print(y.shape)
# 输出: torch.Size([6, 4])

# 改变形状为 1x24
z = x.reshape(1, 24)
print(z.shape)
# 输出: torch.Size([1, 24])

# 使用 -1 自动计算维度
w = x.reshape(-1, 2)
print(w.shape)
# 输出: torch.Size([12, 2])
import torch
x = torch.randn(2, 2)  # 创建一个 2x1 的张量
x=x.transpose(0,1)
# 改变形状为 2x8
y = x.reshape(4)#转置后的x不是连续的,使用reshape产生复制,此时不能用.view()
print(y)
# 输出: torch.Size([2, 8])
x[0][0]=100
print(x)
print(y)
tensor([-0.5386, -0.3646, -0.1661, -0.2516])
tensor([[100.0000,  -0.1661],
        [ -0.3646,  -0.2516]])
tensor([-0.5386, -0.3646, -0.1661, -0.2516])
3.reshape_as()

在 PyTorch 中,.reshape_as() 是一个方便的方法,用于将一个张量重新塑形为与另一个张量相同的形状。这个方法实质上是 .reshape() 方法的一个简化版本,它以另一个张量的形状为目标形状。
换句话说,.reshape_as()相当于是省略了自指定参数的.reshape(),而可以直接用目标张量形状作为形状。

a.函数的基本用法

.reshape_as() 的基本用法非常直接:tensor1.reshape_as(tensor2)。这会将 tensor1 的形状修改为与 tensor2 相同的形状。

b.参数:
  • tensor2: 这是模型张量,tensor1 将改变形状以匹配 tensor2 的形状。
c.示例
import torch
x = torch.randn(2, 3, 4)  # 原始张量,形状为 2x3x4
y = torch.randn(6, 4)     # 目标张量,形状为 6x4

# 将 x 的形状改变为与 y 相同
z = x.reshape_as(y)
print(z.shape)
# 输出: torch.Size([6, 4])
d.注意

虽然 .reshape_as() 很方便,但使用它时应确保两个张量具有相同的元素总数,因为改变形状的操作不会改变数据的总量。如果两个张量的总元素数量不匹配,尝试使用 .reshape_as() 将抛出错误。此外,如果原始张量在内存中是非连续的,.reshape_as() 会像 .reshape() 一样处理,可能需要在内部进行数据复制以确保连续性。

三、维度重排

permute方法可以按照指定顺序重新排列维度,而transpose方法可以交换张量的两个维度。用于需要进行维度重排或转置操作。如矩阵转置。

1.permute

在 PyTorch 中,.permute() 方法用于重新排列张量的维度,这是处理多维数据时一个非常有用的功能,尤其在需要对维度进行特定的重排序操作时。

a.函数的基本用法

.permute() 方法的调用格式为 tensor.permute(*dims),其中 *dims 是一个整数序列,代表新的维度排列顺序。

b.参数:
  • dims: 这个参数定义了张量的每个维度应该如何重新排列。序列中的每个整数都代表原始张量中一个维度的索引,这些索引的排列顺序确定了输出张量的形状。
c.示例
import torch
x = torch.randn(2, 3, 5)  # 创建一个形状为 [2, 3, 5] 的张量

# 改变维度的排列顺序为 [2, 0, 1]
y = x.permute(2, 0, 1)
print(y.shape)
# 输出: torch.Size([5, 2, 3])

# 将维度的排列顺序改为 [1, 2, 0]
z = x.permute(1, 2, 0)
print(z.shape)
# 输出: torch.Size([3, 5, 2])
d.注意
import torch
x = torch.tensor([[1,2,3,4],[2,4,2,4],[5,6,7,8]]) 
x = x.permute(1,0)
'''
tensor([[1, 2, 5],
        [2, 4, 6],
        [3, 2, 7],
        [4, 4, 8]])
'''

在 PyTorch 中,当使用 .permute() 方法重排张量维度时,张量的数据实际上在内存中的位置并没有改变。更准确地说.permute() 改变的是张量访问这些数据的方式,通过调整形状(shape)步长(stride) 的元信息,而不是数据本身。

  • 步长(Stride)
    • 步长是一个定义在每一维上的整数数组,表示为了在数据中从当前维度的一个元素移动到下一个元素,需要跨过的内存位置数。对于一个连续的张量,步长决定了元素在内存中的布局。

形状(Shape)和步长的调整当调用 .permute(1,0) 时,你实际上是告诉 PyTorch 以一个新的顺序来解释原始数据的内存布局。例如:

x = torch.tensor([[1, 2, 3, 4],
                  [2, 4, 2, 4],
                  [5, 6, 7, 8]])

原始的 x 的形状为 (3, 4),即有 3 行和 4 列。在 PyTorch 中,这意味着其步长为 (4, 1),其中 4 表示要从一行的开始移动到下一行的开始,在内存中需要跨过 4 个元素位置;1 表示在同一行中从一个元素移动到下一个元素,只需要跨过 1 个元素位置。

当你调用 x.permute(1, 0) 时,你是在指示 PyTorch 将原来的列视为行,将原来的行视为列。这就改变了形状为 (4, 3)。这时,步长变为 (1, 4)。这意味着:

  • 要从列的一个元素到下一个元素(现在变成了“行”移动),你只需要移动一个数据位置(原来的行移动)。
  • 要从一行移动到下一行(现在是原来的列跨行移动),你需要跨过 4 个数据位置。
2.transpose

在 PyTorch 中,.transpose() 方法用于交换张量中的两个维度,这是处理多维数组时一个常用的功能,尤其是在需要对特定的维度进行转置操作时。

a.函数的基本用法

.transpose() 方法的调用格式为 tensor.transpose(dim0, dim1),其中 dim0dim1 是要交换的维度的索引。

b.参数:
  • dim0: 第一个要交换的维度的索引。
  • dim1: 第二个要交换的维度的索引。
c.示例
import torch
x = torch.randn(2, 3, 5)  # 创建一个形状为 [2, 3, 5] 的张量

# 交换维度 0 和 1
y = x.transpose(0, 1)
print(y.shape)
# 输出: torch.Size([3, 2, 5])

# 交换维度 1 和 2
z = x.transpose(1, 2)
print(z.shape)
# 输出: torch.Size([2, 5, 3])
d.注意

.permute() 类似,.transpose() 也是返回原始数据的一个新视图,并不复制数据。因此,输出张量与输入张量共享同一块内存空间,只是它们的形状和步长(stride)不同。同样,.transpose() 会导致张量在内存中可能变为非连续,因此在某些情况下,可能需要调用 .contiguous() 来使张量在内存中连续。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/552819.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

OpenHarmony实战开发-如何通过分割swiper区域,实现指示器导航点位于swiper下方的效果。

介绍 本示例介绍通过分割swiper区域,实现指示器导航点位于swiper下方的效果。 效果预览图 使用说明 1.加载完成后swiper指示器导航点,位于显示内容下方。 实现思路 1.将swiper区域分割为两块区域,上方为内容区域,下方为空白区…

HAL STM32 I2C方式读取MT6701磁编码器获取角度例程

HAL STM32 I2C方式读取MT6701磁编码器获取角度例程 📍相关篇《Arduino通过I2C驱动MT6701磁编码器并读取角度数据》🎈《STM32 软件I2C方式读取MT6701磁编码器获取角度例程》📌MT6701当前最新文档资料:https://www.magntek.com.cn/u…

生产服务器变卡怎么排查

服务器变卡怎么排查,可以从以下四个方面去考虑 生产服务器变卡怎么排查 1、网络2、cpu的利用率3、io效率4、内存瓶颈 1、网络 可以使用netstat、iftop等工具查看网络流量和网络连接情况,检查是否网络堵塞、丢包等问题 2、cpu的利用率 1、用top命令定…

VMWare Ubuntu压缩虚拟磁盘

VMWare中ubuntu会越用越大,直到占满预分配的空间 即使系统里没有那么多东西 命令清理 开机->open Terminal sudo vmware-toolbox-cmd disk shrink /关机-> 编辑虚拟机设置->硬盘->碎片整理&压缩 磁盘应用 开机->disk usage analyzer(应用) …

【LLM】认识LLM

文章目录 1.LLM1.1 LLM简介1.2 LLM发展1.3 市面常见的LLM1.4 LLM涌现的能力 2.RAG2.1 RAG简介2.2 RAG 的工作流程2.3 RAG 和 Finetune 对比2.4 RAG的使用场景分析 3. LangChain3.1 LangChain简介3.2 LangChain的核心组件3.3 LangChain 入门 4.开发 RAG 应用的整体流程5. 环境配…

stm32中的中断优先级

在工作中使用到多个定时器中断,由于中断的中断优先级不熟悉导致出错,下面来写一下中断的一些注意事项。 一、中断的分类 1、EXTI外部中断:由外部设备或外部信号引发,例如按键按下、外部传感器信号变化等。外部中断用于响应外部事件,并及时处理相关任务。 2、内部中断:…

搭建Zookeeper完全分布式集群(CentOS 9 )

ZooKeeper是一个开源的分布式协调服务,它为分布式应用提供了高效且可靠的分布式协调服务,并且是分布式应用保证数据一致性的解决方案。该项目由雅虎公司创建,是Google Chubby的开源实现。 分布式应用可以基于ZooKeeper实现诸如数据发布/订阅…

Jmeter 测试-跨线程调用变量

1、Jmeter中线程运行规则 ①各个线程组是完全独立的,每个线程组是不同的业务,互不影响 ②线程组中的每个线程也是完全独立 ③线程组中的每个线程,都是从上往下执行,完成一轮循环后,继续下一轮循环 ④存在业务流或者…

考察自动化立体库应注意的几点

导语 大家好,我是智能仓储物流技术研习社的社长,老K。专注分享智能仓储物流技术、智能制造等内容。 整版PPT和更多学习资料,请球友到知识星球 【智能仓储物流技术研习社】自行下载 考察自动化立体仓库的关键因素: 仓库容量&#x…

python爬虫之爬取微博评论(4)

一、获取单页评论 随机选取一个微博,例如下面这个 【#出操死亡女生家属... - 冷暖视频的微博 - 微博 (weibo.com) 1、fnf12,然后点击网络,搜索评论内容,然后预览,就可以查看到网页内容里面还有评论内容 2、编写代码…

稀碎从零算法笔记Day51-LeetCode:最小路径和

题型:DP、数组、矩阵 链接:64. 最小路径和 - 力扣(LeetCode) 来源:LeetCode 题目描述 给定一个包含非负整数的 m x n 网格 grid ,请找出一条从左上角到右下角的路径,使得路径上的数字总和为…

适用于Windows电脑的最佳数据恢复软件是哪些?10佳数据恢复软件

丢失我们系统中可用的宝贵信息是很烦人的。我们可以尝试几种手动方法来重新获取丢失的数据。然而,当我们采用非自动方法来恢复数据时,这是一项令人厌烦和乏味的工作。在这种情况下,我们可以尝试使用一些正版硬盘恢复软件进行数据恢复。此页面…

Dual-AMN论文阅读

Boosting the Speed of Entity Alignment 10: Dual Attention Matching Network with Normalized Hard Sample Mining 将实体对齐速度提高 10 倍:具有归一化硬样本挖掘的双重注意力匹配网络 ABSTRACT 寻找多源知识图谱(KG)中的等效实体是知识图谱集成的关键步骤&…

TRIZ理论下攀爬机器人的创新设计与研究

随着科技的飞速发展,机器人技术已广泛应用于各个领域。特别是在复杂环境下的作业,如灾难救援、太空探测等,对机器人的移动能力和适应性提出了更高要求。在这样的背景下,基于TRIZ理论的攀爬机器人设计与研究应运而生,它…

分类算法——朴素贝叶斯(四)

概率基础 1概率定义 概率定义为一件事情发生的可能性 扔出一个硬币,结果头像朝上 P(X):取值在[0,1] 2女神是否喜欢计算案例 在讲这两个概率之前我们通过一个例子,来计算一些结果: 问题如下: 1、女神喜欢…

Python pyglet制作彩色圆圈“连连看”游戏

原文链接: Python 一步一步教你用pyglet制作“彩色方块连连看”游戏(续)-CSDN博客文章浏览阅读1.6k次,点赞75次,收藏55次。上期讲到相同的色块连接,链接见: Python 一步一步教你用pyglet制作“彩色方块连连看”游戏-…

Python基于Django搜索的目标站点内容监测系统设计,附源码

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

【Android GUI】FramebufferNativeWindow与Surface

文章目录 显示整体体系FramebufferNativeWindowFramebufferNativeWindow构造函数 dequeueBufferSurface总结参考 显示整体体系 native window为OpenGL与本地窗口系统之间搭建了桥梁。 这个窗口系统中,有两类本地窗口,nativewindow1是能直接显示在屏幕的…

超平实版Pytorch CNN Conv2d

torch.nn.Conv2d 基本参数 in_channels (int) 输入的通道数量。比如一个2D的图片,由R、G、B三个通道的2D数据叠加。 out_channels (int) 输出的通道数量。 kernel_size (int or tuple) kernel(也就是卷积核,也可…

selenium反反爬虫,隐藏selenium特征

一、stealth.min.js 使用 用selenium爬网页时,常常碰到被检测到selenium ,会被服务器直接判定为非法访问,这个时候就可以用stealth.min.js 来隐藏selenium特征,达到绕过检测的目的 from selenium import webdriver from seleniu…