Vit Transformer

一 VitTransformer 介绍

vit : An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

        论文是基于Attention Is All You Need,由于图像数据和词数据数据格式不一样,经典的transformer不能处理图像数据,在视觉领域的应用有限。本文提出的方法可以将transformer直接应用图像分类任务,引入Patch Embedding,位置编码等方法,克服了Transformer在处理图像数据时的限制。整体流程如下。

从图中可以看出, Vision Transformer 主要有三个部分组成: 1 ) 第一部分是Linear Projection of Flattened Patches ,也就是 Emdedding 层,主要的工作就是将图像数据转换成transformer可以处理的数据格式。2)第二部分是Transformer Encoder部分,它是vit 最核心的组件(原始的NLP的transformer还有Decoder部分)。它主要是层归一化,多头注意力机制,MLP,Dropout/DropPath四个小block组成,用于学习图像数据。3) 第三部分就是MLP head ,用于分类。

二 PatchEmbedding & Positional Encoding

        首先,每个图像被分割成一系列不重叠的块(16x16或者 32x32),然后做一个线性的embedding ,由于这些块如果并行的输入到transformer中,不提供位置信息,模型不知道这些块的顺序。因此要加一个 positional encoding。 

        在实际的实现上,图像数据是[batch_size, C , H, W] 的格式,要将其变成[batch_size , token_len , dim],其中token_len 可以理解成图像patch token的数量。以[4,3,224,224]的图像为例子,首先我们模拟分割块,对于一个图像,我们要将其分割成 (H*w)/(patch_size*patch_size)个patches,即(224x224)// (16x16) = 196个 patches 。每个patch的大小是(3,16,16),然后我们将其flatten一个768( 3x16x16)dim的 token。这样数据格式就变成[4,196,768]。

代码分割图像块 : 

def split_patches(x, patch_size=16):
    batch_size, channels, height, width = x.shape
    x = x.reshape(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.reshape(batch_size, -1, channels * patch_size * patch_size)
    return x

当然这个过程可以通过卷积实现,官方代码其实就是用卷积来实现的。

class PatchEmbed(nn.Module) :

    def __init__(self,img_size=224,patch_size=16,in_channels=3,embed_dim=768,norm_layer=None):

        super().__init__()
        img_size = (img_size,img_size)
        patch_size = (patch_size,patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0],img_size[1]//patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1] 

        # 卷积层
        self.proj = nn.Conv2d(in_channels=in_channels,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)
        # 归一化
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    def forward(self,x) :
        B,C,H,W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
        f"Input image size ({H}x{W} does not match input image size ({self.img_size[0]}x{self.img_size[1]}"
        x = self.proj(x).flatten(2).transpose(1,2)
        x = self.norm(x)

        return x

 positional Embedding 

由于输入的图像数据patch序列没有能够表达patch之间相对位置关系,因此需要加入位置编码(Positional encoding)这个特征,为了得到不同位置的对应的编码,Transformer模型使用不同频率的正余弦函数

PE(Pos,2i) = sin(\frac{pos}{10000^{2i/d}})

PE(Pos,2i+1) = cos(\frac{pos}{10000^{2i/d}})

 其中 pos是表示token(flattened image patch)的位置,2i和2i+1表示位置编码向量中对应的维度,d是对应位置编码的总维度。

def add_positional_encoding(x, max_len):
    batch_size, patch_numbers, dim = x.shape
    position = torch.arange(max_len).reshape(-1, 1)
    div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))
    pe = torch.zeros((max_len, dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    x += pe[:patch_numbers]
    return x

三 Self-Attention 

        其中最核心的部分就是对于注意力部分了。在基于Transformer的机器翻译模型中,要建模源语言和目标语言任意两个单词的依赖关系,引入自注意力K(键)Q(查询)V(值)。这三个用来计算上下文单词所对应的权重得分,这些权重反映了在编码当前单词时,对于上下文不同部分所需要关注程度。

    同样在vision transformer中,对于一个个image patch token 来说,也需要建模任意 token之间的相互关注关系,当处理当前token时,哪些token与它有更高的关联度。

        上图是论文中的Scaled Dot-Product AttentionMulti-head Attention,我们首先定义三个矩阵 Q,K,V,这三个矩阵是由 输入X([4,196,768])分别经过三个权重矩阵$w_{q},w_{k},w_{v}$得到的。其中 Q 矩阵和K 矩阵,V矩阵是“同源”的,因为它们都是来自于同一个输入序列(图像patch token)的某种表示(线性变换的嵌入表示)。

根据Attenton分数的计算公式,Q(shape=[4,196,768])左乘一个K(shape=[4,196,768])矩阵的转置,得到一个相似度矩阵(shape=[4,196,196]),为了防止过大的相似度数值在后续Softmax计算过程中导致的梯度爆炸以及收敛效率差的问题,因此使用一个缩放因子\sqrt{d}缩放来稳定优化。放缩后的得分经过Softmax函数归一化为概率后,与其他位置的值向量相乘来聚合希望关注的上下文信息,并最小化不相关信息的干扰。

def self_attention(x, w_q, w_k, w_v):
    query = torch.matmul(x, w_q)
    key = torch.matmul(x, w_k)
    value = torch.matmul(x, w_v)
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))
    attention_scores = softmax(scores)
    output = torch.matmul(attention_scores, value)
    return attention_scores, output

