ViT: transformer在图像领域的应用

文章目录

  • 1. 概要
  • 2. 方法
  • 3. 实验
    • 3.1 Compare with SOTA
    • 3.2 PRE-TRAINING DATA REQUIREMENTS
    • 3.3 SCALING STUDY
    • 3.4 自监督学习
  • 4. 总结
  • 参考

论文: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
代码:https://github.com/google-research/vision_transformer
代码2:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py

我们在Transformer详解(1)—原理部分详细介绍了transformer在NLP领域应用的原理,transformer架构自发布以来已经在自然语言处理任务上广泛应用,今天我们将介绍如何将transformer架构应用在图像领域。

1. 概要

基于self-attention的网络架构在NLP领域中取得了很大的成功,但是在CV领域卷积网络架构仍然占据主导地位。受到transformer在NLP中应用成功的启发,也有很多工作尝试将self-attention与CNN网络结合,甚至有些工作直接替换CNN网络,理论上这些模型是高效的,由于这些特殊的注意力机制未与硬件加速器有效适配,因此在大规模的图像检测中,经典的ResNet网络架构仍然是SOTA

受到Transformer网络在NLP领域中成功适配的启发,作者提出对transformer尽可能少的修改,直接在图片上应用标准的transformer。为了实现这个目标,首先需要将图片分割成多个patch,并将这些patch转换成embedding作为transformer的输入。图片的patch就相当于NLP中的token

最后作者得到结论:在数据量不足的情况下进行训练时,ViT不能很好地泛化,效果不如CNN,不过在训练大规模数据时,vit的效果会反超CNN

2. 方法

在模型设计方面,version transformer尽量与原始transformer结构保持一致,因为NLP中的transformer具有高效的实现方式,这样可以开箱即用。模型的整体结构如下所示:
在这里插入图片描述
标准的 transformer 输入是一维向量序列,为了处理二维图像,将输入图片 x ∈ R H × W × C \mathbf{x}\in\mathbb{R}^{H\times W\times C} xRH×W×C 分割成一系列的patch,并将这些patch平整成一维向量,最终得到 x p ∈ R N × ( P 2 ⋅ C ) \mathbf{x}_p\in\mathbb{R}^{N\times(P^2\cdot C)} xpRN×(P2C),其中 ( H , W ) (H,W) (H,W)是原始图片分辨率, C C C 是图片的通道数, ( P , P ) (P,P) (P,P)是每个patch的分辨率, N = H W P 2 N=\frac{HW}{P^2} N=P2HW 是patch的个数,也可以看作是输入序列的长度。由于transformer每一层的输入向量维度都是固定的 D D D,因此需要通过一个可训练的线性层将 flatten patch 的维度从 P 2 C P^2C P2C转换成 D D D,这个线性层的输出称为patch的embedding.

和BERT的 [class] token 类似,在path embedding序列的首位增加了一个可学习向量 z 0 0 = x c l a s s z_0^0=x_{class} z00=xclass,该向量在transformer encoder的输出部分看做是图片的表征,在预训练和微调阶段,该表征后都会接一个分类层。

为了保持位置信息,位置embedding会加到patch embedding上,这里作者使用了一个一维可学习的位置向量,因为通过实验发现使用二维位置向量并没有获得很大的性能提升,通过以上流程处理后的embedding就是transformer的输入embedding。从输入图片到transformer encoder输出可由以下式子表示:

z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯   ; x p N E ] + E p o s ;     E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D z ′ ℓ = M S A ( L N ( z ℓ − 1 ) ) + z ℓ − 1 ;    ℓ = 1 … L z ℓ = M L P ( L N ( z ′ ℓ ) ) + z ′ ℓ ;     ℓ = 1 … L y = L N ( z L 0 ) \begin{align} z_0 =&[\mathbf{x}_\text{class};\mathbf{x}_p^1\mathbf{E};\mathbf{x}_p^2\mathbf{E};\cdots;\mathbf{x}_p^N\mathbf{E}]+\mathbf{E}_{pos}; \ \ \ \mathbf{E}\in\mathbb{R}^{(P^{2}\cdot C)\times D}, \mathbf{E}_{pos}\in\mathbb{R}^{(N+1)\times D}\\ \mathbf{z}^{\prime}{}_{\ell} =& \mathrm{MSA(LN(z_{\ell-1}))+z_{\ell-1}};\ \ \ell=1\ldots L \\ \mathbf{z}_{\ell} = &\mathrm{MLP}(\mathrm{LN}(\mathbf{z^{\prime}}_\ell))+\mathbf{z^{\prime}}_\ell; \ \ \ \ell=1\ldots L \\ y =& \mathrm{LN}(\mathbf{z}_{L}^{0}) \end{align} z0=z=z=y=[xclass;xp1E;xp2E;;xpNE]+Epos;   ER(P2C)×D,EposR(N+1)×DMSA(LN(z1))+z1;  =1LMLP(LN(z))+z;   =1LLN(zL0)

其中 E E E 是patch维度转换矩阵, M S A MSA MSA是多头注意力层(multi-head self attention), L N LN LN是layer normalization 层, M L P MLP MLP是transformer中前馈网络层

