AI大模型探索之路-训练篇17:大语言模型预训练-微调技术之QLoRA

系列篇章💥

AI大模型探索之路-训练篇1:大语言模型微调基础认知
AI大模型探索之路-训练篇2:大语言模型预训练基础认知
AI大模型探索之路-训练篇3:大语言模型全景解读
AI大模型探索之路-训练篇4:大语言模型训练数据集概览
AI大模型探索之路-训练篇5:大语言模型预训练数据准备-词元化
AI大模型探索之路-训练篇6:大语言模型预训练数据准备-预处理
AI大模型探索之路-训练篇7:大语言模型Transformer库之HuggingFace介绍
AI大模型探索之路-训练篇8:大语言模型Transformer库-预训练流程编码体验
AI大模型探索之路-训练篇9:大语言模型Transformer库-Pipeline组件实践
AI大模型探索之路-训练篇10:大语言模型Transformer库-Tokenizer组件实践
AI大模型探索之路-训练篇11:大语言模型Transformer库-Model组件实践
AI大模型探索之路-训练篇12:语言模型Transformer库-Datasets组件实践
AI大模型探索之路-训练篇13:大语言模型Transformer库-Evaluate组件实践
AI大模型探索之路-训练篇14:大语言模型Transformer库-Trainer组件实践
AI大模型探索之路-训练篇15:大语言模型预训练之全量参数微调
AI大模型探索之路-训练篇16:大语言模型预训练-微调技术之LoRA


目录

  • 系列篇章💥
  • 前言
  • 一、QLoRA 总体概述
  • 二、QLoRA原理解释(4-bit NormalFloat)
  • 三、QLoRA代码实践
    • 学术资源加速
    • 步骤1 导入相关包
    • 步骤2 加载数据集
    • 步骤3 数据集预处理
      • 1)获取分词器
      • 2)定义数据处理函数
      • 3)对数据进行预处理
    • 步骤4 创建模型
      • 1、PEFT 步骤1 配置文件
      • 2、PEFT 步骤2 创建模型
    • 步骤5 配置训练参数
    • 步骤6 创建训练器
    • 步骤7 模型训练
    • 步骤8 模型推理
  • 总结


前言

在深度学习的不断进步中,大型语言模型(LLMs)的预训练和微调技术成为了研究的热点。其中,量化技术以其在模型压缩和加速方面的潜力备受关注。本文将深入探讨QLoRA(Quantized Low-Rank Adaptation)技术的原理、实践及应用。

一、QLoRA 总体概述

QLoRA技术是一种创新的量化LoRA(Low-Rank Adaptation)的技术,旨在保持模型性能的同时,显著减少模型的内存占用。该技术的核心包括:
1)4bit NormalFloat(NF4): 这是针对正态分布权重设计的一种信息理论上最优的数据类型。相较于传统的4-bit整数和4-bit浮点数,NF4为正态分布数据提供了更优异的实证性能。
2)双量化:QLoRA采用一种独特的双重量化机制,对初次量化后的常量进行二次量化,进一步压缩存储空间。
3)分页优化器:使用NVIDIA统一内存特性,该特性可以在在GPU偶尔OOM的情况下,进行CPU和GPU之间自动分页到分页的传输,以实现无错误的 GPU 处理。该功能的工作方式类似于 CPU 内存和磁盘之间的常规内存分页。使用此功能为优化器状态(Optimizer)分配分页内存, 然后在 GPU 内存不足时将其自动卸载到 CPU 内存,并在优化器更新步骤需要时将其加载回 GPU 内存。
在这里插入图片描述

二、QLoRA原理解释(4-bit NormalFloat)

前面篇章中我们有介绍,通常为了减少GPU的使用,我们会对模型进行量化处理,减少资源的使用;int8、int4量化是一种有效的模型压缩技术,它通过减少数值的精度来换取计算效率的提升,同时尽量保持模型的准确性。
在这里插入图片描述

1)常规int8量化和反量化过程:
在这里插入图片描述

2)常规int4量化和反量化过程:
在这里插入图片描述

3)QLoRA的NF4量化
是一种特殊的4位浮点数(Normal Float 4-bit)量化方法。它不仅定义了一种新的数据类型,还采用了基于分块的分位数量化策略,这种方法能够更有效地保持数值的相对关系,并且减少了由于量化引入的误差。QLoRA的NF4量化通过双重量化进一步减小了缓存占用,并且结合低秩适配器(LoRA)进行模型微调,可以在有限的计算资源下达到较高的性能水平。
在这里插入图片描述

三、QLoRA代码实践

学术资源加速

方便从huggingface下载模型,这云平台autodl提供的,仅适用于autodl。

import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

步骤1 导入相关包

开始之前,我们需要导入适用于模型训练和推理的必要库,如transformers。

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

步骤2 加载数据集

使用适当的数据加载器,例如datasets库,来加载预处理过的指令遵循性任务数据集。

ds = Dataset.load_from_disk("/root/tuning/lesson01/data/alpaca_data_zh/")
ds

输出:

Dataset({
    features: ['output', 'input', 'instruction'],
    num_rows: 26858
})
ds[:1]

输出

{'output': ['以下是保持健康的三个提示:\n\n1. 保持身体活动。每天做适当的身体运动,如散步、跑步或游泳,能促进心血管健康,增强肌肉力量,并有助于减少体重。\n\n2. 均衡饮食。每天食用新鲜的蔬菜、水果、全谷物和脂肪含量低的蛋白质食物,避免高糖、高脂肪和加工食品,以保持健康的饮食习惯。\n\n3. 睡眠充足。睡眠对人体健康至关重要,成年人每天应保证 7-8 小时的睡眠。良好的睡眠有助于减轻压力,促进身体恢复,并提高注意力和记忆力。'],
 'input': [''],
 'instruction': ['保持健康的三个提示。']}

步骤3 数据集预处理

利用预训练模型的分词器(Tokenizer)对原始文本进行编码,并生成相应的输入ID、注意力掩码和标签。

1)获取分词器

tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")
tokenizer

输出:

BloomTokenizerFast(name_or_path='Langboat/bloom-1b4-zh', vocab_size=46145, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=False)

2)定义数据处理函数

def process_func(example):
    # 设置最大长度为256
    MAX_LENGTH = 256
    # 初始化输入ID、注意力掩码和标签列表
    input_ids, attention_mask, labels = [], [], []
    # 对指令和输入进行编码
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    # 对输出进行编码,并添加结束符
    response = tokenizer(example["output"] + tokenizer.eos_token)
    # 将指令和响应的输入ID拼接起来
    input_ids = instruction["input_ids"] + response["input_ids"]
    # 将指令和响应的注意力掩码拼接起来
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    # 将指令的标签设置为-100,表示不计算损失;将响应的输入ID作为标签
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    # 如果输入ID的长度超过最大长度,截断输入ID、注意力掩码和标签
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    # 返回处理后的数据
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

3)对数据进行预处理

tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

输出:

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 26858
})

步骤4 创建模型

然后,我们实例化一个预训练模型,这个模型将作为微调的基础。对于大型模型,我们可能还需要进行一些特定的配置,以适应可用的计算资源。(在实例化时,指定量化参数

import torch
##修改
# low_cpu_mem_usage=True: 这个参数设定为True意味着在模型加载时会尽可能地减少CPU内存的使用。
# torch_dtype=torch.half: 这个参数设置了模型中张量的数据类型为半精度浮点数,这可以减少内存占用和计算时间,但可能会牺牲一些精度。
# device_map="auto": 这个参数设置了模型应该在哪个设备上运行。“auto”意味着它将自动选择可用的设备,优先选择GPU,如果没有GPU则选择CPU。
# load_in_4bit=True: 这个参数设置为True意味着在模型加载时将使用4位量化,这可以进一步减少内存占用。
# bnb_4bit_compute_dtype=torch.half: 这个参数设置了在4位量化时的计算数据类型,这里设置为半精度浮点数。
# bnb_4bit_quant_type="nf4": 这个参数设置了4位量化的类型,"nf4"是一种特定的量化策略。
# bnb_4bit_use_double_quant=True: 这个参数设置为True意味着在4位量化时使用双重量化。
model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh",
                                              torch_dtype=torch.half,
                                              low_cpu_mem_usage=True, 
                                              device_map="auto", 
                                              load_in_4bit=True,
                                              bnb_4bit_quant_type="nf4", 
                                              bnb_4bit_use_double_quant=True)
model.dtype

torch.float16

查看参数,查看模型有哪些层,可以用于添加LoRA旁路

for name, parameter in model.named_parameters():
    print(name,parameter.dtype)

输出