过程如下

四 Multi-head self-attention (MSA)

        为了进一步提升自注意力机制的全局信息聚合能力,提出了Multi-head attention机制,具体来说,上下文的每个token 向量的表示x_{i}的经过多组的线性{W_{j}^{Q},W_{j}^{K},W_{j}^{V}}映射到不同的表示空间。计算出不同子空间得到的attention score得到{Z_{j}}_{j=1}^{N},再用一个线性变换w^{o} 用于综合不同子空间中的上下文表示形成最后的输出。

import torch
import torch.nn.functional as F

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, input_dim, num_heads, head_dim):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim

        assert input_dim % self.num_heads == 0
        self.projection_dim = input_dim // self.num_heads

       # 定义 权重矩阵 
        self.weight_q = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))
        self.weight_k = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))
        self.weight_v = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))

       
        self.weight_combine = torch.nn.Parameter(torch.randn(num_heads * self.projection_dim, input_dim))

    def forward(self, x):
        batch_size, seq_length, _ = x.size()

        
        queries = torch.matmul(x, self.weight_q)
        keys = torch.matmul(x, self.weight_k)
        values = torch.matmul(x, self.weight_v)

     
        queries = queries.view(batch_size, seq_length, self.num_heads, self.projection_dim)
        keys = keys.view(batch_size, seq_length, self.num_heads, self.projection_dim)
        values = values.view(batch_size, seq_length, self.num_heads, self.projection_dim)

    
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # 计算得分
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.projection_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
        attention_output = torch.matmul(attention_weights, values)

        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
        output = torch.matmul(attention_output, self.weight_combine)

        return output

input_dim = 64
num_heads = 8
head_dim = input_dim // num_heads
seq_length = 10
batch_size = 4


multihead_attention = MultiHeadAttention(input_dim, num_heads, head_dim)
x = torch.rand(batch_size, seq_length, input_dim)
output = multihead_attention(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)


五 代码实现

import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import transforms
import numpy as np

def softmax(x):
    return torch.nn.functional.softmax(x, dim=-1)

