ViT论文Pytorch代码解读

ViT论文代码实现

论文地址:https://arxiv.org/abs/2010.11929
Pytorch代码地址:https://github.com/lucidrains/vit-pytorch

ViT结构图

在这里插入图片描述

调用代码

import torch
from vit_pytorch import ViT

def test():
    v = ViT(
        image_size = 256, 
        patch_size = 32,  
        num_classes = 1000,  
        dim = 1024,  
        depth = 6,  
        heads = 16,  
        mlp_dim = 2048,  
        dropout = 0.1,
        emb_dropout = 0.1
    )

    img = torch.randn(1, 3, 256, 256)

    preds = v(img)
    print(preds.shape)
    assert preds.shape == (1, 1000), 'correct logits outputted'

if __name__ == '__main__':
    test()

ViT结构

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        
        # 将image_size和patch_size都转换为(height, width)形式
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
		
		# 检查图像尺寸是否可以被patch尺寸整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

		# 计算图像中的patch数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
		
		# 计算每个patch的维度(即每个patch的元素数量)
        patch_dim = channels * patch_height * patch_width
        
        # 确保池化方式是'cls'或'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

		# 将图像转换为patch嵌入的操作
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),  # 图像切分重排,后文有注释
            # 注:此时的维度为[b, h*w/p1/p2, p1*p2*c]:[批处理尺寸、图像中patch的数、每个patch的元素数量]
            nn.LayerNorm(patch_dim),  # 对patch进行层归一化
            nn.Linear(patch_dim, dim),  # 使用线性层将patch的维度从patch_dim转化为dim
            nn.LayerNorm(dim),  # 对结果进行层归一化
        )
		
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 初始化位置嵌入
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 初始化CLS token(用于分类任务的特殊token)
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # 定义Transformer模块
 
        self.pool = pool  # 设置池化方式('cls'或'mean')
        self.to_latent = nn.Identity()  # 设置一个恒等映射(在此实现中不改变数据,但可以在子类或其他变种中进行修改)

        self.mlp_head = nn.Linear(dim, num_classes)   # 定义MLP头部,用于最终的分类

    def forward(self, img):
        x = self.to_patch_embedding(img) # 第一步,将图片切分为若干小块
		# 此时维度为:[b, h*w/p1/p2, dim]
        b, n, _ = x.shape
		
		# 第二步,设置位置编码
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)  # 将cls_token复制b个 
        # (为每个输入图像复制一个CLS token,使输入批次中的每张图像都有一个相应的CLS token)
        x = torch.cat((cls_tokens, x), dim=1)  # 将CLS token与patch嵌入合并; cat之后,原来的维度[1,64,1024],就变成了[1,65,1024]
        x += self.pos_embedding[:, :(n + 1)] # 原数据和位置编码直接进行相加操作,即完成结构图中的【Patch + Position Embedding】操作
        
        x = self.dropout(x)

		# 第三步,Transformer的Encoder结构
        x = self.transformer(x)
        
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]   # 根据所选的池化方式进行池化

        x = self.to_latent(x)  # 将数据传递给恒等映射
        return self.mlp_head(x)  # 使用MLP头部进行分类
  

Rearrange解释:
y = x.transpose(0, 2, 3, 1)
可以写成:y = rearrange(x, ‘b c h w -> b h w c’)

关于pos_embedding和cls_token的逻辑讲解:
在这里插入图片描述如图所示,红色框框出的部分。
图像被切分为多个小块之后,经过self.to_patch_embedding 中的Rearrange,原本的[b,c,h,w]维度变为[b, h*w/p1/p2, p1*p2*c]。
再经过线性层nn.Linear(patch_dim, dim),维度变为[b, h*w/p1/p2, dim]。
输出结果即为上图中黄色框标出的部分的粉色条(不包括紫色条,是因为此处还没进行Position Embedding操作)。
继续往下走,进行torch.cat((cls_tokens, x), dim=1),此时将xcls_tokens进行concat操作,得到红色框框出的所有粉色条(在原本的基础上增加了带*号的粉色条)。
记下来的x += self.pos_embedding[:, :(n + 1)]操作就是将xpos_embedding直接进行相加,用图表示出来就是上图中整个红色框框出的部分了(紫色条就是传说中的pos_embedding)。
举一个有数字的例子:
原本输入图像维度为[1, 3, 256, 256],dim设置为1023,经过self.to_patch_embedding后维度变为:[1,64,1024],cls_tokens的维度为:[1,1,1024],经过concat操作后,x的维度变为[1,65,1024],然后经过pos_embedding加操作后,维度依然是[1,65,1024],因为在设置变量pos_embedding时的维度就是torch.randn(1, num_patches + 1, dim)
~这个解释应该够清晰了吧!~

Transformer Encoder结构

