李沐64_注意力机制——自学笔记

注意力机制

1.卷积、全连接和池化层都只考虑不随意线索

2.注意力机制则显示的考虑随意线索
(1)随意线索倍称之为查询(query)

(2)每个输入是一个值value,和不随意线索key的对

(3)通过注意力池化层来有偏向性的选择某些输入

总结

注意力机制中,通过query(随意线索)和key(不随意线索)来有偏向性的选择输入

代码实现:注意力汇聚:Nadaraya-Watson 核回归

!pip install d2l
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

在这里生成了50个训练样本和50个测试样本。

n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数
n_test
50

下面的函数将绘制所有的训练样本(样本由圆圈表示), 不带噪声项的真实数据生成函数f
(标记为“Truth”), 以及学习得到的预测函数(标记为“Pred”)。

def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);

平均汇聚

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

在这里插入图片描述

非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

在这里插入图片描述

现在来观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

在这里插入图片描述

批量矩阵乘法

X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。

weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
tensor([[[ 4.5000]],

        [[14.5000]]])

使用小批量矩阵乘法, 定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

将训练数据集变换为键和值用于训练注意力模型。

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降。

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

在这里插入图片描述

如下所示,训练完带参数的注意力汇聚模型后可以发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

在这里插入图片描述

与非参数的注意力汇聚模型相比, 带参数的模型加入可学习的参数后, 曲线在注意力权重较大的区域变得更不平滑。

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/570393.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

客服话术分享:客服如何挖掘需求?

电商客服主动挖掘询问顾客需求是非常重要的,这就需要我们具备一定的沟通技巧。今天这篇客服话术分享,很适合想提升业绩的你们哦! 一、打招呼式询问需求: 1.欢迎光临,本店竭诚为您服务~请问您有什么具体想了解的问题吗&…

java-spring 06 图灵 getBean方法和 doGetBean方法

01.一般的流程是,这里是从上一章的preInstantiateSingleton方法顺序过来的。 getBean() -> doGetBean() -> createBean() -> doCreateBean() -> createBeanInstance() -> populateBean() -> initializeBean() 02.getBean方法,一般就…

C语言(1):初识C语言