def split_patches(x, patch_size=16):
    batch_size, channels, height, width = x.shape
    x = x.reshape(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.reshape(batch_size, -1, channels * patch_size * patch_size)
    return x

def add_positional_encoding(x, max_len):
    batch_size, patch_numbers, dim = x.shape
    position = torch.arange(max_len).reshape(-1, 1)
    div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))
    pe = torch.zeros((max_len, dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    x += pe[:patch_numbers]
    return x

def plot_heatmap(scores,index,name ):
    plt.figure(figsize=(8, 6))
    plt.imshow(scores, cmap='hot', interpolation='nearest')
    plt.xlabel('Keys')
    plt.ylabel('Queries')
    plt.title(f'Attention Scores Heatmap {name}')
    plt.colorbar()
    plt.savefig(f"./attention_heatmap{index}.png")
    # plt.show()
def self_attention(x, w_q, w_k, w_v):
    query = torch.matmul(x, w_q)
    key = torch.matmul(x, w_k)
    value = torch.matmul(x, w_v)
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))
    attention_scores = softmax(scores)
    output = torch.matmul(attention_scores, value)
    return attention_scores, output

def plot_heatmap_on_image(image, attention_scores, patch_size=16,index=0):
    # 对每个patch的注意力分数求平均
    attention_scores_mean = attention_scores.mean(dim=1)
    # 将注意力分数转换为与原始图像大小相匹配的热力图
    attention_map = attention_scores_mean.view(1, 1, int(224 / patch_size), int(224 / patch_size))
    attention_map = torch.nn.functional.interpolate(attention_map, size=(224, 224), mode='bilinear', align_corners=False)
    attention_map = attention_map.squeeze().cpu().detach().numpy()


    plt.figure(figsize=(6, 6))
    plt.imshow(image)
    plt.imshow(attention_map, cmap='jet', alpha=0.5)
    plt.axis('off')
    plt.savefig(f'attention_map{index}.png')
    # plt.show()

if __name__ == '__main__':
    batch_size = 4
    channels = 3
    height = 224
    width = 224
    input_dim = 768
    output_dim = 64

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])


    image_paths = ["./images/11.jpg",
                   "./images/15.jpg",
                   "./images/16.jpg",
                   "./images/17.jpg"]

    images = torch.zeros((4, 3, 224, 224), dtype=torch.float32)
    for i, path in enumerate(image_paths):
        img = Image.open(path).convert('RGB')
        img_tensor = transform(img)
        images[i] = img_tensor


    patch_embeddings = split_patches(images, patch_size=16)
    patch_embeddings_pe = add_positional_encoding(patch_embeddings, max_len=196)

    w_q = torch.normal(0, 0.01, size=(input_dim, output_dim))
    w_k = torch.normal(0, 0.01, size=(input_dim, output_dim))
    w_v = torch.normal(0, 0.01, size=(input_dim, output_dim))

    attention_scores, output = self_attention(patch_embeddings_pe, w_q, w_k, w_v)

    # plot_heatmap(attention_scores[0])
    # for index in range(4) :
    #
    #     name = image_paths[index].split('/')[-1].split('.')[0]
    #     plot_heatmap(attention_scores[index],index,name)  # 选择第一张图像的注意力分数进行绘制


    # 将热力图叠加到原始图像上
    for index in range(4) :
        image_path = image_paths[index]
        img = Image.open(image_path).convert('RGB')
        img_tensor = transform(img)

        img_np = np.array(img)

        plot_heatmap_on_image(img_np, attention_scores[index],16,index=index)

参考

  • LLM(廿四):Transformer 的结构改进与替代方案 - 知乎

  • 【深度学习系列】五、Self Attention_self attention 加入位置信息-CSDN博客

  • NLP(五):Transformer及其attention机制 - 知乎

  • 有关vision transformer的一个综述 - 知乎

  • 为什么 Vision transformer 训练和推理很慢? - 知乎

  • 大规模语言模型:从理论到实践 -- 张奇、桂韬、郑锐、黄萱菁

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/493326.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++电子宠物商店