# 定义前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            # Vit_base: dim=768,hidden_dim=3072
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),  # 将输入从dim维映射到hidden_dim维
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),  # 将隐藏状态从hidden_dim维映射回到dim维
            nn.Dropout(dropout) 
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads  # 64*8=512  # 计算内部维度
        project_out = not (heads == 1 and dim_head == dim) # 判断是否需要投影输出,投影输出就是是否需要经过线性层
        # 如果只有一个attention头并且其维度与输入相同则不需要投影输出,否则需要。

        self.heads = heads
        self.scale = dim_head ** -0.5 # 缩放因子,通常是头维度的平方根的倒数

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)   # softmax函数用于最后一个维度,计算注意力权重
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) # 一个线性层生成Q, K, V

		# 判断是否需要投影输出
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 用线性层生成QKV,并在最后一个维度上分块;相当于写3遍nn.Linear
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 
        # 将[batch_size, sequence_length, heads_dimension] 转换为 [batch_size, number_of_heads, sequence_length, dimension_per_head]

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # 计算Q和K的点乘,然后进行缩放
        # q: [batch_size, number_of_heads, sequence_length, dimension_per_head]
        # k转置后:[batch_size, number_of_heads, sequence_length, dimension_per_head] -> [batch_size, number_of_heads, dimension_per_head, sequence_length]
        # q和k点乘后:[batch_size, number_of_heads, sequence_length, sequence_length]

        attn = self.attend(dots)   # 使用softmax函数获取注意力权重
        attn = self.dropout(attn)
		
		# 使用注意力权重对V进行加权
        out = torch.matmul(attn, v) 
        out = rearrange(out, 'b h n d -> b n (h d)') # 使用rearrange函数重新组织输出的维度
        return self.to_out(out)  # 投影输出(如果需要)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):  # depth设置为几层,就重复几次
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:  # 残差
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

如上就是ViT的整体结构了。

附:完整代码

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange


# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)


# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            # Vit_base: dim=768,hidden_dim=3072
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads  # 64*8=512
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 相当于写3遍nn.Linear
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        # 将[batch_size, sequence_length, heads_dimension] 转换为 [batch_size, number_of_heads, sequence_length, dimension_per_head]

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        # q: [batch_size, number_of_heads, sequence_length, dimension_per_head]
        # k转置后:[batch_size, number_of_heads, sequence_length, dimension_per_head] -> [batch_size, number_of_heads, dimension_per_head, sequence_length]
        # q和k点乘后:[batch_size, number_of_heads, sequence_length, sequence_length]

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)


class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),  # 图像切分重排
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        # Rearrange解释:
        # y = x.transpose(0, 2, 3, 1)
        # 可以写成:y = rearrange(x, 'b c h w -> b h w c')

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)  # 数字编码,将cls_token复制b个
        x = torch.cat((cls_tokens, x), dim=1)  # cat之后,原来的维度[1,64,1024],就变成了[1,65,1024]
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

附:训练代码

model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
).to(device)


# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)


for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/100035.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Rust 学习笔记(持续更新中…)

一、 编译和运行是单独的两步 运行 Rust 程序之前必须先编译,命令为:rustc 源文件名 - rustc main.rs编译成功之后,会生成一个二进制文件 - 在 Windows 上还会生产一个 .pdb 文件 ,里面包含调试信息Rust 是 ahead-of-time 编译的…

外贸爬虫系统

全球智能搜索 全球智能搜索 支持全球所有国家搜索引擎,及社交平台,精准定位优质的外贸客户,免翻墙 全球任意国家地区实时采集 搜索引擎全网邮箱电话采集 社交平台一键查看采集(Facebook,Twitter,Linkedin等) 职位…

Pytorch 的基本概念和使用场景介绍

文章目录 一、基本概念1. 张量(Tensor)2. 自动微分(Autograd)3. 计算图(Computation Graph)4. 动态计算图(Dynamic Computation Graph)5. 变量(Variable) 二、…

微软表示Visual Studio的IDE即日起开启“退休”倒计时

据了解,日前有消息透露称,适用于 Mac平台的Visual Studio集成开发环境(IDE)于8月31日启动“退休”进程。 而这意味着Visual Studio for Mac 17.6将继续支持12个月,一直到2024年8月31日。    微软表示后续不再为Visual Studio for Mac开发…

windows自带远程桌面连接的正确使用姿势

摘要 目前远程办公场景日趋广泛,对远程控制的需求也更加多样化,windows系统自带了远程桌面控制,在局域网内可以实现流程的远程桌面访问及控制。互联网使用远程桌面则通常需要使用arp等内网穿透软件,市场上teamviewer、Todesk、向…

进程的挂起状态

进程的挂起状态详解 当我们谈论操作系统和进程管理时,我们经常听到进程的各种状态,如“就绪”、“运行”和“阻塞”。但其中一个不那么常被提及,但同样重要的状态是“挂起”状态。本文将深入探讨挂起状态,以及为什么和在何时进程…

直播预约|哪吒汽车岳文强:OEM和Tier1如何有效对接网络安全需求

