聊聊拉长LLaMA的一些经验

Sequence Length是指LLM能够处理的文本的最大长度,越长,自然越有优势:

  1. 更强的记忆性。更多轮的历史对话被拼接到对话中,减少出现遗忘现象

  2. 长文本场景下体验更佳。比如文档问答、小说续写等

当今开源LLM中的当红炸子鸡——LLaMA,第一版上下文长度是2048,第二版长度是4096。相比之下ChatGPT、GPT4已经支持到16k,Claude甚至支持到了100k。足以见得将LLaMA拉长是如此的任重而道远。本文将会介绍三种在旋转位置编码(RoPE)基础上扩充上下文的高性价比方案,在文末会介绍我的实践经验。

线性插值法

Kaiokendev的博客[1]中提到了方法,和Meta的一篇工作[2]不谋而合,其思想主要是将目标长度压缩到原始长度。如下图所示,LLaMA-1预训练的长度为2048,如果我们想把它拉长到4096:

  • 方法一:推理时直接拉长到4096。这考虑位置编码的外推性(即在短文本上训练,长文本上推理的能力[2]),而RoPE的外推性则是相当一般[2]。由于训练时长度都是小于2048的,超过2048部分Attention分数会飙升,导致困惑度急剧上升。

  • 方法二:在原始模型基础上做长度为4096的继续训练。这里先岔开介绍另一款模型——MPT-30B的做法,根据官方博客[3]的介绍:

    As mentioned earlier, MPT-30B was trained with a long context window of 8k tokens (vs. 2k for LLaMa and Falcon) and can handle arbitrarily long context windows via ALiBi or with fine-tuning. To build 8k support into MPT-30B efficiently, we first pre-trained on 1T tokens using sequences that were 2k tokens long, and continued training for an additional 50B tokens using sequences that were 8k tokens long.

    MPT-30B采用ALiBi位置编码(外推性优于RoPE),在2k的长度进行1T token的训练,然后在8k长度上进行50B token的预训练——这是在外推性强于RoPE的ALiBi上的情况。LLaMA-1预训练的token数是1T以上,想要在长度为4096样本上效果不下降,那需要训练足够多的token数才行,这就需要较大的成本了。

  • 方法三:另一种思路则是将4096的位置编码通过线性插值法压缩到2048内,这样只需要在少量的长度为4096的数据上进行继续预训练,便可达到不错的效果。

来自论文[2]

来自论文[2]

代码实现

线性插值法的实现代码相当的简单,这需要在原始RoPE上进行微小的改动,即加上下图的scale参数。

来自[7],scaled_rope/LlamaLinearScaledRotaryEmbedding.py

来自[7],scaled_rope/LlamaLinearScaledRotaryEmbedding.py

效果

Meta的工作[2]中进行了充足实验和公式推导证明,如果想看具体的代码,建议看lmsys.org(Vicuna的出品方)的一篇工作[4]:How Long Can Open-Source LLMs Truly Promise on Context Length? 他们对比了商用模型、号称支持长文本的开源模型和“Vicuna+线性插值法”的效果,并给出了几个结论:

  1. 商用模型在长文本的效果上很能打!而那些号称支持长文本的开源模型,在长文本上则表现不佳。

  2. 随着文本长度的增加,越接近边界,Vicuna+线性插值法的效果降低越明显。这可能是因为训练数据存在短文本的情况。

来自[4]

来自[4]

来自[4]

来自[4]

写这篇文章的同时,ChatGLM团队更新了ChatGLM2-6B-32K,也是使用了插值法。同时推出了长文本的中英评测集LongBench,在这个评测集上ChatGLM2-6B-32K展示了强大的实力,但值得注意的是,该评测集的评测方式是使用ChatGLM2-6B来进行评估的。

NTK插值法

NTK插值法的提出于一篇Reddit帖子[5],它提出使用Neural Tangent Kernel (NTK)来解决这个问题。