另外,也可以使用CNN网络的特征图作为输入序列,在这种混合模型中,patch embeding 投影层将被用于改变CNN特征图的形状。

在微调阶段,将移除预训练的prediction layer,并新增一个零初始化的预测层,一般来说,在更高分辨率图像上微调是非常有益的。在喂入更高分辨率图像时,保持patch的尺寸不变,这样会造成输入序列长度增加,虽然ViT模型可以处理任意长的输入序列(直到内存不够),但是预训练的位置编码将无效,因此作者根据当前位置在原始图片中的位置,对预训练的位置编码采用2D插值的方法获取最新的位置编码

3. 实验

下文中将用一些简写来代表模型的尺寸和输入patch的尺寸,如ViT-L/16 代表模型为ViT-Large,输入patch的尺寸为 16 × 16 16 \times 16 16×16,下表展示了不同尺寸模型的配置及参数量
在这里插入图片描述
这里需要注意,由于输入序列长度与patch的尺寸成反比,所以,patch 尺寸越小,反而计算量越大

3.1 Compare with SOTA

在这里插入图片描述
TPU v3-core-days:代表计算量,All models were trained on TPUv3 hardware, and we
report the number of TPUv3-core-days taken to pre-train each of them, that is, the number of TPU
v3 cores (2 per chip) used for training multiplied by the training time in days

在这里插入图片描述
不同模型简介:

  • Big Transfer (BiT), which performs supervised transfer learning with large ResNets
  • VIVI – a ResNet co-trained on ImageNet and Youtube
  • S4L – supervised plus semi-supervised learning on ImageNet

3.2 PRE-TRAINING DATA REQUIREMENTS

作者经过实验得到如下结论:

  • 在小数据集上预训练,ViT-Large比ViT-Base要差,在大数据集上训练对ViT-Large比较有益
  • 在小数据集上预训练,ViT的效果比CNN还要差,在大数据集上预训练ViT的效果超过CNN
  • CNN网络的归纳有偏性在小数据集上是有用的,但是在大数据集上,直接从数据中学习相关的模式更有效
    在这里插入图片描述

3.3 SCALING STUDY

如下图所示,作者得到如下结论:

  • ViT在效果和计算量平衡之间相比ResNet占绝对优势,ResNet需要使用约3倍的算力来获得与ViT相似的结果
  • 混合模型在小计算量上相比ViT具有一定的优势,但是这种优势在大模型(大计算量)上逐渐消失
  • ViT在当前实验中貌似并没有饱和,这激励着未来的研究
    在这里插入图片描述

3.4 自监督学习

作者模仿BERT通过mask patch prediction任务进行自监督预训练,ViT-B/16在ImageNet上获得了79.9%的准确率,相比从随机初始化开始训练提升了2%,但是相比于监督学习仍然落后4%。

4. 总结

作者将图片看作是patch序列,并使用标准的Transformer对patch序列进行处理,最终在大数据集上预训练取得了很不错的效果,在图片分类任务上超过了很多SOTA模型。但也还存在一些挑战等待后期处理:

  • 将ViT应用在其他计算机视觉任务中,如目标检测、语义分割等
  • 还需进一步探索自监督预训练方法
  • 进一步扩大ViT模型的规模,可能会取得更好的效果

参考

如何理解Inductive bias?
Translation Equivariance
CNN中的Translation Equivariance【理解】
2D插值(2D interpolation)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/392177.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

删除windows自带输入法

ctrl shift F 搜狗简繁体切换

【大厂AI课学习笔记】【2.1 人工智能项目开发规划与目标】(4)数据准备的流程

今天学习的是数据准备的流程。 我们已经知道,数据准备占了AI项目超过一半甚至79%的时间。 那么数据准备,都做些什么,有哪些流程。 1.数据采集 观测数据人工收集调查问卷线上数据库 2.数据清洗 有缺失的数据有重复的数据有内容错误的数据…

CSS的注释:以“ /* ”开头,以“ */ ”结尾

CSS的注释:以“ /* ”开头,以“*/”结尾 CSS的注释: 以“ /* ”开头,以“ */ ”结尾 在CSS中,注释是一种非常重要的工具,它们可以帮助开发者记录代码的功能、用法或其他重要信息。这些信息对于理解代码、维护代码以及与他人合作都…

SpringBoot实现OneDrive文件上传

SpringBoot实现OneDrive文件上传 源码 OneDriveUpload: SpringBoot实现OneDrive文件上传 获取accessToken步骤 参考文档:针对 OneDrive API 的 Microsoft 帐户授权 - OneDrive dev center | Microsoft Learn 1.访问Azure创建应用Microsoft Azure,使…

Sora 文生视频提示词实例集 2

Prompt: Historical footage of California during the gold rush. 加利福尼亚淘金热期间的历史影像。 Prompt: A close up view of a glass sphere that has a zen garden within it. There is a small dwarf in the sphere who is raking the zen garden and creating patter…

Ubuntu 20.04 安装RVM