一、功能描述 店内有不同类型的电子宠物 1.每种电子宠物能通过显示出来的文本提出需要或表示情绪如:饿、渴、饱涨、困、不舒服、高兴、生气、伤心、绝望、无聊等。 2.店员用户通过键盘操作“饲养”电子宠物,给它实施喂饭、喂水、带它上厕所、陪它玩耍、…

3.27作业

1、完成下面类 #include <iostream> #include <cstring> using namespace std;class myString { private:char *str; //记录c风格的字符串int size; //记录字符串的实际长度 public://无参构造myString():size(10){str new char[size]; …

凯撒加密.

题目描述 给定一个单词&#xff0c;请使用凯撒密码将这个单词加密 凯撒密码是一种替换加密的技术&#xff0c;单词中的所有字母都在字母表上向后偏移3位后被替换成密文。即a变为 d&#xff0c;变为e.&#xff0c;w 变为z&#xff0c;2 变为a&#xff0c;y变为 6&#xff0c;z变…

javaWeb项目-大学生体质测试管理系统功能介绍

项目关键技术 开发工具&#xff1a;IDEA 、Eclipse 编程语言: Java 数据库: MySQL5.7 框架&#xff1a;ssm、Springboot 前端&#xff1a;Vue、ElementUI 关键技术&#xff1a;springboot、SSM、vue、MYSQL、MAVEN 数据库工具&#xff1a;Navicat、SQLyog 1、JSP技术 JSP(Jav…

python笔记进阶--面向对象(2)

目录 1.类和对象&#xff08;实例&#xff09; 1.1对象&#xff08;实例&#xff09; 1.1.1使用对象组织数据 1.1.2类中增加属性 1.2成员方法&#xff08;类&#xff09; 1.2.1类的定义和使用语法 1.2.2成员方法的使用 1.2.3self关键字的作用 1.3类和对象 2&#xff…

Mysql各种日志管理

文章目录 事务日志事务日志的记录过程事务日志类型事务日志的相关变量 错误日志二进制日志功能作用文件的构成日志格式查看日志删除日志 通用日志慢查询日志 Mysql日志记录着数据库在运行过程中的各种操作&#xff0c;帮助管理员定位查找问题。 事务日志 事务日志(Transaction…

方案分享 | 嵌入式指纹方案

随着智能设备的持续发展&#xff0c;指纹识别技术成为了现在智能终端市场和移动支付市场中占有率最高的生物识别技术。凭借高识别率、短耗时等优势&#xff0c;被广泛地运用在智能门锁、智能手机、智能家居等设备上。 上海航芯在2015年进入指纹识别应用领域&#xff0c;自研高性…

MySQL的安装(Linux版)

1.所需要的文件 MySQL.zip 2. 卸载自带的Mysql-libs # 查看是否存在 rpm -qa | grep mariadb# 如果存在则执行命令进行卸载 rpm -e --nodeps mariadb-libs3.在/opt目录下创建MySQL目录并上传所需要的安装包 cd /optmkdir MySQL4.按照编号顺序安装&#xff08;压缩包在解压完…

vs2010打包QT程序

一、环境 win10 、 VS2010 、 qt5.7.1 将代码在release模式下运行 运行完后会在相应的文件夹下生成exe文件&#xff0c;也会将部分dll文件拷贝到release文件夹中 二、生成可执行文件 2.1 选择“文件”->“新建”->”项目“ 2.2 在打开的对框中选择”其他类型项目…

视频素材网有哪些?7个优质无水印素材库推荐

当今社会&#xff0c;视频内容已成为最吸引眼球的媒介之一&#xff0c;无论是社交媒体的短视频还是专业制作的影片&#xff0c;都离不开高质量的视频素材。为了帮助你在茫茫网海中找到宝藏&#xff0c;下面我精选了一系列视频素材网站&#xff0c;不仅提供丰富多样的视频资源&a…

(1) 易经与命运_学习笔记

