LLM - Baichuan7B Tokenizer 生成训练数据

目录

一.引言

二.Tokenizer 原始数据

1.原始数据样例

2.加载并 Token 原始数据

2.1 参数准备

2.2 单条样本处理逻辑

2.3 批量处理逻辑

2.4 主函数与完整代码

三.shell 执行

四.总结


一.引言

前面提到了自己在微调 Baichuan7B Lora 的过程中遇到了一些问题,后面通过调整已经调通。鉴于自己刚刚从推荐算法转 AIGC,所以用笔记的形式记录下用于后面查漏补缺以及对 API 的熟悉。本文主要介绍 LORA 微调时原始数据的处理与编码,即 encode By tokenizer,最终生成可用的 Dataset。

二.Tokenizer 原始数据

1.原始数据样例

{"q": "请计算:39 * 0 = 什么?", "a": "这是简单的乘法运算,39乘以0得到的是0"}
{"q": "题目:51/186的答案是什么?", "a": "这是简单的除法运算,51除以186大概为0.274"}
{"q": "鹿妈妈买了24个苹果,她想平均分给她的3只小鹿吃,每只小鹿可以分到几个苹果?", "a":"鹿妈妈买了24个苹果,平均分给3只小鹿吃,那么每只
小鹿可以分到的苹果数就是总苹果数除以小鹿的只数。\n24÷3=8\n每只小鹿可以分到8个苹果。所以,答案是每只小鹿可以分到8个苹果。"}
{"q": "请计算:39 * 0 = 什么?", "a": "这是简单的乘法运算,39乘以0得到的是0"}
{"q": "题目:51/186的答案是什么?", "a": "这是简单的除法运算,51除以186大概为0.274"}
{"q": "鹿妈妈买了24个苹果,她想平均分给她的3只小鹿吃,每只小鹿可以分到几个苹果?", "a": "鹿妈妈买了24个苹果,平均分给3只小鹿吃,那么每>只小鹿可以分到的苹果数就是总苹果数除以小鹿的只数。\n24÷3=8\n每只小鹿可以分到8个苹果。所以,答案是每只小鹿可以分到8个苹果。"}
{"q": "请计算:39 * 0 = 什么?", "a": "这是简单的乘法运算,39乘以0得到的是0"}
{"q": "题目:51/186的答案是什么?", "a": "这是简单的除法运算,51除以186大概为0.274"}
{"q": "鹿妈妈买了24个苹果,她想平均分给她的3只小鹿吃,每只小鹿可以分到几个苹果?", "a": "鹿妈妈买了24个苹果,平均分给3只小鹿吃,那么每>只小鹿可以分到的苹果数就是总苹果数除以小鹿的只数。\n24÷3=8\n每只小鹿可以分到8个苹果。所以,答案是每只小鹿可以分到8个苹果。"}

这里 q 可以理解为 question,a 可以理解为 answer,上面将基础的训练数据重复了几次生成原始的训练文件 simple.json。

2.加载并 Token 原始数据

2.1 参数准备

import argparse
import json
from tqdm import tqdm
import datasets
import transformers

# 1.参数准备
parser = argparse.ArgumentParser()
parser.add_argument("--model_checkpoint", type=str, help="checkpoint, like `THUDM/chatglm-6b`") # 必填
parser.add_argument("--input_file", type=str, help="Instruction 数据文件地址,文件中每一行都是json格式,包含一个输出和一个输出") # 必填
parser.add_argument("--prompt_key", type=str, default=f"prompt", help="你的jsonl文件里,Instruction 的输入字段是什么") # 选填
parser.add_argument("--target_key", type=str, default=f"target", help="你的jsonl文件里,Instruction 的输出字段是什么") # 必填
parser.add_argument("--save_name", type=str, default=f"temp", help="经过tokenize之后的数据集的存放位置") # 选填
parser.add_argument("--max_seq_length", type=int, default=2040) # 选填
parser.add_argument("--skip_overlength", type=bool, default=False) # 选填
args = parser.parse_args()

参数采用 argparse 类进行初始化:

- model_checkpoint : 预训练模型地址,这里我们提前把 Baichuan7B 或者 ChatGLM 下载好即可

- input_file : 原始训练数据,训练数据格式为 json,可以参考上面的数据示例

- prompt_key : 训练数据在 json 里 prompt 提示对应的 key,上例为 q

- target_key : 训练数据在 json 里 target 提示对应的 key,上例为 a

- save_name : 保存地址,数据最终会议 arrow 的数据将 dataset 保存

- max_seq_length : 最长阶段序列长度

- skip_overlength : 是否忽略超长的文本,True 时忽略,False 时采取截断

2.2 单条样本处理逻辑

以 json 里一条样本为例:

{"q": "请计算:39 * 0 = 什么?", "a": "这是简单的乘法运算,39乘以0得到的是0"}
def preprocess(tokenizer, config, example, max_seq_length, prompt_key, target_key):
    prompt = example[prompt_key]
    target = example[target_key]
    prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
    target_ids = tokenizer.encode(target, max_length=max_seq_length, truncation=True, add_special_tokens=False)
    # 最终还是将 instruction 的输入输出都拼在一起,使用经典的 causal-LM 的 next word prediction 方式来训练
    input_ids = prompt_ids + target_ids + [config.eos_token_id] # EOS 用于标识句子结束
    return {"input_ids": input_ids, "seq_len": len(prompt_ids)}

根据配置的 prompt_key 和 target_key 获取 json 里对应的 prompt 与 target 内容,本例下 prompt_key = "q",target_key = "a",通过加载预训练模型获取对应的 Tokenizer 对 q、a 的文本进行 encode 编码。

Q: 请计算:39 * 0 = 什么?
A: 这是简单的乘法运算,39乘以0得到的是0
TokenQ: [31010, 6184, 77, 55, 61, 1734, 31106, 52, 1147, 31106, 1534, 75]
TokenA: [31106, 3908, 14313, 32329, 31257, 31481, 31742, 72, 55, 61, 32329, 31187, 52, 5442, 2585, 52]

为什么把 QA 前后连接拼到一起,上面的注释也给出了原因,该样本用于使用 causal-LM 模型进行 next word 的预测即续写功能的训练。通过将 Q 放在 A 前面训练,学习 QA 的前后文字逻辑。未来模型训练完毕后,我们给出 Q,模型机会根据之前的训练续写出 A 的相关内容。

2.3 批量处理逻辑

上面 preprocess 的逻辑主要在 read_json 里调用,该方法主要用于加载预训练模型生成 Tokenizer 与 config,

def read_json(path, max_seq_length, prompt_key,target_key,skip_overlength=False):
    # 基于预训练模型加载获取 tokenizer 和 config
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_checkpoint, trust_remote_code=True)
    config = transformers.AutoConfig.from_pretrained(
        model_checkpoint, trust_remote_code=True, device_map='auto')
    with open(path, "r") as f:
        for line in tqdm(f.readlines()):
            example = json.loads(line)
            feature = preprocess(tokenizer, config, example, max_seq_length,prompt_key,target_key)
            if skip_overlength and len(feature["input_ids"]) > max_seq_length:
                continue
            # 截取最大长度
            feature["input_ids"] = feature["input_ids"][:max_seq_length]
            yield feature