if you apply Neural Tangent Kernel (NTK) theory to this problem, it becomes clear that simply interpolating the RoPE's fourier space "linearly" is very sub-optimal, as it prevents the network to distinguish the order and positions of tokens that are very close by. Borrowing from NTK literature, scaling down the fourier features too much will eventually even prevent succesful finetunes (this is corroborated by the recent paper by Meta that suggests an upper bound of ~600x)

Instead of the simple linear interpolation scheme, I've tried to design a nonlinear interpolation scheme using tools from NTK literature. Basically this interpolation scheme changes the base of the RoPE instead of the scale, which intuitively changes the "spinning" speed which each of the RoPE's dimension vectors compared to the next. Because it does not scale the fourier features directly, all the positions are perfectly distinguishable from eachother, even when taken to the extreme (eg. streched 1million times, which is effectively a context size of 2 Billion)

帖子中作者还用了时钟的例子来解释线性插值和NTK插值的异同:

RoPE behaves like a clock. Your 12 hours wall clock is basically a RoPE of dimension 3 with a base of 60. So for each second, the minute hand turns 1/60th of a minute, and for each minute, the hour hand turns 1/60th.

Now if you slowed down time by a factor of 4x, that is a linear RoPE scaling used in SuperHOT. Unfortunately now it is really really hard to distinguish each second, because now the seconds hand barely moves each second. So if someone gave you two different times, which is only different by a single second, you won't be able to distinguish them from afar (let's say the NNs have myopia because that's basically what NTK predicts)

Now NTK-Aware RoPE scaling does not slow down the seconds. One second is still one second, but it slows down minutes by a factor of let's say 1.5, and the hours by a factor of 2. This way you can fit 90 minutes in a hour, and fit 24 hours in half a day. So now you basically have a clock that can measure 129.6k seconds instead of 43.2k seconds.

Because you don't need a precise measurement of the hour hand when looking at the time, scaling the hours more compared to seconds is crucial. You don't want to lose the precision of the seconds hand, but you can afford to lose precision on the minutes hand and even more on the hours hand.

Then, it's just a matter of deriving the base change formula in order to obtain such a scaling. (where less precise dimensions are scaled more and more)

代码实现

NTK的实现则更加简单了,根据超参数alpha,对应修改base变量即可:

来自[7],scaled-rope/scaled_rope/LlamaNTKScaledRotaryEmbedding.py

来自[7],scaled-rope/scaled_rope/LlamaNTKScaledRotaryEmbedding.py

效果

在效果上,帖子中也给出了NTK插值法和线性插值法的PPL比较,可以看到,在二者都不做Finetune的情况下,NTK插值法具备更低的PPL。

来自[5]

来自[5]

动态插值法

动态插值法同样出自于一篇Reddit帖子[6],它的出发点很简单:

My idea was to use the exact position values for the first 2k context (after all, why mess with a good thing?) and then re-calculate the position vector for every new sequence length as the model generates token by token.

这种做法可以和先前的两种方法相结合,[7]中也给出了详细的代码实现。

效果

作者和前两种方法做了对比,展示了动态插值法在PPL下降上的优势。

实践经验

我在实践的过程中,评估效果主要使用longChat[4]中使用的评估方式,以下是一些takeaway tips,欢迎大家一起交流。

  1. 线性插值法具备完整的理论支持和大量的实验证明,在我的实践中,“线性插值法+Finetune”取得了最佳效果。

  2. NTK插值法的实验中,对比的是不做Finetune的情况,在我的实践中,“NTK插值+Finetune”效果会明显优于单独的“NTK插值”,但它的收敛速度会慢于“线性插值法+Finetune”。

  3. 动态插值法的实验同样是在不做Finetune的情况对比的,目前为止我并没有尝试过这种方法。在Reddit的评论区有人提出一个很好的问题:如果采取这种方法,逐token推理时,文本的长度是在变化的,则导致无法使用kv-cache,这会对性能产生很大的影响。

最后,拉长LLaMA的方案可以不从RoPE入手(如:LongLLaMA[8]),但“线性插值法+Finetune”无疑是一种性价比很高的方案,推荐大家尝试!

