HuggingFace学习笔记--Prompt-Tuning高效微调

1--Prompt-Tuning介绍

        Prompt-Tuning 高效微调只会训练新增的Prompt的表示层,模型的其余参数全部固定;

        新增的 Prompt 内容可以分为 Hard Prompt 和 Soft Prompt 两类;

        Soft prompt 通常指的是一种较为宽泛或模糊的提示,允许模型在生成结果时有更大的自由度,通常用于启发模型进行创造性的生成;

        Hard prompt 是一种更为具体和明确的提示,要求模型按照给定的信息生成精确的结果,通常用于需要模型提供准确答案的任务;

        Soft Prompt 在 peft 中一般是随机初始化prompt的文本内容,而 Hard prompt 则一般需要设置具体的提示文本内容;

2--实例代码

from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from transformers import pipeline, TrainingArguments, Trainer
from peft import PromptTuningConfig, get_peft_model, TaskType, PromptTuningInit, PeftModel

# 分词器
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")

# 函数内将instruction和response拆开分词的原因是:
# 为了便于mask掉不需要计算损失的labels, 即代码labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

if __name__ == "__main__":
    # 加载数据集
    dataset = load_from_disk("./PEFT/data/alpaca_data_zh")
    
    # 处理数据
    tokenized_ds = dataset.map(process_func, remove_columns = dataset.column_names)
    # print(tokenizer.decode(tokenized_ds[1]["input_ids"]))
    # print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))
    
    # 创建模型
    model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)
    
    # 设置 Prompt-Tuning
    # Soft Prompt
    # config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化
    # Hard Prompt
    config = PromptTuningConfig(task_type = TaskType.CAUSAL_LM,
                                prompt_tuning_init = PromptTuningInit.TEXT,
                                prompt_tuning_init_text = "下面是一段人与机器人的对话。", # 设置hard_prompt的具体内容
                                num_virtual_tokens = len(tokenizer("下面是一段人与机器人的对话。")["input_ids"]),
                                tokenizer_name_or_path = "Langboat/bloom-1b4-zh")
    model = get_peft_model(model, config) # 生成Prompt-Tuning对应的model
    print(model.print_trainable_parameters())
    
    # 训练参数
    args = TrainingArguments(
        output_dir = "/tmp_1203",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        logging_steps = 10,
        num_train_epochs = 1
    )
    
    # trainer
    trainer = Trainer(
        model = model,
        args = args,
        train_dataset = tokenized_ds,
        data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, padding = True)
    )
    
    # 训练模型
    trainer.train()
    
    # 模型推理
    model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)
    peft_model = PeftModel.from_pretrained(model = model, model_id = "/tmp_1203/checkpoint-500/")
    peft_model = peft_model.cuda()
    ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(peft_model.device)
    print(tokenizer.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/211832.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

规则引擎专题---3、Drools组成和入门

Drools概述 drools是一款由JBoss组织提供的基于Java语言开发的开源规则引擎,可以将复杂且多变的业务规则从硬编码中解放出来,以规则脚本的形式存放在文件或特定的存储介质中(例如存放在数据库中),使得业务规则的变更不需要修改项目代码、重启…

初识RocketMQ

1、简介 RocketMQ 是阿里巴巴在 2012 年开源的消息队列产品,用 Java 语言实现,在设计时参考了 Kafka,并做出了自己的一些改进,后来捐赠给 Apache 软件基金会,2017 正式毕业,成为 Apache 的顶级项目。Rocket…

avue-crud中时间范围选择默认应该是0点却变成了12点

文章目录 一、问题二、解决三、最后 一、问题 在avue-crud中时间范围选择&#xff0c;正常默认应该是0点&#xff0c;但是不知道怎么的了&#xff0c;选完之后就是一直是12点。具体问题如下动图所示&#xff1a; <template><avue-crud :option"option" /&g…

YOLOv8改进 | 2023 | SCConv空间和通道重构卷积(精细化检测,又轻量又提点)

一、本文介绍 本文给大家带来的改进内容是SCConv&#xff0c;即空间和通道重构卷积&#xff0c;是一种发布于2023.9月份的一个新的改进机制。它的核心创新在于能够同时处理图像的空间&#xff08;形状、结构&#xff09;和通道&#xff08;色彩、深度&#xff09;信息&#xf…

计算机组成原理笔记——存储器(静态RAM和动态RAM的区别,动态RAM的刷新, ROM……)

■ 随机存取存储器 ■ 1.随机存取存储器&#xff1a;按存储信息的原理不同分为&#xff1a;静态RAM和动态RAM 2.静态RAM&#xff08;SRAM&#xff09;&#xff1a;用触发器工作原理存储信息&#xff0c;但电源掉电时&#xff0c;存储信息会丢失具有易失性。 3.存储器的基本单元…

代码随想录算法训练营第三十四天|62.不同路径,63. 不同路径 II

62. 不同路径 - 力扣&#xff08;LeetCode&#xff09; 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为 “Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff08;在下图中标记为 “Finish” &#…

Kubernetes技术与架构-策略

Kubernetes集群提供系统支持的策略&#xff0c;也提供开放接口给第三方定义的策略&#xff0c;这些策略用于可定义的配置文件或者Kubernetes集群的运行时环境&#xff0c;其中包括进程ID数量的申请与限制策略&#xff0c;服务器节点Node内的进程ID的数量限制策略&#xff0c;Po…

PyQt6 QCheckBox复选框按钮控件

​锋哥原创的PyQt6视频教程&#xff1a; 2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~共计33条视频&#xff0c;包括&#xff1a;2024版 PyQt6 Python桌面开发 视频教程(无废话…

深入了解c语言中的结构体

介绍&#xff1a; 在C语言中&#xff0c;结构体是一种用户自定义的数据类型&#xff0c;它允许我们将不同类型的数据组合在一起&#xff0c;形成一个更为复杂的数据结构。结构体可以用来表示现实世界中的实体&#xff0c;如人员、学生、图书等。本篇博客将介绍结构体的基本概念…

iceoryx(冰羚)-进程间消息同步

iceoryx进程间消息同步 iceoryx进程间消息同步&#xff0c;是用socket或管道实现的,定义在iceoryx\iceoryx_posh\include\iceoryx_posh\internal\runtime\ipc_interface_base.hpp namespace platform { #if defined(_WIN32) using IoxIpcChannelType iox::posix::NamedPipe; …

Linux下为可执行文件添加图标

Ubuntu 18.04上使用Qt5.14.2创建一个简单的Qt Widgets项目test&#xff0c;添加2个Push Button按钮&#xff0c;点击分别获取github和csdn地址&#xff0c;在mainwindow.cpp中添加的代码如下: #include "mainwindow.h" #include "ui_mainwindow.h" #inclu…

Linux 互斥锁 读写锁 条件变量 信号量 (备查)

线程同步 1&#xff09;所谓的同步并不是多个线程同时对内存进行访问&#xff0c;而是按照先后顺序依次进行的。 2&#xff09;如没有对线程进行同步处理&#xff0c;会导致多个线程访问共享资源出现数据混乱的问题。 3&#xff09;所谓共享资源就是多个线程共同访问的变量&…

javaweb校车校园车辆管理系统springboot+jsp

结构设计&#xff1a;总体采用B/S结构设计模式 (1)用户登录模块&#xff1a;用户通过手动登录&#xff0c;检测是否是校内人员的车辆。 (2)用户车辆信息编辑、上传、模块&#xff1a;通过上传车辆入场信息的操作权限&#xff0c;以用户的名义发布资料上传至校园停车场系统中。…

可视化数据库管理客户端:Adminer

简介&#xff1a;Adminer&#xff08;前身为phpMinAdmin&#xff09;是一个用PHP编写的功能齐全的数据库管理工具。与phpMyAdmin相反&#xff0c;它由一个可以部署到目标服务器的文件组成。Adminer可用于MySQL、PostgreSQL、SQLite、MS SQL、Oracle、Firebird、SimpleDB、Elast…

Java+SSM+MySQL基于微信的在线协同办公小程序(附源码 调试 文档)

基于微信的在线协同办公小程序 一、引言二、系统设计三、技术架构四、管理员功能设计五、员工功能设计六、系统实现七、界面展示八、源码获取 一、引言 随着科技的飞速发展&#xff0c;移动互联网已经深入到我们生活的各个角落。在这个信息时代&#xff0c;微信作为全球最大的…

头歌JUnit单元测试相关实验入门

一、入门实验 1.1第一个Junit测试程序 任务描述 请学员写一个名为testSub()的测试函数&#xff0c;来测试给定的减法函数是否正确。 相关知识 Junit编写原则 1、简化测试的编写&#xff0c;这种简化包括测试框架的学习和实际测试单元的编写。 2、测试单元保持持久性。 3、利用…

【Python】Python给工作减负-读Excel文件生成xml文件

目录 ​前言 正文 1.Python基础学习 2.Python读取Excel表格 2.1安装xlrd模块 2.2使用介绍 2.2.1常用单元格中的数据类型 2.2.2 导入模块 2.2.3打开Excel文件读取数据 2.2.4常用函数 2.2.5代码测试 2.2.6 Python操作Excel官方网址 3.Python创建xml文件 3.1 xml语法…

计算机组成原理,硬件组成,存储器,控制器,控制器的任务, 运算器,中央处理器CPU,主存

计算机组成原理 课程需求 前导课程&#xff1a; 后继课程 汇编 操作系统 数逻 组成 系统结构 数电 微机原理 课程结构 计算机特性 1 从外部角度来看计算机的特性 快速 通用 准确 逻辑 2从外部特性与内部特性的关系 计算机组成 一 硬件组成 运算器 主要功能是进行算术…

强化学习(一)——基本概念及DQN

1 基本概念 智能体 agent &#xff0c;做动作的主体&#xff0c;&#xff08;大模型中的AI agent&#xff09; 环境 environment&#xff1a;与智能体交互的对象 状态 state &#xff1b;当前所处状态&#xff0c;如围棋棋局 动作 action&#xff1a;执行的动作&#xff0c;…

CRM系统是怎样帮助销售流程自动化的?

销售业绩是衡量企业经营的重要指标&#xff0c;也是销售人员一直要达成的目标。销售业绩能否提高取决于销售人员的能力、客户服务水平&#xff0c;还需要借助有效的工具。CRM系统就是这样的一款软件。企业如何提高销售业绩&#xff1f;不妨试试CRM销售流程自动化。 CRM如何实现…
最新文章