AIGC:【LLM(一)】——LoRA微调加速技术

文章目录

    • 一.微调方法
      • 1.1 Instruct微调
      • 1.2 LoRA微调
    • 二.LoRA原理
    • 三.LoRA使用

一.微调方法

Instruct微调和LoRA微调是两种不同的技术。

1.1 Instruct微调

Instruct微调是指在深度神经网络训练过程中调整模型参数的过程,以优化模型的性能。在微调过程中,使用一个预先训练好的模型作为基础模型,然后在新的数据集上对该模型进行微调。Instruct微调是一种通过更新预训练模型的所有参数来完成的微调方法,通过微调使其适用于多个下游应用。

1.2 LoRA微调

LoRA(Low-Rank Adaptation)微调冻结了预训练的模型权重,并将可训练的秩分解矩阵注入到 Transformer 架构的每一层,极大地减少了下游任务的可训练参数的数量。与Instruct微调相比,LoRA在每个Transformer块中注入可训练层,因为不需要为大多数模型权重计算梯度,大大减少了需要训练参数的数量并且降低了GPU内存的要求。研究发现,使用LoRA进行的微调质量与全模型微调相当,速度更快并且需要更少的计算。
基于LoRA的微调产生保存了新的权重,可以将生成的LoRA权重认为是一个原来预训练模型的补丁权重 。所以LoRA模型无法单独使用,需要搭配原模型,两者进行合并即可获得完整版权重。
在这里插入图片描述

二.LoRA原理

神经网络包含许多密集的层,这些层执行矩阵乘法。这些层中的权重矩阵通常具有全秩。当适应特定的任务时,预训练的语言模型往往具有较低的“instrisic dimension”,尽管随机投影到较小的子空间,但仍然可以有效地学习。
LoRA的实现原理在于,冻结预训练模型权重,并将可训练的秩分解矩阵注入到Transformer层的每个权重中,大大减少了下游任务的可训练参数数量。直白的来说,实际上是增加了右侧的“旁支”,也就是先用一个Linear层A,将数据从 d维降到r,再用第二个Linear层B,将数据从r变回d维。最后再将左右两部分的结果相加融合,得到输出的hidden_state。
如上图所示,左边是预训练模型的权重,输入输出维度都是d,在训练期间被冻结,不接受梯度更新。右边部分对A使用随机的高斯初始化,B在训练开始时为零,r是秩,会对△Wx做缩放 α/r。
在这里插入图片描述

Lora 性质 1:全面微调的推广
通过将 LoRA 秩 r 设置为预训练的权重矩阵的秩,大致恢复了完整微调的表达性。换句话说,当增加可训练参数的数量时,训练LoRA大致收敛于训练原始模型。

Lora 性质 2:没有额外的推断延迟
在生产中部署时,可以显式地计算和存储 W = W 0 + B A W = W_0 + BA W=W0+BA,并像往常一样执行推理,也就是将 LoRA 权重和原始模型权重合并,不增加任何的推断耗时 W 0 W_0 W0 B A BA BA 都是 R d × k R^{d×k} Rd×k。当需要切换到另一个下游任务时,可以通过减去 BA 然后添加不同的 B ′ A ′ B'A' BA 来恢复 W 0 W_0 W0,这是一个内存开销很小的快速操作。

LoRA的优势
1)一个预先训练好的模型可以被共享,并用于为不同的任务建立许多小的LoRA模块。可以冻结共享模型,并通过替换图中的矩阵A和B来有效地切换任务,大大降低了存储需求和任务切换的难度。
2)在使用自适应优化器时,LoRA使训练更加有效,并将硬件进入门槛降低了3倍,因为我们不需要计算梯度或维护大多数参数的优化器状态。相反,我们只优化注入的、小得多的低秩矩阵。
3)简单的线性设计允许在部署时将可训练矩阵与冻结权重合并,与完全微调的模型相比,在结构上没有引入推理延迟。
4)LoRA与许多先前的方法是正交的,可以与许多方法结合,如前缀调整。

三.LoRA使用

HuggingFace的PEFT(Parameter-Efficient Fine-Tuning)中提供了模型微调加速的方法,参数高效微调(PEFT)方法能够使预先训练好的语言模型(PLMs)有效地适应各种下游应用,而不需要对模型的所有参数进行微调。