——2023.07.31

Reference

[1] Extending Context is Hard…but not Impossible†

[2] EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

[3] MPT-30B: Raising the bar for open-source foundation models

[4] How Long Can Open-Source LLMs Truly Promise on Context Length?

[5] NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.

[6] Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

[7] GitHub - jquesnelle/scaled-rope

[8] Focused Transformer: Contrastive Training for Context Scaling

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/54014.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[Docker实现测试部署CI/CD----相关服务器的安装配置(1)]

目录 0、CI/CD系统最终架构图规划IP地址 1、git配置Git下载pycharm配置gitidea配置git 2、GitLab安装与配置主机要求拉取镜像定义 compose.yml启动gitlab浏览器访问并修改密码查看登录密码修改密码 3、SonarQube 安装与配置拉取镜像修改虚拟内存的大小启动SonarQube登录 SonarQ…

2023 蓝桥杯真题B组 C/C++

https://www.dotcpp.com/oj/train/1089/ 题目 3150: 蓝桥杯2023年第十四届省赛真题-冶炼金属 题目描述 小蓝有一个神奇的炉子用于将普通金属 O 冶炼成为一种特殊金属 X。这个炉子有一个称作转换率的属性 V,V 是一个正整数,这意味着消耗 V 个普通金 属 O…

价值 1k 嵌入式面试题-计算机网络 OSI

开门见山 请讲下 OSI 各层协议的主要功能? 常见问题 回答不系统回答不确切无法和实际网络协议做关联对应 答题思路 OSI 代表了开放互联系统中信息从一台计算机的一个软件应用流到另一个计算机的另一个软件应用的参考模型 OSI 包含 7 层,每一层负责特…

51单片机学习--串口通信

首先需要配置寄存器: 下面这里SCON配0x40和0x50都可以,因为暂时还不需要接受信息,所以REN置1置0都可 void Uart_Init(void) //4800bps11.0592MHz {PCON | 0x80; //使能波特率倍速位SMODSCON 0x50; //8位数据,可变波特率TMOD & 0x0F…

tinkerCAD案例:29. 摇头娃娃

Research Your Favorite Bobblehead 摇头娃娃 Project Overview: 项目概况: Design and create your favorite Minecraft 3D bobble head. All you need is a computer, 3D printer, spring and your creativity to your favorite Minecraft character in the for…

1、Hadoop3.x 从入门到放弃,第一章:概念

Hadoop3.x从入门到放弃,第一章:概念 一、什么是大数据 1、主要解决什么 大数据主要解决:海量数据的“采集”、“存储” 和 "分析计算" 问题2、大数据特点 1> Volume 大量 2> velocity 高速 3> variety 多样性数据分为…

网络层:IP协议/Mac协议