json.loads 加载一条样本随后调用 preprocess 生成训练 json,这里会根据 skip_overlength 参数决定是否忽略超长样本,最后返回 feature json。这里 tqdm 用于为迭代器 iterator 生成一个可视化的进度条,是一个辅助类。

本着一个参数都不放过的原则,博主查阅了模型加载中用到的两个参数含义:

- trust_remote_code

该参数指示系统在执行远程或外部代码时如何处理安全性和信任性。如果 "trust_remote_code" 设置为 True,则系统将信任并执行远程或外部提供的代码,而不进行严格的安全检查或验证。反之则系统会采取更谨慎的做法,并对远程或外部提供的代码进行安全性检查和验证,以确保其不会造成潜在的风险或恶意操作。这是一种常见的安全策略,用于防止恶意代码或攻击者利用远程执行漏洞来入侵系统。由于我们一般加载的都是官方认可的预训练模型,例如 Baichuan7B、ChatGLM 等等,所以一般看到的代码里都是 True。

- device_map

该参数用于指定设备映射或设备配置的相关信息。可以使用 map 将任务分配给特定的硬件设备或资源,当然也可以像上面一样使用 auto。

2.4 主函数与完整代码

# 输入文件统一放在 data 文件夹下
# 输出文件统一放在 data/tokenized_data 文件夹下
input_file_path = f'data/{args.input_file}'
save_path = f"data/tokenized_data/{args.save_name}"
dataset = datasets.Dataset.from_generator(
    lambda: read_jsonl(input_file_path, args.max_seq_length, args.prompt_key,args.target_key,args.skip_overlength)
)