对大规模的PLM进行微调往往成本过高,在这方面,PEFT方法只对少数(额外的)模型参数进行微调,基本思想在于仅微调少量 (额外) 模型参数,同时冻结预训练 LLM 的大部分参数,从而大大降低了计算和存储成本,这也克服了灾难性遗忘的问题,这是在 LLM 的全参数微调期间观察到的一种现象PEFT 方法也显示出在低数据状态下比微调更好,可以更好地泛化到域外场景。

例如,使用PEFT-lora进行加速微调的效果如下,从中我们可以看到该方案的优势:
在这里插入图片描述

## 1、引入组件并设置参数
from transformers import AutoModelForSeq2SeqLM
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType
import torch
from datasets import load_dataset
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import default_data_collator, get_linear_schedule_with_warmup
from tqdm import tqdm

## 2、搭建模型

peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

## 3、加载数据
dataset = load_dataset("financial_phrasebank", "sentences_allagree")
dataset = dataset["train"].train_test_split(test_size=0.1)
dataset["validation"] = dataset["test"]
del dataset["test"]

classes = dataset["train"].features["label"].names
dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

## 4、训练数据预处理
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[label_column]
    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=3, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


processed_datasets = dataset.map(
 
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)

## 5、设定优化器和正则项
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

## 6、训练与评估

model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
 
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
 
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )
	    eval_epoch_loss = eval_loss / len(eval_dataloader)
	    eval_ppl = torch.exp(eval_epoch_loss)
	    train_epoch_loss = total_loss / len(train_dataloader)
	    train_ppl = torch.exp(train_epoch_loss)
	    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

## 7、模型保存
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id)

## 8、模型推理预测

from peft import PeftModel, PeftConfig
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
 
model = PeftModel.from_pretrained(model, peft_model_id)
model.eval()

inputs = tokenizer(dataset["validation"][text_column][i], return_tensors="pt")
print(dataset["validation"][text_column][i])
print(inputs)
with torch.no_grad():
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)
    print(outputs)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/23468.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

抖音账号运营技巧,让你的短视频更火爆

抖音是目前最火爆的短视频平台之一,拥有着庞大的用户群体和广阔的市场前景。在这个平台上,每天都有大量的用户在发布自己的短视频内容,让自己的账号脱颖而出并吸引更多的粉丝,成为每个用户所追求的目标。下面就来介绍一些抖音账号…

QDir拼接路径解决各种斜杠问题

一般在项目中经常需要组合路径,与其他程序进行相互调用传递消息通信。 经常可能因为多加斜杠、少加斜杠等问题导致很多问题。 为了解决这些问题,我们可以使用QDir来完成路径的拼接,不直接拼接字符串。 QDir的静态方法QDir::cleanPath() 是为了规范化路径名的,在使用QDir组…

Atlassian数据迁移攻略:迁移前必备须知

到2024年2月,Atlassian将终止对Server产品及插件的所有支持。是时候制定您的迁移计划了——Atlassian为您提供两种迁移选择,一是本地部署的数据中心版本,中国用户25人以上即可使用,二是云版。作为Atlassian全球白金合作伙伴&#…

Mybatisplus真实高效批量插入附容错机制

文章目录 概要优化技术细节小结 概要 提示:mybatisplus自带真实批量插入 在mybatisplus已知常用批量插入为继承Iservice里的saveBatch方法和saveOrUpdateBatch方法, 进入源码可知,此两种方法的插入均为单条插入,如图: 其中可看出&#xff0…

JavaScript之DOM基础

1. 初识DOM DOM: 文档对象模型 是W3C组织推荐的处理可扩展标记语言(HTML或者XML)的标准编程接口。 w3c已经定义了一系列的DOM接口,通过这些DOM接口可以改变网页的内容,结构和样式。 DOM树:文档:一个页面就是一个文档, DOM中使用d…

LeetCode_DFS_困难_1377.T 秒后青蛙的位置