信息安全是一个防护市场。如果数字化程度低,数据量不够,对外接口少,攻击成本高,所获利益少,自然就没有什么攻击,车厂因此也不需要在防护上花费太多成本。所以此前尽管说得热闹,但并没有太多真实…

css让多个盒子强制自动等宽

1.width: calc( 100 / n% ) 2.display:flex; flex:1;width:100px; 3.display:grid;grid-template-columns: repeat(auto-fit, minmax(100px, 1fr)); 但是其中某一个内容较长的时候 会破坏1:1:1的平衡 这个时候发现附件名字过长导致不等比例,通过查看阮一峰flex文…

滑动窗口实例4(将x减到0的最小操作数)

题目: 给你一个整数数组 nums 和一个整数 x 。每一次操作时,你应当移除数组 nums 最左边或最右边的元素,然后从 x 中减去该元素的值。请注意,需要 修改 数组以供接下来的操作使用。 如果可以将 x 恰好 减到 0 ,返回 …

DOM破坏绕过XSSfilter例题

目录 一、什么是DOM破坏 二、例题1 三、多层关系 1.Collection集合方式 2.标签关系 3.三层标签如何获取 四、例题2 五、例题3 1.代码审计 2.payload分析 一、什么是DOM破坏 DOM破坏(DOM Clobbering)指的是对网页上的DOM结构进行不当的修改&am…

评估安全 Wi-Fi 接入:Cisco ISE、Aruba、Portnox 和 Foxpass

在当今不断变化的数字环境中,对 Wi-Fi 网络进行强大访问控制的需求从未像现在这样重要。各组织一直在寻找能够为其用户提供无缝而安全的体验的解决方案。 在本博客中,我们将深入探讨保护 Wi-Fi(和有线)网络的四种领先解决方案——…

春秋云镜 CVE-2018-19422

春秋云镜 CVE-2018-19422 Subrion CMS 4.2.1 存在文件上传漏洞 靶标介绍 Subrion CMS 4.2.1 存在文件上传漏洞。CVE-2021-41947同一套cms。 启动场景 漏洞利用 admin/admin登陆后台管理界面 执行SQL命令,获取flag select load_file(/flag); 得到flag flag{174…

泥石流山体滑坡监控视觉识别检测算法

泥石流山体滑坡监控视觉识别检测算法通过yolov8python深度学习框架模型,泥石流山体滑坡监控视觉识别检测算法识别到泥石流及山体滑坡灾害事件的发生,算法会立即进行图像抓拍,并及时进行预警。Yolo的源码是用C实现的,但是好在Githu…

Qt---对话框 事件处理 如何发布自己写的软件

目录 一、对话框 1.1 消息对话框(QMessageBox) 1> 消息对话框提供了一个模态的对话框,用来提示用户信息,或者询问用户问题并得到回答 2> 基于属性版本的API 3> 基于静态成员函数版本 4> 对话框案例 1、ui界面 …

vue之若依分页组件的导入使用(不直接使用若依框架,只使用若依分页组件)

vue之若依分页组件的导入使用 步骤 步骤: 工具类:src/utils/scroll-to.js 样式:src/assets/styles/ruoyi.scss 组件:src/components/Pagination 全局挂载:src/main.js 复制工具类 复制若依框架中的src/utils/scrol…

【UE 材质】实现角度渐变材质、棋盘纹理材质

目标 步骤 一、角度渐变材质 1. 首先通过“Mask”节点将"Texture Coordinate" 节点的R、G通道分离 2. 通过“RemapValueRange”节点将0~1范围映射到-1~1 可以看到此时R通道效果: G通道效果: 继续补充如下节点 二、棋盘纹理材质 原视频链接&…

20230901工作心得:IDEA列操作lambda表达式加强版用法

今天是中小学开学时间,亦是9月的开始,继续努力。 今日收获较大的有四个地方,先说这四点。 1、IDEA列操作 使用场景:需要批量将Excel表格里的数据插入到数据库中,此时需要写大量的insert SQL语句。 比如像这样的&am…

Python之父加入微软三年后,Python嵌入Excel!

近日,微软传发布消息,Python被嵌入Excel,从此Excel里可以平民化地进行机器学习了。只要直接在单元格里输入“PY”,回车,调出Python,马上可以轻松实现数据清理、预测分析、可视化等等等等任务,甚…

知识图谱笔记:TransH

1 TransE存在的问题 一对多 假设有一个关系 "是父亲",其中一个父亲(头实体)可能有多个孩子(尾实体) 父亲 A -> 孩子 1父亲 A -> 孩子 2在 TransE 中,这两个关系会被建模为: A是…

网络入门基础

目录 计算机网络背景 网络发展 认识协议 协议的制订 网络协议详解 协议分层 OSI七层模型 TCP/IP模型 网络传输的基本流程 局域网通信 跨网络通信 网络中的地址管理 IP地址 MAC地址 计算机网络背景 网络发展 独立模式:计算机之间相互独立 在早期的时候…
最新文章