0 安装vs2022 见 鹏哥视频即可 1 什么是C语言 c语言擅长的是底层开发! 现在一般用的是C89和C90的标准 主要的编辑器: 2 第一个C语言项目 .c 源文件 .h头文件 .cpp c文件 c语言代码中一定要有main函数 标准主函数的写法: int main() { …

菜鸟Java面向对象 1. Java继承

1. Java继承 Java继承 1. Java继承1. 继承的概念_简单介绍继承的用处生活中的继承: 2. 类的继承格式类的继承格式 3. 为什么需要继承企鹅类:老鼠类:公共父类:企鹅类:老鼠类: 4. 继承类型_多重继承5. 继承的…

视频怎么批量压缩?5个好用的电脑软件和在线网站

视频怎么批量压缩?有时候我们需要批量压缩视频来节省存储空间,便于管理文件和空间,快速的传输发送给他人。有些快捷的视频压缩工具却只支持单个视频导入,非常影响压缩效率,那么今天就向大家从软件和在线网站2个角度介绍…

AI建模效果到底行不行?试用这些AI工具告诉你!

当前AI大模型技术浪潮正掀起一股颠覆性的变革浪潮。诸如Midjourney、Stable Diffusion等AI绘画生成工具变得日益成熟,赋能千行百业。在之前的文章中我给大家介绍了很多Midjourney、Stable Diffusion的使用方法和对应的功能: Midjourney vs Stable Diffu…

【连接管理,三次握手,拥塞控制原理】

文章目录 连接管理TCP连接管理同意建立连接TCP3次握手3次握手解决:半连接和接受老数据问题TCP:关闭连接 拥塞控制原理拥塞控制的方法 连接管理 TCP连接管理 TCP连接管理 在正式交换数据之前,发送方和接收方握手建立通信关系: 同…

ECharts海量数据渲染解决卡顿

file模块用来写文件 我们首先使用node来生成10万条数据; 借助node的fs模块就行; 如果不会的小伙伴;也不要担心;超级简单// 引入模块 let fs = require(fs); // 数据内容 let fileCont=我是文件内容 /*** 第一个参数是文件名* 第二个参数是文件内容,这个文件的内容必须是字…

内容平台加码旅游:谁是下一个网红城市

“姐妹们,你们五一啥安排?”早在3月中旬,小威就在询问两个好朋友的行程,“不早早问,怕约不上你们。” 去年以来,国人的旅游需求快速复苏,像小威的朋友一样,之前爱玩的、不爱玩的似乎…

使用Unity扫描场景内的二维码,使用插件ZXing

使用Unity扫描场景内的二维码,使用插件ZXing 使用Unity扫描场景内的二维码,ZXing可能没有提供场景内扫描的方法,只有调用真实摄像机扫描二维码的方法。 实现的原理是:在摄像机上添加脚本,发射射线,当射线打…

世界首台能探测单个原子的量子模拟器,诞生!

量子物理学依赖于高精度的传感技术,以便深入研究材料的微观特性。近期开发的模拟量子处理器显示出量子气体显微镜在原子层面理解量子系统方面的强大潜力。这种显微镜可以生成极高分辨率的量子气体图像,甚至能够检测到单个原子。 在西班牙巴塞罗那的ICFO&…

XxlJob外网访问

Xxl-Job使用外网访问 服务注册中心配置 ### web server.port8088 server.servlet.context-path/xxl-job-admin### actuator management.server.base-path/actuator management.health.mail.enabledfalse### resources spring.mvc.servlet.load-on-startup0 spring.mvc.static…

Java练习题

打印9*9乘法口诀表 解析&#xff1a;利用for循环解决 代码如图所示&#xff1a; public class Cc {public static void main(String[] args) {for (int i 1; i < 10; i){ //从1遍历到9 for(int j 1; j < i; j){ System.out.print(j "*" i "&…

由于找不到steam_api64.dll,无法继续执行代码的解决方法

当用户在尝试启动某款基于Steam平台的游戏时&#xff0c;遇到了“游戏显示找不到steam_api64.dll”的错误提示&#xff0c;这会导致无法正常启动游戏。这究竟是什么原因导致的呢&#xff1f;本文将介绍五种解决方法&#xff0c;帮助大家解决这一问题。 一&#xff0c;了解steam…

实现ALV页眉页脚

1、文档介绍 在ALV中&#xff0c;可以通过增加页眉和页脚&#xff0c;丰富ALV的展示。除了基本的页眉和页脚&#xff0c;还可以通过插入HTML代码的方式展示更加丰富的页眉和页脚&#xff0c;本篇文章将介绍ALV和OOALV中页眉页脚的使用。 2、ALV页眉页脚 效果如下 2.1、显示内…

对于地理空间数据,PostGIS扩展如何在PostgreSQL中存储和查询地理信息?

文章目录 一、PostGIS扩展简介二、PostGIS存储地理空间数据1. 创建空间数据表2. 插入空间数据 三、PostGIS查询地理空间数据1. 查询指定范围内的地理空间数据2. 计算地理空间数据之间的距离3. 对地理空间数据进行缓冲区分析 四、总结 地理空间数据是指描述地球表面物体位置、形…

翻译《The Old New Thing》 - What‘s so special about the desktop window?

Whats so special about the desktop window? - The Old New Thing (microsoft.com)https://devblogs.microsoft.com/oldnewthing/20040224-00/?p40493 Raymond Chen 2004年02月24日 简介 桌面窗口在 Windows 编程中具有特殊的地位&#xff0c;因为它代表整个桌面环境。滥用…

常见大厂面试题(SQL)01

知乎问答最大连续回答问题天数大于等于3天的用户及其对应等级 1.描述 现有某乎问答创作者信息表author_tb如下(其中author_id表示创作者编号、author_level表示创作者级别&#xff0c;共1-6六个级别、sex表示创作者性别)&#xff1a; author_id author_level sex 101 …

ARP 攻击神器:ARP Spoof 保姆级教程

一、介绍 arpspoof是一种网络工具&#xff0c;用于进行ARP欺骗攻击。它允许攻击者伪造网络设备的MAC地址&#xff0c;以欺骗其他设备&#xff0c;并截获其通信。arpspoof工具通常用于网络渗透测试和安全评估&#xff0c;以测试网络的安全性和漏洞。 以下是arpspoof工具的一些…

LabVIEW学习记录 - 实时显示时间

LabVIEW操作 - 实时显示时间 在程序框图&#xff0c;选择函数->定时->格式化日期/时间字符串 该函数的使用手册说明&#xff1a; 鼠标选择“格式化日期/时间字符串”->创建->输入控件->输入格式 查看时间代码格式&#xff1a; 编程->定时->获取时间日…
最新文章