dataset.save_to_disk(save_path)

这里默认原始训练文件 json 存放在 data 文件夹下,经过 tokenizer 的样本放在 data/tokenized_data 目录下,当然也可以根据自己习惯调整,这个位置影响不大。根据路径调用 datasets.Dataset 的 API 进行 DataSet 的生成与存储。

- 完整代码

import argparse
import json
from tqdm import tqdm
import datasets
import transformers

# 1.参数准备
parser = argparse.ArgumentParser()
parser.add_argument("--model_checkpoint", type=str, help="checkpoint, like `THUDM/chatglm-6b`") # 必填
parser.add_argument("--input_file", type=str, help="Instruction 数据文件地址,文件中每一行都是json格式,包含一个输出和一个输出") # 必填
parser.add_argument("--prompt_key", type=str, default=f"prompt", help="你的jsonl文件里,Instruction 的输入字段是什么") # 选填
parser.add_argument("--target_key", type=str, default=f"target", help="你的jsonl文件里,Instruction 的输出字段是什么") # 必填
parser.add_argument("--save_name", type=str, default=f"temp", help="经过tokenize之后的数据集的存放位置") # 选填
parser.add_argument("--max_seq_length", type=int, default=2040) # 选填
parser.add_argument("--skip_overlength", type=bool, default=False) # 选填
args = parser.parse_args()
model_checkpoint = args.model_checkpoint

#. 2.处理逻辑
def preprocess(tokenizer, config, example, max_seq_length, prompt_key, target_key):
    prompt = example[prompt_key]
    target = example[target_key]
    prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
    target_ids = tokenizer.encode(target, max_length=max_seq_length, truncation=True, add_special_tokens=False)
    # 最终还是将 instruction 的输入输出都拼在一起,使用经典的 causal-LM 的 next word prediction 方式来训练
    input_ids = prompt_ids + target_ids + [config.eos_token_id] # EOS 用于标识句子结束
    return {"input_ids": input_ids, "seq_len": len(prompt_ids)}

# 3.读取训练 JSON
def read_jsonl(path, max_seq_length, prompt_key,target_key,skip_overlength=False):
    # 基于预训练模型加载获取 tokenizer 和 config
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_checkpoint, trust_remote_code=True)
    config = transformers.AutoConfig.from_pretrained(
        model_checkpoint, trust_remote_code=True, device_map='auto')
    with open(path, "r") as f:
        for line in tqdm(f.readlines()):
            example = json.loads(line)
            feature = preprocess(tokenizer, config, example, max_seq_length,prompt_key,target_key)
            if skip_overlength and len(feature["input_ids"]) > max_seq_length:
                continue
            # 截取最大长度
            feature["input_ids"] = feature["input_ids"][:max_seq_length]
            yield feature


# 输入文件统一放在 data 文件夹下
# 输出文件统一放在 data/tokenized_data 文件夹下
input_file_path = f'data/{args.input_file}'
save_path = f"data/tokenized_data/{args.save_name}"
dataset = datasets.Dataset.from_generator(
    lambda: read_jsonl(input_file_path, args.max_seq_length, args.prompt_key,args.target_key,args.skip_overlength)
)

dataset.save_to_disk(save_path)


三.shell 执行

simple.json 为我们的测试样例,tokenizer_data 为存储 token 后 DataSet 的地址。下面看下 tokenizer.sh 的 shell 脚本:

baichuan="/model/baichuan-7B"

input=simple.json