transformer.word_embeddings.weight torch.float16
transformer.word_embeddings_layernorm.weight torch.float16
transformer.word_embeddings_layernorm.bias torch.float16
transformer.h.0.input_layernorm.weight torch.float16
transformer.h.0.input_layernorm.bias torch.float16
transformer.h.0.self_attention.query_key_value.weight torch.uint8
transformer.h.0.self_attention.query_key_value.bias torch.float16
transformer.h.0.self_attention.dense.weight torch.uint8
transformer.h.0.self_attention.dense.bias torch.float16
transformer.h.0.post_attention_layernorm.weight torch.float16
transformer.h.0.post_attention_layernorm.bias torch.float16
transformer.h.0.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.0.mlp.dense_h_to_4h.bias torch.float16
transformer.h.0.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.0.mlp.dense_4h_to_h.bias torch.float16
transformer.h.1.input_layernorm.weight torch.float16
transformer.h.1.input_layernorm.bias torch.float16
transformer.h.1.self_attention.query_key_value.weight torch.uint8
transformer.h.1.self_attention.query_key_value.bias torch.float16
transformer.h.1.self_attention.dense.weight torch.uint8
transformer.h.1.self_attention.dense.bias torch.float16
transformer.h.1.post_attention_layernorm.weight torch.float16
transformer.h.1.post_attention_layernorm.bias torch.float16
transformer.h.1.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.1.mlp.dense_h_to_4h.bias torch.float16
transformer.h.1.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.1.mlp.dense_4h_to_h.bias torch.float16
transformer.h.2.input_layernorm.weight torch.float16
transformer.h.2.input_layernorm.bias torch.float16
transformer.h.2.self_attention.query_key_value.weight torch.uint8
transformer.h.2.self_attention.query_key_value.bias torch.float16
transformer.h.2.self_attention.dense.weight torch.uint8
transformer.h.2.self_attention.dense.bias torch.float16
transformer.h.2.post_attention_layernorm.weight torch.float16
transformer.h.2.post_attention_layernorm.bias torch.float16
transformer.h.2.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.2.mlp.dense_h_to_4h.bias torch.float16
transformer.h.2.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.2.mlp.dense_4h_to_h.bias torch.float16
transformer.h.3.input_layernorm.weight torch.float16
transformer.h.3.input_layernorm.bias torch.float16
transformer.h.3.self_attention.query_key_value.weight torch.uint8
transformer.h.3.self_attention.query_key_value.bias torch.float16
transformer.h.3.self_attention.dense.weight torch.uint8
transformer.h.3.self_attention.dense.bias torch.float16
transformer.h.3.post_attention_layernorm.weight torch.float16
transformer.h.3.post_attention_layernorm.bias torch.float16
transformer.h.3.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.3.mlp.dense_h_to_4h.bias torch.float16
transformer.h.3.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.3.mlp.dense_4h_to_h.bias torch.float16
transformer.h.4.input_layernorm.weight torch.float16
transformer.h.4.input_layernorm.bias torch.float16
transformer.h.4.self_attention.query_key_value.weight torch.uint8
transformer.h.4.self_attention.query_key_value.bias torch.float16
transformer.h.4.self_attention.dense.weight torch.uint8
transformer.h.4.self_attention.dense.bias torch.float16
transformer.h.4.post_attention_layernorm.weight torch.float16
transformer.h.4.post_attention_layernorm.bias torch.float16
transformer.h.4.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.4.mlp.dense_h_to_4h.bias torch.float16
transformer.h.4.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.4.mlp.dense_4h_to_h.bias torch.float16
transformer.h.5.input_layernorm.weight torch.float16
transformer.h.5.input_layernorm.bias torch.float16
transformer.h.5.self_attention.query_key_value.weight torch.uint8
transformer.h.5.self_attention.query_key_value.bias torch.float16
transformer.h.5.self_attention.dense.weight torch.uint8
transformer.h.5.self_attention.dense.bias torch.float16
transformer.h.5.post_attention_layernorm.weight torch.float16
transformer.h.5.post_attention_layernorm.bias torch.float16
transformer.h.5.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.5.mlp.dense_h_to_4h.bias torch.float16
transformer.h.5.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.5.mlp.dense_4h_to_h.bias torch.float16
transformer.h.6.input_layernorm.weight torch.float16
transformer.h.6.input_layernorm.bias torch.float16
transformer.h.6.self_attention.query_key_value.weight torch.uint8
transformer.h.6.self_attention.query_key_value.bias torch.float16
transformer.h.6.self_attention.dense.weight torch.uint8
transformer.h.6.self_attention.dense.bias torch.float16
transformer.h.6.post_attention_layernorm.weight torch.float16
transformer.h.6.post_attention_layernorm.bias torch.float16
transformer.h.6.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.6.mlp.dense_h_to_4h.bias torch.float16
transformer.h.6.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.6.mlp.dense_4h_to_h.bias torch.float16
transformer.h.7.input_layernorm.weight torch.float16
transformer.h.7.input_layernorm.bias torch.float16
transformer.h.7.self_attention.query_key_value.weight torch.uint8
transformer.h.7.self_attention.query_key_value.bias torch.float16
transformer.h.7.self_attention.dense.weight torch.uint8
transformer.h.7.self_attention.dense.bias torch.float16
transformer.h.7.post_attention_layernorm.weight torch.float16
transformer.h.7.post_attention_layernorm.bias torch.float16
transformer.h.7.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.7.mlp.dense_h_to_4h.bias torch.float16
transformer.h.7.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.7.mlp.dense_4h_to_h.bias torch.float16
transformer.h.8.input_layernorm.weight torch.float16
transformer.h.8.input_layernorm.bias torch.float16
transformer.h.8.self_attention.query_key_value.weight torch.uint8
transformer.h.8.self_attention.query_key_value.bias torch.float16
transformer.h.8.self_attention.dense.weight torch.uint8
transformer.h.8.self_attention.dense.bias torch.float16
transformer.h.8.post_attention_layernorm.weight torch.float16
transformer.h.8.post_attention_layernorm.bias torch.float16
transformer.h.8.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.8.mlp.dense_h_to_4h.bias torch.float16
transformer.h.8.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.8.mlp.dense_4h_to_h.bias torch.float16
transformer.h.9.input_layernorm.weight torch.float16
transformer.h.9.input_layernorm.bias torch.float16
transformer.h.9.self_attention.query_key_value.weight torch.uint8
transformer.h.9.self_attention.query_key_value.bias torch.float16
transformer.h.9.self_attention.dense.weight torch.uint8
transformer.h.9.self_attention.dense.bias torch.float16
transformer.h.9.post_attention_layernorm.weight torch.float16
transformer.h.9.post_attention_layernorm.bias torch.float16
transformer.h.9.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.9.mlp.dense_h_to_4h.bias torch.float16
transformer.h.9.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.9.mlp.dense_4h_to_h.bias torch.float16
transformer.h.10.input_layernorm.weight torch.float16
transformer.h.10.input_layernorm.bias torch.float16
transformer.h.10.self_attention.query_key_value.weight torch.uint8
transformer.h.10.self_attention.query_key_value.bias torch.float16
transformer.h.10.self_attention.dense.weight torch.uint8
transformer.h.10.self_attention.dense.bias torch.float16
transformer.h.10.post_attention_layernorm.weight torch.float16
transformer.h.10.post_attention_layernorm.bias torch.float16
transformer.h.10.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.10.mlp.dense_h_to_4h.bias torch.float16
transformer.h.10.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.10.mlp.dense_4h_to_h.bias torch.float16
transformer.h.11.input_layernorm.weight torch.float16
transformer.h.11.input_layernorm.bias torch.float16
transformer.h.11.self_attention.query_key_value.weight torch.uint8
transformer.h.11.self_attention.query_key_value.bias torch.float16
transformer.h.11.self_attention.dense.weight torch.uint8
transformer.h.11.self_attention.dense.bias torch.float16
transformer.h.11.post_attention_layernorm.weight torch.float16
transformer.h.11.post_attention_layernorm.bias torch.float16
transformer.h.11.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.11.mlp.dense_h_to_4h.bias torch.float16
transformer.h.11.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.11.mlp.dense_4h_to_h.bias torch.float16
transformer.h.12.input_layernorm.weight torch.float16
transformer.h.12.input_layernorm.bias torch.float16
transformer.h.12.self_attention.query_key_value.weight torch.uint8
transformer.h.12.self_attention.query_key_value.bias torch.float16
transformer.h.12.self_attention.dense.weight torch.uint8
transformer.h.12.self_attention.dense.bias torch.float16
transformer.h.12.post_attention_layernorm.weight torch.float16
transformer.h.12.post_attention_layernorm.bias torch.float16
transformer.h.12.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.12.mlp.dense_h_to_4h.bias torch.float16
transformer.h.12.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.12.mlp.dense_4h_to_h.bias torch.float16
transformer.h.13.input_layernorm.weight torch.float16
transformer.h.13.input_layernorm.bias torch.float16
transformer.h.13.self_attention.query_key_value.weight torch.uint8
transformer.h.13.self_attention.query_key_value.bias torch.float16
transformer.h.13.self_attention.dense.weight torch.uint8
transformer.h.13.self_attention.dense.bias torch.float16
transformer.h.13.post_attention_layernorm.weight torch.float16
transformer.h.13.post_attention_layernorm.bias torch.float16
transformer.h.13.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.13.mlp.dense_h_to_4h.bias torch.float16
transformer.h.13.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.13.mlp.dense_4h_to_h.bias torch.float16
transformer.h.14.input_layernorm.weight torch.float16
transformer.h.14.input_layernorm.bias torch.float16
transformer.h.14.self_attention.query_key_value.weight torch.uint8
transformer.h.14.self_attention.query_key_value.bias torch.float16
transformer.h.14.self_attention.dense.weight torch.uint8
transformer.h.14.self_attention.dense.bias torch.float16
transformer.h.14.post_attention_layernorm.weight torch.float16
transformer.h.14.post_attention_layernorm.bias torch.float16
transformer.h.14.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.14.mlp.dense_h_to_4h.bias torch.float16
transformer.h.14.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.14.mlp.dense_4h_to_h.bias torch.float16
transformer.h.15.input_layernorm.weight torch.float16
transformer.h.15.input_layernorm.bias torch.float16
transformer.h.15.self_attention.query_key_value.weight torch.uint8
transformer.h.15.self_attention.query_key_value.bias torch.float16
transformer.h.15.self_attention.dense.weight torch.uint8
transformer.h.15.self_attention.dense.bias torch.float16
transformer.h.15.post_attention_layernorm.weight torch.float16
transformer.h.15.post_attention_layernorm.bias torch.float16
transformer.h.15.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.15.mlp.dense_h_to_4h.bias torch.float16
transformer.h.15.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.15.mlp.dense_4h_to_h.bias torch.float16
transformer.h.16.input_layernorm.weight torch.float16
transformer.h.16.input_layernorm.bias torch.float16
transformer.h.16.self_attention.query_key_value.weight torch.uint8
transformer.h.16.self_attention.query_key_value.bias torch.float16
transformer.h.16.self_attention.dense.weight torch.uint8
transformer.h.16.self_attention.dense.bias torch.float16
transformer.h.16.post_attention_layernorm.weight torch.float16
transformer.h.16.post_attention_layernorm.bias torch.float16
transformer.h.16.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.16.mlp.dense_h_to_4h.bias torch.float16
transformer.h.16.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.16.mlp.dense_4h_to_h.bias torch.float16
transformer.h.17.input_layernorm.weight torch.float16
transformer.h.17.input_layernorm.bias torch.float16
transformer.h.17.self_attention.query_key_value.weight torch.uint8
transformer.h.17.self_attention.query_key_value.bias torch.float16
transformer.h.17.self_attention.dense.weight torch.uint8
transformer.h.17.self_attention.dense.bias torch.float16
transformer.h.17.post_attention_layernorm.weight torch.float16
transformer.h.17.post_attention_layernorm.bias torch.float16
transformer.h.17.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.17.mlp.dense_h_to_4h.bias torch.float16
transformer.h.17.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.17.mlp.dense_4h_to_h.bias torch.float16
transformer.h.18.input_layernorm.weight torch.float16
transformer.h.18.input_layernorm.bias torch.float16
transformer.h.18.self_attention.query_key_value.weight torch.uint8
transformer.h.18.self_attention.query_key_value.bias torch.float16
transformer.h.18.self_attention.dense.weight torch.uint8
transformer.h.18.self_attention.dense.bias torch.float16
transformer.h.18.post_attention_layernorm.weight torch.float16
transformer.h.18.post_attention_layernorm.bias torch.float16
transformer.h.18.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.18.mlp.dense_h_to_4h.bias torch.float16
transformer.h.18.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.18.mlp.dense_4h_to_h.bias torch.float16
transformer.h.19.input_layernorm.weight torch.float16
transformer.h.19.input_layernorm.bias torch.float16
transformer.h.19.self_attention.query_key_value.weight torch.uint8
transformer.h.19.self_attention.query_key_value.bias torch.float16
transformer.h.19.self_attention.dense.weight torch.uint8
transformer.h.19.self_attention.dense.bias torch.float16
transformer.h.19.post_attention_layernorm.weight torch.float16
transformer.h.19.post_attention_layernorm.bias torch.float16
transformer.h.19.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.19.mlp.dense_h_to_4h.bias torch.float16
transformer.h.19.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.19.mlp.dense_4h_to_h.bias torch.float16
transformer.h.20.input_layernorm.weight torch.float16
transformer.h.20.input_layernorm.bias torch.float16
transformer.h.20.self_attention.query_key_value.weight torch.uint8
transformer.h.20.self_attention.query_key_value.bias torch.float16
transformer.h.20.self_attention.dense.weight torch.uint8
transformer.h.20.self_attention.dense.bias torch.float16
transformer.h.20.post_attention_layernorm.weight torch.float16
transformer.h.20.post_attention_layernorm.bias torch.float16
transformer.h.20.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.20.mlp.dense_h_to_4h.bias torch.float16
transformer.h.20.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.20.mlp.dense_4h_to_h.bias torch.float16
transformer.h.21.input_layernorm.weight torch.float16
transformer.h.21.input_layernorm.bias torch.float16
transformer.h.21.self_attention.query_key_value.weight torch.uint8
transformer.h.21.self_attention.query_key_value.bias torch.float16
transformer.h.21.self_attention.dense.weight torch.uint8
transformer.h.21.self_attention.dense.bias torch.float16
transformer.h.21.post_attention_layernorm.weight torch.float16
transformer.h.21.post_attention_layernorm.bias torch.float16
transformer.h.21.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.21.mlp.dense_h_to_4h.bias torch.float16
transformer.h.21.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.21.mlp.dense_4h_to_h.bias torch.float16
transformer.h.22.input_layernorm.weight torch.float16
transformer.h.22.input_layernorm.bias torch.float16
transformer.h.22.self_attention.query_key_value.weight torch.uint8
transformer.h.22.self_attention.query_key_value.bias torch.float16
transformer.h.22.self_attention.dense.weight torch.uint8
transformer.h.22.self_attention.dense.bias torch.float16
transformer.h.22.post_attention_layernorm.weight torch.float16
transformer.h.22.post_attention_layernorm.bias torch.float16
transformer.h.22.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.22.mlp.dense_h_to_4h.bias torch.float16
transformer.h.22.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.22.mlp.dense_4h_to_h.bias torch.float16
transformer.h.23.input_layernorm.weight torch.float16
transformer.h.23.input_layernorm.bias torch.float16
transformer.h.23.self_attention.query_key_value.weight torch.uint8
transformer.h.23.self_attention.query_key_value.bias torch.float16
transformer.h.23.self_attention.dense.weight torch.uint8
transformer.h.23.self_attention.dense.bias torch.float16
transformer.h.23.post_attention_layernorm.weight torch.float16
transformer.h.23.post_attention_layernorm.bias torch.float16
transformer.h.23.mlp.dense_h_to_4h.weight torch.uint8
transformer.h.23.mlp.dense_h_to_4h.bias torch.float16
transformer.h.23.mlp.dense_4h_to_h.weight torch.uint8
transformer.h.23.mlp.dense_4h_to_h.bias torch.float16
transformer.ln_f.weight torch.float16
transformer.ln_f.bias torch.float16