个人笔记&#xff0c;斟酌阅读 占卦的原理 三个铜板&#xff0c;正面是3&#xff0c;反面2&#xff0c;三个一起转&#xff0c;得出6,7,8,9 数字象6老阴7少阳8少阴9老阳 生数和成数 生数和成数应该说出自《河图》。其中一二三四五为生数&#xff0c;六七八九十为成数。 生…

Typst入门简明教程

文章目录 写在前面对比阅读安装Typst的安装系统路径设置VSCode的设置 用Typst写作节指令图指令表指令公式指令参考文献指令引用指令节的标签图、表、公式的标签参考文献的标签 Typst编译 Typst的<label>写法建议写在最后&#xff1a;Typst、LaTex和Word的比较 写在前面 …

面试经典150题【111-120】

文章目录 面试经典150题【111-120】67.二进制求和190.颠倒二进制位191.位1的个数136.只出现一次的数字137.只出现一次的数字II201.数字范围按位与5.最长回文子串97.交错字符串72.编辑距离221.最大正方形 面试经典150题【111-120】 六道位运算&#xff0c;四道二维dp 67.二进制…

未能加载文件或程序集socutdata或它的某一个依赖项试图加载格式不正确的程序

未能加载文件或程序集socut data或它的某一个依赖项试图加载格式不正确的程序 Socut.Data.dll找不到类型或命名空间名称 把bin目录下面 的socut.data.dll删除就行了 C#报错未能加载文件或程序集socut data或它的某一个依赖项试图加载格式不正确的程序 "/"应用程序…

逐步学习Go-并发通道chan(channel)

概述 Go的Routines并发模型是基于CSP&#xff0c;如果你看过七周七并发&#xff0c;那么你应该了解。 什么是CSP&#xff1f; "Communicating Sequential Processes"&#xff08;CSP&#xff09;这个词组的含义来自其英文直译以及在计算机科学中的使用环境。 CSP…

Android Studio详细安装教程及入门测试

Android Studio 是 Android 开发人员必不可少的工具。 它可以帮助开发者快速、高效地开发高质量的 Android 应用。 这里写目录标题 一、Android Studio1.1 Android Studio主要功能1.2 Android应用 二、Android Studio下载三、Android Studio安装四、SDK工具包下载五、新建测试…

Live800:设计与管理客户忠诚度计划,提升客户满意度与忠诚度

在当今竞争激烈的商业环境中&#xff0c;吸引新客户的成本远高于保留现有客户。因此&#xff0c;设计并实施一套有效的客户忠诚度计划&#xff0c;以提升客户满意度和忠诚度&#xff0c;已经成为企业获得长期成功的关键。文章将探讨如何设计和实施客户忠诚度计划&#xff0c;以…

ehters.js:provider

ethers.jsV5.4文档 安装ethers npm install ethers5.4.0// 引入 import { ethers } from ethersProviders /** Provider类* Provider类是对以太坊网络连接的抽象&#xff0c;为标准以太坊节点功能提供简洁、一致的接口。 */ const provider new ethers.providers.Web3Provider…

【QT入门】 Qt代码创建布局之水平布局、竖直布局详解

往期回顾&#xff1a; 【QT入门】 Qt实现自定义信号-CSDN博客 【QT入门】 Qt自定义信号后跨线程发送信号-CSDN博客 【QT入门】 Qt内存管理机制详解-CSDN博客 【QT入门】 Qt代码创建布局之水平布局、竖直布局详解 先看两个问题&#xff1a; 1、ui设计器设计界面很方便&#xf…

Soft Robotics:两栖环境下螃蟹仿生机器人的行走控制

传统水陆两栖机器人依靠轮胎或履带与表面的接触及摩擦产生推进力&#xff0c;这种对于表面接触的依赖性限制了现有水陆两栖机器人在低重力环境下&#xff08;如水中&#xff09;的机动性。利用生物自身的推进机制&#xff0c;人为激发生物运动行为&#xff0c;由活体生物与微机…