CUDA_VISIBLE_DEVICES=0 python tokenize_dataset_rows.py \
    --model_checkpoint $baichuan \
    --input_file $input \
    --prompt_key q \
    --target_key a \
    --save_name simple_token_by_baichuan-7B \
    --max_seq_length 2000 \
    --skip_overlength False

执行上述脚本即可得到 tokenizer 后的数据:

当使用 dataset.save_to_dist 方法保存数据集合时会生成三个文件: 

dataset.arrow: 这是主要的数据文件,其中包含数据集的实际内容。它以 Apache Arrow 格式存储,这种格式旨在高效地存储和处理大规模数据集。该文件可能包含数据样本、标签、特征、元数据等。

dataset.info.json: 这个 JSON 文件包含与数据集相关的元信息。它提供了关于数据集结构、列名称、数据类型、特征信息、统计摘要等详细信息。通过读取此文件,可以获得数据集的描述性信息,以便更好地理解数据的组织和特征。

dataset.state.json: 这个 JSON 文件包含数据集的状态信息,例如上次更新的时间戳、版本号、数据集大小等。它记录了数据集的状态和元数据,以便在后续操作中能够恢复到相同的点,并确保数据集的一致性。

四.总结

基于 10 条左右的样本基于 Baichuan7B 微调后,我们测试了原始模型与 Lora 后的效果:

可以看到原始模型存在续写混乱的问题,前面再说查字典,后面又在介绍字典,而经过 Lora 微调后,模型已经能够回答该问题,但是如果换个问法可能就又答不上来了,所以 prompt 工程和样本工程在 AIGC 中是重要组成部分。后面我们将基于 tokenized_data 介绍如何进行 Lora 模型训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/36733.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

leetcode 236. 二叉树的最近公共祖先

2023.7.11 这道题是道面试高频题,并且有点抽象。 首先确定终止条件。如果根节点为空,或者其中一个节点是根节点本身(即 p root 或 q root),那么根节点就是它们的最低共同祖先,因此我们直接返回根节点 roo…

产品经理怎么管理项目进度?

作为在职七年的项目管理人员,在项目进度管理上确实有一点发言权。产品经理作为企业的核心骨干岗位之一,在进行项目进度管理时也会有很多问题出现,那么应该怎样去管理项目进度呢?以下是答主的一些拙见,有需要的朋友们就…

接口测试之postman使用详解

我们平常要做接口测试时,可能需要使用一些工具,其实最简单的的做接口测试的工具就是postman,它可以用来模拟http中的get、post接口等,然后我们去验证接口的返回参数及数据是否符合我们的逻辑。那么怎么使用呢?也就是今…

C++之工厂模式

目录 一、为什么要使用工厂模式 优点 缺点 二、简单工厂(Simple Factory) 好处: 不足: 三、工厂方法: 好处: 不足: 四、抽象工厂(Abstract Factory) 一、为什…

【工具推荐】企业微信、企业飞书接口调用工具

github地址: GitHub - fasnow/idebug: 企业微信、企业飞书接口调用工具。 简介 企业微信、企业飞书接口调用工具。 使用方法 wechat模块 使用use wechat 选择模块。 首先设置corpid和corpsecret,如有需要可以设置代理,之后再执行run命令。 导出通信…

chatgpt 与传统3D建模对比分析

推荐:将NSDT场景编辑器加入你的3D工具链 随着人工智能技术的发展,越来越多的领域正逐渐被AI模型所取代。ChatGPT作为一种自然语言处理技术,越来越为人们所熟悉。最近,一些3D建模领域的专家想知道ChatGPT是否可以取代传统的手动3D建…

在?聊聊浏览器事件循环机制

目录 前言 同步/异步编程模型 同步 异步 JS异步模型 调用栈 任务队列 宏任务队列 微任务队列 微任务API 事件循环 队列优先级 混合队列 事件循环实现 总结 参考文章 Event-Loop可视化工具 前言 JS是单线程语言,在某个时间段只能执行一段代码。这…

IP地址定位技术为何如此准确?揭秘背后原理

据最新数据显示,全球互联网用户数量已突破50亿。为确保用户安全和提供个性化服务,IP地址定位技术愈发重要。但你是否好奇,为何IP地址定位如此准确?今天我们将揭秘其背后原理。 IP地址定位技术利用了多种方法来确定用户的地理位置。…