下面2个部分是LoRA相关的配置。

1、PEFT 步骤1 配置文件

在使用PEFT进行微调时,我们首先需要创建一个配置文件,该文件定义了微调过程中的各种设置,如学习率调度、优化器选择等。

from peft import LoraConfig, TaskType, get_peft_model
## ,target_modules=["query_key_value"],r=8
config = LoraConfig(task_type=TaskType.CAUSAL_LM,r=8, target_modules=['query_key_value'])
config

输出

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules=['query_key_value'], lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None)

启用梯度计算

# 在深度神经网络 [deep neural network] 训练时,需要对每个参数或权重 [parameter/weight] 计算其对损失函数 
# [loss function] 的梯度 [gradient],从而进行反向传播 [back propagation] 和优化[optimization]。
# 默认情况下不会计算输入数据 [input data] 的梯度,即使它们在计算中起到了关键的作用。但是,在某些应用场景中,
# 例如图像生成 [image generation]、注意力机制 [attention mechanism] 等,需要计算输入数据的梯度。此时,
# 可以通过启用计算输入梯度的功能,对输入数据进行求导并利用其梯度信息进行优化。
# 作用: 启用该功能这对于在保持模型权重固定的同时微调适配器权重非常有用。

model.enable_input_require_grads()

2、PEFT 步骤2 创建模型