目录 1.题目2.思路3.代码实现(Java) 1.题目 给你一棵由 n 个顶点组成的无向树,顶点编号从 1 到 n。青蛙从 顶点 1 开始起跳。规则如下: 在一秒内,青蛙从它所在的当前顶点跳到另一个未访问过的顶点(如果它…

redis【stream】:对redis流数据类型的详细介绍

目录 stream产生原因 stream的概念 stream底层实现 stream的常用指令 常用命令一览: xadd命令 xread命令 xlen命令 xrange命令 xrevrange命令 xtrim命令 xdel命令 xgroup命令 xinfo命令 xpending命令 xreadgroup命令 xack命令 xclaim命令 stream产…

linux周六串讲

esc. //粘贴复制上一条命令的参数 cat /etc/resolv.conf //查看DNS地址 route -n //查看网关 hostname //临时修改主机名 hostnamectl set-hostname 名称 //永久修改主机名 ssh root192.168.10.233 //用windows远程的格式,在CMD窗口输入这个命令 …

zabbix部署

zabbix 是什么? zabbix 是一个基于 Web 界面的提供分布式系统监视以及网络监视功能的企业级的开源解决方案。 zabbix 能监视各种网络参数,保证服务器系统的安全运营;并提供灵活的通知机制以让系统管理员快速定位/解决存在的各种问题。 zabbi…

ZooKeeper快速入门学习+在springboot中的应用+监听机制的业务使用

目录 前言 基础知识 一、什么是ZooKeeper 二、为什么使用ZooKeeper 三、数据结构 四、监听通知机制 五、选举机制 使用 1 下载zookeeper 2 修改 3 排错 在SpringBoot中的使用 安装可视化插件 依赖 配置 安装httpclient方便测试 增删查改 新建控制器 创建节点…

云计算:优势与未来趋势

文章目录 前言一、云计算的优势1. 降低IT成本2. 提高工作效率3. 提高业务的可靠性和稳定性4. 提升安全性 二、未来发展趋势1. AI与云计算的融合2. 边缘计算的发展3. 多云的趋势4. 服务器和存储的创新 三、 行业应用案例1.金融行业2.医疗保健行业3.教育行业4.零售和物流行业 四、…

【章节1】git commit规范 + husky + lint-staged实现commit的时候格式化代码

创建项目我们不多说,可以选择默认的,也可以用你们现有的项目。 前言: git commit 的时候总有人填写一堆花里胡哨乱写的内容,甚至看了commit 的描述都不知道他这次提交到底做了个啥,那我们有没有办法规范大家的commit提…

边沿检测电路

目录 同步信号的边沿检测 异步信号的边沿检测 所谓的边沿检测(幼教边沿提取),就是检测输入信号的上升沿和下降沿。在设计数字系统时,边沿检测是一种很重要的思想,实际编程时用的最多的时序电路应该就是边沿检测电路和…

docker网络管理

1、常用的网络模式 Docker 容器的网络默认与宿主机、与其他容器都是相互隔离。 1.1、host 使用主机的网络命名空间,这意味着容器与主机共享同一个IP地址和端口号。使用host网络可以提高容器的网络性能,但是会降低容器的隔离性(容器直接使用宿主机网络栈…

浅谈电解电容在电路设计中的作用

谈起电解电容我们不得下多了解一下它的作用 1、滤波作用 在电源电路中,整流电路将交流变成脉动的直流,而在整流电路之后接入一个较大容量的电解电容,利用其充放电特性(储能作用),使整流后的脉动直流电压变成相对比较稳定的直流电…

C++是如何从代码到游戏的

有一个Student类。C怎么创建一个学生类的对象? // 嗯我会!有两种方式: Student s; Student *s2 new Student("张三");现在这学生的行为有:吃饭,睡觉,上网课。现在你执行个上网课的行为&#xf…

次氯酸消毒剂制备中的全氟醚橡胶密封耐腐蚀电动阀门解决方案

摘要:次氯酸作为是一种新型消毒剂,近年来广泛应用于医疗卫生机构、公共卫生场所和家庭的一般物体表面、医疗器械、医疗废物等。由于次氯酸的酸性和强氧化性,使得次氯酸生产制备过程中会给流量调节阀门带来腐蚀并影响寿命和控制精度&#xff0…

UE5电脑配置要求是什么?2023虚幻5电脑配置推荐

虚幻引擎对于游戏创作者来说已经不再陌生。该软件为程序员构建和设计终极视频游戏,以创建壮观的游戏场景和流畅的动作。此外,它还处理音效、物理碰撞效果和控制。尤其是人工智能对角色的控制。与其他软件一样,Unreal Engine也有最低系统要求才…

数据类型的陷进,从表象看本质!

哪些值转为布尔值为false 1、undefined(未定义,找不到值时出现) 2、null(代表空值) 3、false(布尔值的false,字符串"false"布尔值为true) 4、0(数字0&…

notepad++查询指定内容并复制

背景说明 记录一下使用notepad进行文本内容查找以及替换的相关场景,简单记录方便后期查看,场景内容: 1.从指定的给出内容中筛选出所有的人员id集合 2.将每一行后面添加逗号 1.从指定的给出内容中筛选出所有的人员id集合 要求从指定的给出内容中筛选出所有的人员id集…
最新文章