IP协议 主机: 配有IP地址, 但是不进行路由控制的设备; 路由器: 即配有IP地址, 又能进行路由控制; 节点: 主机和路由器的统 称; IP 目标网络(前半部分) 目标主机(后半部分) IP层的核心:IP地址定位主机(定…

Socks IP轮换:为什么是数据挖掘和Web爬取的最佳选择?

在数据挖掘和Web爬取的过程中,IP轮换是一个非常重要的概念。数据挖掘和Web爬取需要从多个网站或来源获取数据,而这些网站通常会对来自同一IP地址的请求进行限制或封锁。为了避免这些问题,数据挖掘和Web爬取过程中需要使用Socks IP轮换技术。在…

《向量数据库指南》——如何持久化存储 LlamaIndex 向量索引?

随着 AGI 时代的到来,越来越多的开发者开始思考如何有效利用大模型,不过,大家在构建 LLM 应用时普遍会面临三大挑战: LLM 的使用成本高昂LLM 无法及时提供最新信息LLM 缺乏特定专业领域的知识 针对上述问题,业界主流的做法是采用两种主要框架:微调和缓存 + 注入。 …

集团MySQL的酒店管理系统

酒店管理系统 概述 基于Spring Spring MVC MyBatis的酒店管理系统,主要实现酒店客房的预定、入住以及结账等功能。使用Maven进行包管理。 用户端主要功能包括: 登录注册、客房预订、客房评论(编写评论和查看评论) 后台管理主要…

如何在 Ubuntu 22.04 下编译 StoneDB for MySQL 8.0 | StoneDB 使用教程 #1

作者:双飞(花名:小鱼) 杭州电子科技大学在读硕士 StoneDB 内核研发实习生 ❝ 大家好,我是 StoneDB 的实习生小鱼,目前正在做 StoneDB 8.0 内核升级相关的一些事情。刚开始接触数据库开发没多久&#xff0c…

Linux 学习记录59(ARM篇)

Linux 学习记录59(ARM篇) 本文目录 Linux 学习记录59(ARM篇)一、IIC总线1. 概念2. IIC总线硬件连接 二、系统框图三、IIC时序1. 起始信号 / 停止信号2. 数据传输信号3. 应答信号 / 非应答信号4. 寻址信号 四、IIC协议1. 主机给从机发送一个字节(写)2. 主机给从机发送多个连续字…

MySQL 的 Join 查询及 Hash Join 优化 | StoneDB 技术分享会 #3

StoneDB开源地址 https://github.com/stoneatom/stonedb 设计:小艾 审核:丁奇、宇亭 编辑:宇亭 作者一:徐鑫强(花名:无花果) 电子科技大学-计算机技术-在读硕士、StoneDB 内核研发实习生 作…

Android 卡顿分析与布局优化

一、什么是卡顿?或者说我们怎么感知APP卡顿? 这里面涉及到android UI渲染机制,我们先了解一下android UI是怎么渲染的,android的View到底是如何一步一步显示到屏幕上的? android系统渲染页面流程: 1&…

重新审视MHA与Transformer

本文将基于PyTorch源码重新审视MultiheadAttention与Transformer。事实上,早在一年前博主就已经分别介绍了两者:各种注意力机制的PyTorch实现、从零开始手写一个Transformer,但当时的实现大部分是基于d2l教程的,这次将基于PyTorch…

使用javax.validation.constraints进行数据验证

使用javax.validation.constraints进行数据验证 在Java应用中,数据的验证是一个很重要的部分,特别是在接收用户输入或处理外部数据时。为了简化和标准化数据验证的过程,Java提供了javax.validation.constraints包,其中包含一系列注…

乳腺癌CT影像数据的深度学习:R语言与ANN神经网络构建高性能分类诊断模型

一、引言 乳腺癌是全球最常见的女性恶性肿瘤之一,也影响着男性的健康。据统计,每年有数百万人被诊断出患有乳腺癌[1]。乳腺癌的早期检测和准确诊断对于治疗和预后至关重要。然而,乳腺癌的早期诊断面临许多挑战,如图像解读的主观性…

uniapp 微信小程序:v-model双向绑定问题(自定义 props 名无效)

uniapp 微信小程序:v-model双向绑定问题(自定义 props 名无效) 前言问题双向绑定示例使用 v-model使用 v-bind v-on使用 sync 修饰符 参考资料 前言 VUE中父子组件传递数据的基本套路: 父传子 props子传父 this.$emit(事件名, …

Linux安装VScode

从本篇开始,打算有时间就写写在VScode中编写一些ros相关的案例程序用于学习记录。本篇是如何在Linux安装VScode的第一篇。 一、下载VScode 在Linux中打开浏览器输入:https://code.visualstudio.com/Download,选择与你电脑相匹配的版本下载&…

AssertionError: CUDA_HOME does not exist, unable to compile CUDA op(s)

安装deepspeed的时候出现如下错误: 检查是否有CUDA: 根据提示安装: 安装完之后检测,重新安装,成功安装。 参考资料 A100单机多卡大模型训练踩坑记录(CUDA环境、多GPU卡住且显存100%)