接下来,我们使用PEFT和预训练模型来创建一个微调模型。这个模型将包含原始的预训练模型以及由PEFT引入的低秩参数。

model = get_peft_model(model, config)
model

输出

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): BloomForCausalLM(
      (transformer): BloomModel(
        (word_embeddings): Embedding(46145, 2048)
        (word_embeddings_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (h): ModuleList(
          (0-23): 24 x BloomBlock(
            (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
            (self_attention): BloomAttention(
              (query_key_value): Linear4bit(
                in_features=2048, out_features=6144, bias=True
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=6144, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (dense): Linear4bit(in_features=2048, out_features=2048, bias=True)
              (attention_dropout): Dropout(p=0.0, inplace=False)
            )
            (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
            (mlp): BloomMLP(
              (dense_h_to_4h): Linear4bit(in_features=2048, out_features=8192, bias=True)
              (gelu_impl): BloomGelu()
              (dense_4h_to_h): Linear4bit(in_features=8192, out_features=2048, bias=True)
            )
          )
        )
        (ln_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      )
      (lm_head): Linear(in_features=2048, out_features=46145, bias=False)
    )
  )
)

查看配置

config

输出

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules={'query_key_value', 'dense_4h_to_h'}, lora_alpha=8, lora_dropout=0.0, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={})

步骤5 配置训练参数

在这一步,我们定义训练参数,这些参数包括输出目录、学习率、权重衰减、梯度累积步数、训练周期数等。这些参数将被用来配置训练过程。

指定分页优化器为"paged_adamw_32bit",这是一种针对低秩模型的优化算法

args = TrainingArguments(
    output_dir="/root/autodl-tmp/tuningdata/qlora",  # 指定模型训练结果的输出目录
    per_device_train_batch_size=4,  # 设置每个设备(如GPU)在训练过程中的批次大小为4
    gradient_accumulation_steps=8,  # 指定梯度累积步数为8,即将多个批次的梯度累加后再进行一次参数更新
    logging_steps=20,  # 每20个步骤记录一次日志信息
    num_train_epochs=1,  # 指定训练的总轮数为1
    gradient_checkpointing=True,  # 启用梯度检查点技术,可以减少内存占用并加速训练过程
    optim="paged_adamw_32bit"  # 指定分页优化器为"paged_adamw_32bit",这是一种针对低秩模型的优化算法
)

步骤6 创建训练器

最后,我们创建一个训练器实例,它封装了训练循环。训练器将负责运行训练过程,并根据我们之前定义的参数进行优化。

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)

步骤7 模型训练

通过调用训练器的train()方法,我们启动模型的训练过程。这将根据之前定义的参数执行模型的训练。

trainer.train()

步骤8 模型推理

训练完成后,我们可以使用训练好的模型进行推理。

from peft import PeftModel
from transformers import pipeline

#加载基础模型
model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")

#加载lora模型
p_model = PeftModel.from_pretrained(model=model, model_id="/root/autodl-tmp/tuningdata/qlora/checkpoint-500")

#模型推理
pipe = pipeline("text-generation", model=p_model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("如何写好一个简历?", "").strip() + "\n\nAssistant: "
pipe(ipt, max_length=500, do_sample=True, )

输出

[{'generated_text': 'Human: 如何写好一个简历?\n\nAssistant: 好的,那么你应该考虑以下几点:\n\n1. 职位相关性\n\n有些职位可能会要求你具有相应的学历或工作经验,所以你需要在简历中附上这些信息,以确保你不会被误解为没有相关经验或学历。\n\n2. 个人信息部分\n\n在你的个人信息部分上,一定要附上你自述的职位,并提供你详细的职位描述。\n\n3. 背景与工作经历\n\n在这里你可以列出你过去的工作经历,包括工作项目、取得的奖励、你的发展方向、参加过的课程等。\n\n4. 优势项目\n\n除了上面提到的经历外,你还可以补充一些你擅长的项目,这样你就可以更容易让招聘人员了解到你的特质并做出判断。\n\n5. 能力证明部分\n\n这里你需要附上你过去工作中涉及到的关键工具和流程,以及通过这些工具和流程实现的实际结果。\n\n6. 本职工作领域\n\n你还应该附上你的主要工作领域,这样招聘人员就可以了解你的技能、经验和知识在哪些领域发挥着作用。\n\n7. 专长描述部分\n\n在这些方面,你可以描述一下你在这个专业领域拥有过哪些独特的技能,哪些领域你比其他人更有优势,以及有哪些是你自己擅长的。\n\n8. 职位描述部分\n\n在这里你可以附上你在当前工作领域取得的成就,描述下你在这项工作中能够为公司带来的价值,并证明你能够胜任这份工作。\n\n9. 未来的发展规划\n\n除了这个部分,你也可以补充一些未来的发展规划,这样招聘人员就可以了解你的目标和野心。\n\n10. 联系方式\n\n这里你可以附上你的联络方式,以便招聘人员能够及时与你联系,讨论相关事宜。\n\n11. 备注部分\n\n在简历的最后,你可以附上一个个人备注部分,你可以在这里说明如何能够更好的帮助面试者了解你,并阐述下你想找工作的原因。'}]

总结

QLoRA技术为大型语言模型的预训练与微调提供了一种高效、节省资源的方案。通过精心设计的量化策略和低秩适配器,QLoRA在保证模型性能的同时,显著降低了内存占用,为AI领域的研究者和工程师提供了宝贵的实践经验。

在这里插入图片描述

🎯🔖更多专栏系列文章:AIGC-AI大模型探索之路

如果文章内容对您有所触动,别忘了点赞、⭐关注,收藏!加入我,让我们携手同行AI的探索之旅,一起开启智能时代的大门!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/602073.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

浅谈消息队列和云存储

1970年代末&#xff0c;消息系统用于管理多主机的打印作业&#xff0c;这种削峰解耦的能力逐渐被标准化为“点对点模型”和稍复杂的“发布订阅模型”&#xff0c;实现了数据处理的分布式协同。随着时代的发展&#xff0c;Kafka&#xff0c;Amazon SQS&#xff0c;RocketMQ&…

基于大数据+Hadoop的豆瓣电子图书推荐系统实现

&#x1f339;作者主页&#xff1a;青花锁 &#x1f339;简介&#xff1a;Java领域优质创作者&#x1f3c6;、Java微服务架构公号作者&#x1f604; &#x1f339;简历模板、学习资料、面试题库、技术互助 &#x1f339;文末获取联系方式 &#x1f4dd; 系列文章目录 基于大数…

组合模式(Composite)——结构型模式

组合模式(Composite)——结构型模式 组合模式是一种结构型设计模式&#xff0c; 你可以使用它将对象组合成树状结构&#xff0c; 并且能通过通用接口像独立整体对象一样使用它们。如果应用的核心模型能用树状结构表示&#xff0c; 在应用中使用组合模式才有价值。 例如一个场景…

新能源汽车充电站智慧充电电能服务综合解决方案

安科瑞薛瑶瑶18701709087/17343930412 ★解决方案 ✔目的地充电-EMS微电网平台 基于EMS解决方案从设备运维的角度解决本地充电的能量管理及运维问题&#xff0c;与充电管理平台打通数据&#xff0c;为企业微电网提供源、网、荷、储、充一体化解决方案。 ✔运营场站--电能服务…

​「Python绘图」绘制太极图

python 绘制太极 一、预期结果 二、核心代码 import turtlepen turtle.Turtle()print("开始绘制太极")radius 100 pen.color("black", "black") pen.begin_fill() pen.circle(radius/2, 180) pen.circle(radius, 180) pen.left(180) pen.circ…

英语口语情景对话视频软件分享!

在当今全球化的时代&#xff0c;英语已成为一种通用的国际语言。为了提高英语口语能力&#xff0c;越来越多的人选择使用英语口语情景对话视频软件。本文将为您推荐几款备受欢迎的英语口语情景对话视频软件&#xff0c;帮助您轻松提高英语口语水平。 AI外语陪练 AI外语陪练软件…

营养补充品软胶囊:弹性测试与市场表现的深度解析

营养补充品软胶囊&#xff1a;弹性测试与市场表现的深度解析 在追求健康生活的时代&#xff0c;营养补充品市场蓬勃发展&#xff0c;其中软胶囊作为一种方便、易吸收的剂型&#xff0c;受到了消费者的广泛欢迎。然而&#xff0c;在这个竞争激烈的市场中&#xff0c;如何确保产…

推荐5个AI工具平替GPT

随着AI技术的快速发展&#xff0c;AI写作正成为创作的新风口。但是面对GPT-4这样的国际巨头&#xff0c;国内很多小伙伴往往望而却步&#xff0c;究其原因&#xff0c;就是它的使用门槛高&#xff0c;还有成本的考量。 不过&#xff0c;随着GPT技术的火热&#xff0c;国内也涌…

window11事件查看器中“在事件中只要触发此事件,就会执行相关非XX.xml脚本”

在事件中只要触发此事件&#xff0c;就会执行相关非XX.xml脚本 一、操作过程 1、在时间查看器中&#xff0c;将任务附加到此事件上 2、按照提示逐步下一步添加完成 3、只要触发1中的事件&#xff0c;那么就会执行对应的关联脚本xx.xml。 二、解决办法 1、通过开始菜单搜索打…

riscv交叉编译ports软件@FreeBSD15

当前FreeBSD的riscv版本下&#xff0c;软件包还很贫乏&#xff0c;再加上RISCV的板子有很多种&#xff0c;大部分时候都需要自己动手编译。但是在RISCV环境下编译太慢了&#xff0c;所以我们要使用交叉编译&#xff0c;在很快的AMD64服务器上交叉编译RISCV的软件包。 这里使用…

Promise魔鬼面试题

文章目录 题目解析难点分析分析输出step1step2step3step4step5step6 参考/致谢&#xff1a;渡一袁老师 题目 Promise.resolve().then(() > {console.log(0);return Promise.resolve(4);}).then((res) > {console.log(res);});Promise.resolve().then(() > {console.l…

基于FPGA的数字信号处理(10)--定点数的舍入模式(1)四舍五入round

1、前言 将浮点数定量化为定点数时&#xff0c;有一个避不开的问题&#xff1a;某些小数是无法用有限个数的2进制数来表示的。比如&#xff1a; 0.5(D) 0.1(B) 0.1(D) 0.0001100110011001~~~~(B) 可以看到0.5是可以精准表示的&#xff0c;但是0.1却不行。原因是整数是离散的…

AngusTester安装请求代理

一、介绍 请求代理程序(AngusProxy)提供两个方面作用&#xff1a; 代理Http和WebSocket协议接口调试请求&#xff0c;解决浏览器跨域限制问题。对代理请求客户化处理支持&#xff0c;允许用户对代理请求进行二次处理&#xff0c;如&#xff1a;请求参数签名。 二、类型 为了…

【经验01】spark执行离线任务的一些坑

项目背景: 目前使用spark跑大体量的数据,效率还是挺高的,机器多,120多台的hadoop集群,还是相当的给力的。数据大概有10T的量。 最近在出月报数据的时候发现有一个任务节点一直跑不过去,已经超过失败次数的阈值,报警了。 预警很让人头疼,不能上班摸鱼了。 经过分析发现…

多个glibc库存在时如何查看ldd调用的哪个

但是发现存在多个版本的glibc版本&#xff0c;需要查看具体的库的信息&#xff0c;和相应的关键函数的信息&#xff0c;但是并不知道具体的libc.so.6的路径信息 rootalg-dev04:~/xingqiao# ldd --version ldd (GNU libc) 2.29 rootalg-dev04:/opt# which ldd /usr/local/bin/…

工厂自动化升级改造(2)-RS485与Modbus通信协议

在工业控制、电力通信、智能仪表等领域,数据交换通常依赖于串口通信。最初,RS232接口是主流选择,然而,由于工业现场的复杂性,各种电气设备产生的电磁干扰可能导致信号传输错误。 RS232和RS485是两种不同的串行通信协议,它们在电气特性、传输距离和拓扑结构等方面有所不同…

基于springboot的篮球联盟管理系统

文章目录 项目介绍主要功能截图&#xff1a;部分代码展示设计总结项目获取方式 &#x1f345; 作者主页&#xff1a;超级无敌暴龙战士塔塔开 &#x1f345; 简介&#xff1a;Java领域优质创作者&#x1f3c6;、 简历模板、学习资料、面试题库【关注我&#xff0c;都给你】 &…

长难句打卡5.8

If it is trying to upset Google, which relies almost wholly on advertising, it has chosen an indirect method: there is no guarantee that DNT by default will become the norm. 如果它想激怒几乎全靠广告业务运营的谷歌公司的话&#xff0c;那么它选择了一个间接的方…

目标检测CNN 目标检测发展历程 应用场景 智慧交通 自动驾驶 工业生产 智慧医疗

目标检测 目标检测是计算机视觉领域中的一个重要任务,其主要目的是让计算机能够自动识别图像或视频帧中所有目标的类别,并在目标周围绘制边界框以标示出每个目标的位置。 目标检测的过程通常包括两个主要步骤:目标定位和目标分类。目标定位是确定图像中是否存在感兴趣的目…

【功耗问题排查】

一、如何处理具体功耗case 在手机功耗测试中&#xff0c;因为我们在功耗测试中&#xff08;电源电压&#xff09;为固定值&#xff08;老手机一般为3.8V左右&#xff0c;现在的大多项目采用4V左右&#xff09;&#xff0c;那么的大小直接由决定&#xff0c;所以&#xff0c;在沟…
最新文章