mac苹果电脑,怎么把mkv转换mp4格式

mac苹果电脑,怎么把mkv转换mp4格式?如果你是一名mac苹果电脑的用户,在电脑上下载到mkv格式的视频后会发现它使用起来非常的麻烦,甚至不能直接打开播放。mkv其实也是一种时间比较久远的视频文件格式,但是不知道是什么原…

MAC电脑查看SHA256方式

背景 现在很多网站下载大文件时,以往通过查看文件大小来确定是否下载正确,但是很多情况下,文件下载后大小差不多,但是很多时候却时候出现无法安装的问题,有可能还是下载的文件出现错误,导致文件无法正常使…

研发效能认证学员作品:使用威胁建模进行DevSecOps实践

一、从DevOps到 DevSecOps 作者: 姚圣伟(现就职天津引元科技 天津市区块链技术创新中心) 研发效能(DevOps)工程师认证学员 DevOps 最开始最要是强调开发和运维的协作与配合,至今,已不仅仅涉及开…

【运维工程师学习二】OS系统管理

【运维工程师学习二】OS系统管理 1、操作系统管理2、进程管理3、进程的启动4、进程信息的查看4.1、STAT 进程的状态:进程状态使用字符表示的(STAT的状态码),其状态码对应的含义:4.2、ps命令常用用法(方便查看系统进程&…

stm32(独立看门狗和窗口看门狗)

独立看门狗介绍 什么是看门狗? 在由单片机构成的微型计算机系统中,由于单片机的工作常常会受到来自外界电磁场的干扰,造 成程序的跑飞,而陷入死循环,程序的正常运行被打断,由单片机控制的系统无法继续工作…

Vue3 网络请求——axios 高级用法之 axios 拦截器实战与并发请求

文章目录 📋前言🎯关于拦截器🎯项目创建🎯代码分析🎯补充:并发请求🧩axios.all() 和 Promise.all() 的区别 📝最后 📋前言 Axios 是一个流行的基于 Promise 的 HTTP 客户…

spring系列所有漏洞vulhub复现CVE-2022-22978、CVE-2022-22963、CVE-2022-22965、CVE-2018-1273

文章目录 CVE-2022-22978 Spring-security 认证绕过漏洞漏洞描述:复现: CVE-2022-22963漏洞描述:复现: 提提神Spring框架Data Binding与JDK 9导致的远程代码执行漏洞(CVE-2022-22965)漏洞描述:复现: Spring Data Commo…

机器学习笔记:随机森林

1 集成学习 集成学习通过构建多个学习器采用加权的方式来完成学习任务一般来讲,多个学习器同属于一种模型,比如决策树,线性模型,而不会交叉用多种模型为了保证集成学习的有效性,多个弱分类器之间应该满足两个条件 准确…

【附3.7安装包】python安装包下载及安装(超详细)

python3.7链接:https://pan.baidu.com/s/1Ett3XBMjWhkVOxkOU8NRqw?pwdqz3l 提取码:qz3l 今日资源:Python 适用系统:WINDOWS ​ Python 3.7.0 软件介绍: Python是一款通用型的计算机程序设计语言,Pytho…

4.5 x64dbg 探索钩子劫持技术

钩子劫持技术是计算机编程中的一种技术,它们可以让开发者拦截系统函数或应用程序函数的调用,并在函数调用前或调用后执行自定义代码,钩子劫持技术通常用于病毒和恶意软件,也可以让开发者扩展或修改系统函数的功能,从而…

unbuntu 22.04 安装和卸载企业微信

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 记录有关在ubuntu22.04上安装和卸载企业微信 以及企业微信无法打开问题处理 1. 正文 1.1 安装 下载wine环境 http://archive.ubuntukylin.com/softwar…

【JMeter】同步定时器Synchronizing Timer集合点功能

LoadRunner 中有一个可以设置集合点的功能,顾名思义是设置多个虚拟用户等待到一个时间点,都到齐集合后一起发请求达到并发的目的 集合点是什么意思呢? 阻塞线程,直到指定的线程数量到达后,再一起释放,可以…
最新文章