RVM是管理Ruby版本的工具,使用RVM可以在单机上方便地管理多个Ruby版本。 下载安装脚本 首先使下载安装脚本 wget https://raw.githubusercontent.com/rvm/rvm/master/binscripts/rvm-installer 如果出现了 Connection refused 的情况, 可以考虑执行以下命令修改dns,再执…

win10下wsl2使用记录(系统迁移到D盘、配置国内源、安装conda环境、配置pip源、安装pytorch-gpu环境、安装paddle-gpu环境)

wsl2 安装好后环境测试效果如下,支持命令nvidia-smi,不支持命令nvcc,usr/local目录下没有cuda文件夹。 系统迁移到非C盘 wsl安装的系统默认在c盘,为节省c盘空间进行迁移。 1、输出wsl -l 查看要迁移的系统名称 2、执行导出命…

配置oracle连接管理器(cman)

Oracle Connection Manager是一个软件组件,可以在oracle客户端上指定安装这个组件,Oracle连接管理器代理发送给数据库服务器的请求,在连接管理器中,我们可以通过配置各种规则来控制会话访问。 简而言之,不同于专用连接…

c入门第十八篇——支持学生数的动态增长(链表,指针的典型应用)

数组最大的问题,就是不支持动态的扩缩容,它是静态内存分配的,一旦分配完成,其容量是固定的。为了支持学生的动态增长,这里可以引入链表。 链表 在C语言中,链表是一种常用的数据结构,它由一系列…

深入解析鸿蒙系统的页面路由(Router)机制

鸿蒙系统以其独特的分布式架构和跨设备的统一体验而备受瞩目。在这个系统中,页面路由(Router)机制是连接应用各页面的关键组成部分。本文将深入探讨鸿蒙系统的页面路由,揭示其工作原理、特点以及在应用开发中的实际应用。 1. 实现…

使用Autodl云服务器或其他远程机实现在本地部署知识图谱数据库Neo4j

本篇博客的目的在于提高读者的使用效率 温馨提醒:以下操作均可在无卡开机状态下就可完成 一.安装JDK 和 Neo4j 1.1 ssh至云服务器 打开你的pycharm或者其他IDE工具或者本地终端,ssh连接到autodl的服务器。(这一步很简单如下图) 1.2 安装JDK 由于我…

gitlab代码控制平台搭建

docker-compose容器化gitlab docker-compose安装 # 官方链接(不推荐,太慢了) curl -L "https://github.com/docker/compose/releases/download/1.29.2/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose# 下面的官方链接会快一…

JAVA面试题基础篇

1. 二分查找 要求 能够用自己语言描述二分查找算法 能够手写二分查找代码 能够解答一些变化后的考法 算法描述 前提:有已排序数组 A(假设已经做好) 定义左边界 L、右边界 R,确定搜索范围,循环执行二分查找&#…

计算机网络——15套接字编程

套接字编程 Socket编程 Socket编程:应用进程使用传输层提供的服务才能够交换报文,实现应用协议,实现应用 TCP/IP:应用进程使用Socket API访问传输服务 地点:界面上的SAP 方式:Socket API 目标&#xff1…

鸿蒙开发系列教程(二十四)--List 列表操作(3)

列表编辑 1、新增列表项 定义列表项数据结构和初始化列表数据,构建列表整体布局和列表项。 提供新增列表项入口,即给新增按钮添加点击事件。 响应用户确定新增事件,更新列表数据。 2、删除列表项 列表的删除功能一般进入编辑模式后才可…

stable diffusion webui学习总结(2):技巧汇总

一、脸部修复:解决在低分辨率下,脸部生成异常的问题 勾选ADetailer,会在生成图片后,用更高的分辨率,对于脸部重新生成一遍 二、高清放大:低分辨率照片提升到高分辨率,并丰富内容细节 1、先通过…

Leetcode-429.N叉树的层序遍历

题目: 给定一个 N 叉树,返回其节点值的层序遍历。(即从左到右,逐层遍历)。 树的序列化输入是用层序遍历,每组子节点都由 null 值分隔(参见示例)。 示例 1: 输入&#xff…

Rocky Linux 下载安装

一、VMware Workstation下载安装 1、安装教程 VMware Workstation下载安装(含密钥) 二、VMware Workstation 创建虚拟机 1、创建教程 VMware Workstation 创建虚拟机 三、Rocky Linux 下载 1、下载官网 RockyLinux.org 2、选择X86架构_64位系统_DVD镜…

【C++初阶】第三站:类和对象(中) -- 日期计算器

目录 前言 日期类的声明.h 日期类的实现.cpp 获取某年某月的天数 全缺省的构造函数 拷贝构造函数 打印函数 日期 天数 日期 天数 日期 - 天数 日期 - 天数 前置 后置 前置 -- 后置-- 日期类中比较运算符的重载 <运算符重载 运算符重载 ! 运算符重载 …

JavaScript设计模式与开发实战

JavaScript设计模式与开发实践 第一章、面向对象的JavaScript 1.1 多态 类似java面向对象&#xff0c;通过继承共有特征&#xff0c;来实现不同方法。JavaScript的多态就是把“做什么”和“谁去做”分离&#xff0c;消除类型间的耦合关系。 他的作用就是把过程化的条件分支…
最新文章