基于Optuna的transformers模型自动调参

文章目录

  • 一、导入相关包
  • 二、加载数据集
  • 三、划分数据集
  • 四、数据集预处理
  • 五、创建模型(区别一)
  • 六、创建评估函数
  • 七、创建 TrainingArguments(区别二)
  • 八、创建 Trainer(区别三)
  • 九、模型训练
  • 十、模型训练(自动搜索)(区别四)
  • 启动 tensorboard

  • 以文本分类为例

六、Trainer和文本分类
image.png
image.png


一、导入相关包

!pip install transformers datasets evaluate accelerate
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

二、加载数据集

dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset = dataset.filter(lambda x: x["review"] is not None)
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 7765
})
'''

三、划分数据集

datasets = dataset.train_test_split(test_size=0.1)
datasets
'''
DatasetDict({
    train: Dataset({
        features: ['label', 'review'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['label', 'review'],
        num_rows: 777
    })
})
'''

四、数据集预处理

import torch

tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")

def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples

tokenized_datasets = datasets.map(process_function, batched=True, 
                                  remove_columns=datasets["train"].column_names)
tokenized_datasets
'''
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 777
    })
})
'''

五、创建模型(区别一)

def model_init():
    model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")
    return model

六、创建评估函数

import evaluate

acc_metric = evaluate.load("accuracy")
f1_metirc = evaluate.load("f1")
def eval_metric(eval_predict):
    predictions, labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metirc.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc

七、创建 TrainingArguments(区别二)

  • logging_steps=500为了防止多次训练 log 太多可以增大 logging_steps
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹
                               per_device_train_batch_size=64,  # 训练时的batch_size
                               per_device_eval_batch_size=128,  # 验证时的batch_size
                               logging_steps=500,               # log 打印的频率
                               evaluation_strategy="epoch",     # 评估策略
                               save_strategy="epoch",           # 保存策略
                               save_total_limit=3,              # 最大保存数
                               learning_rate=2e-5,              # 学习率
                               weight_decay=0.01,               # weight_decay
                               metric_for_best_model="f1",      # 设定评估指标
                               load_best_model_at_end=True)     # 训练完成后加载最优模型

八、创建 Trainer(区别三)

  • 没有指定 model而是指定 model_init
from transformers import DataCollatorWithPadding
trainer = Trainer(model_init=model_init, 
                  args=train_args, 
                  train_dataset=tokenized_datasets["train"], 
                  eval_dataset=tokenized_datasets["test"], 
                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                  compute_metrics=eval_metric)


# 之前
from transformers import DataCollatorWithPadding
trainer = Trainer(model=model,
                  args=train_args,
                  train_dataset=tokenized_datasets["train"],
                  eval_dataset=tokenized_datasets["test"],
                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                  compute_metrics=eval_metric)

九、模型训练

trainer.train()

十、模型训练(自动搜索)(区别四)

!pip install optuna
  • 使用默认的超参数空间
  • compute_objective=lambda x: x["eval_f1"]中的 x是指的评价函数的返回值,在这里因为没有显示的指定评价函数返回值的 key,所以 f1key采用默认值 eval_f1
trainer.hyperparameter_search(compute_objective=lambda x: x["eval_f1"], direction="maximize", n_trials=10)
  • 自定义超参数空间
    • 可以在default_hp_space_optuna 函数中增加 trainer 的选项
def default_hp_space_optuna(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
        "seed": trial.suggest_int("seed", 1, 40),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
        "optim": trial.suggest_categorical("optim", ["sgd", "adamw_hf"]),
    }

trainer.hyperparameter_search(hp_space=default_hp_space_optuna, compute_objective=lambda x: x["eval_f1"], direction="maximize", n_trials=10)

启动 tensorboard

  • 进入运行日志文件夹
    • 终端启动
!tensorboard --logdir runs
  • jupyter 启动
# 运行这行代码将加载 TensorBoard并允许我们将其用于可视化
%reload_ext tensorboard 
%tensorboard --logdir=./runs/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/163333.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MyBatis逆向工程

新建Maven工程 <build><plugins><plugin><!--mybatis代码自动生成插件--><groupId>org.mybatis.generator</groupId><artifactId>mybatis-generator-maven-plugin</artifactId><version>1.3.6</version><confi…

Ubuntu20.04 安装微信 【优麒麟的镜像源方式安装】

缺点&#xff1a;是网页版本的嵌入&#xff0c;功能少。 推荐wine方式安装&#xff1a;Ubuntu20.04 安装微信 【wine方式安装】推荐 从优麒麟的镜像源安装原生微信 应用下载-优麒麟&#xff5c;Linux 开源操作系统 新建文件software.list sudo vi /etc/apt/sources.list.d/…

操作系统(五)| 文件系统上 结构 存取方式 文件目录 检索

文章目录 1 文件系统概述2 文件的结构与存取方式2.1 磁盘2.2 文件的物理结构2.2.1 连续结构2.2.2 链式结构2.2.3 索引结构 2.3 文件的存取方式 3 文件目录3.1 基本概念3.2 目录结构单级目录结构多级目录结构 3.3 文件目录检索3.3.1 目录检索文件寻址 3.4 文件目录的实现 1 文件…

docker容器自启动

场景 当服务器关机重启后&#xff0c;docker容器每次都要去docker start 容器id 怎么可以下次让它自启动呢&#xff1f; 解决 先 # docker ps -a 查到之前启动过的容器id # docker update --restartalways 容器id重启后&#xff0c;reboot&#xff0c;就不用再单独去启动容…

Mol-Instructions:大模型赋能,药物研发新视野

论文标题&#xff1a;Mol-Instructions: A Large-Scale Biomolecular Instruction Dataset for Large Language Models 论文链接&#xff1a; https://arxiv.org/pdf/2306.08018.pdf Github链接&#xff1a; https://github.com/zjunlp/Mol-Instructions 模型下载&#xf…

Docker 可视化面板 ——Portainer

Portainer 是一个非常好用的 Docker 可视化面板&#xff0c;可以让你轻松地管理你的 Docker 容器。 官网&#xff1a;Portainer: Container Management Software for Kubernetes and Docker 【Docker系列】超级好用的Docker可视化工具——Portainer_哔哩哔哩_bilibili 环境 …

zabbix的安装配置,邮件告警,钉钉告警

zabbix监控架构 zabbix优点 开源&#xff0c;无软件成本投入server对设备性能要求低支持设备多&#xff0c;自带多种监控模板支持分布式集中管理&#xff0c;有自动发现功能&#xff0c;可以实现自动化监控开放式接口&#xff0c;扩展性强&#xff0c;插件编写容易当监控的item…

【Linux网络】详解使用http和ftp搭建yum仓库,以及yum网络源优化

目录 一、回顾yum的原理 1.1yum简介 yum安装的底层原理&#xff1a; yum的好处&#xff1a; 二、学习yum的配置文件及命令 1、yum的配置文件 2、yum的相关命令详解 3、yum的命令相关案例 三、搭建yum仓库的方式 1、本地yum仓库建立 2、通过http搭建内网的yum仓库 3、…

Android SdkManager简介

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、商业变现、人工智能等&#xff0c;希望大家多多支持。 目录 一、导读二、概览三、 安装使用3.1 安装3.2 使用3.3 选项…

vue项目中使用vant轮播图组件(桌面端)

一. 内容简介 vue使用vant轮播图组件(桌面端) 二. 软件环境 2.1 Visual Studio Code 1.75.0 2.2 chrome浏览器 2.3 node v18.14.0 三.主要流程 3.1 安装环境 3.2 添加代码 3.3 结果展示 四.具体步骤 4.1 安装环境 先安装包 # Vue 3 项目&#xff0c;安装最新版 Va…

Unity中Shader立方体纹理Cubemap

文章目录 前言一、什么是立方体纹理二、立方体纹理的生成方式1、使用6个面的生成方式2、使用单张图片的生成方式 三、Cubemap的采样方式四、在Unity中看一下Cubemap五、在Shader中&#xff0c;对立方体纹理进行采样使用1、我们在属性面板定义一个Cube类型的变量来存放立方体纹理…

二分查找算法合集

二分查找也称折半查找&#xff08;Binary Search&#xff09;&#xff0c;它是一种效率较高的查找方法。但是&#xff0c;折半查找要求线性表必须采用顺序存储结构&#xff0c;而且表中元素按关键字有序排列。 时间复杂度 O(logn) 自己写二分算法 左闭右开 左开右闭C算法&a…

java智慧校园信息管理系统源码带微信小程序

一、智慧校园的定义 智慧校园指的是以云计算和物联网为基础的智慧化的校园工作、学习和生活一体化环境。以各种应用服务系统为载体&#xff0c;将教学、科研、管理和校园生活进行充分融合&#xff0c;让校园实现无处不在的网络学习、融合创新的网络科研、透明高效的校务治理、…

OpenGL 坐标投影与反投影(Qt)

文章目录 一、简介1.1投影1.2反投影二、应用代码三、实现效果参考资料一、简介 在学习OpenGL一段时间之后,我们都会了解坐标的转换过程,如下图所示: 1.1投影 正如图中所述,OpenGL将一个3D坐标投影到一个2D空间主要有以下几个步骤,这也是我们比较熟知的几个步骤: 现实局部…

asp.net学生成绩评估系统VS开发sqlserver数据库web结构c#编程计算机网页项目

一、源码特点 asp.net 学生成绩评估系统 是一套完善的web设计管理系统&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。 系统运行视频连接&#xff1a;https://www.bilibili.com/video/BV1Wz4y1A7CG/ 二、功能介绍 本系统使用Microsof…

读像火箭科学家一样思考笔记02_与不确定性共舞(下)

1. 万有理论 1.1. 相对论 1.1.1. 适用于体积非常大的物体 1.2. 量子力学 1.2.1. 适用于非常小的物体 1.2.2. 在量子力学诞生之前&#xff0c;物理学一直强调的是因果关系&#xff0c;即做这件事&#xff0c;就会得到那个结果 1.2.3. 量子力学讲的似乎是&#xff1a;当我们…

Linux发展历程

<!DOCTYPE html> <html> <head> <meta charset"UTF-8"> <title>Linux历史发展</title> <style> /* CSS样式 */ body { font-family: Arial, sans-serif; margin: 0;…

【计算机网络笔记】IPv6简介

系列文章目录 什么是计算机网络&#xff1f; 什么是网络协议&#xff1f; 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能&#xff08;1&#xff09;——速率、带宽、延迟 计算机网络性能&#xff08;2&#xff09;…

C++二分查找算法:查找和最小的 K 对数字

相关专题 二分查找相关题目 题目 给定两个以 非递减顺序排列 的整数数组 nums1 和 nums2 , 以及一个整数 k 。 定义一对值 (u,v)&#xff0c;其中第一个元素来自 nums1&#xff0c;第二个元素来自 nums2 。 请找到和最小的 k 个数对 (u1,v1), (u2,v2) … (uk,vk) 。 示例 1:…

Linux shell编程学习笔记27:tputs

除了stty命令&#xff0c;我们还可以使用tput命令来更改终端的参数和功能。 1 tput 命令的功能 tput 命令的主要功能有&#xff1a;移动更改光标、更改文本显示属性&#xff08;如颜色、下划线、粗体&#xff09;&#xff0c;清除屏幕特定区域等。 2 tput 命令格式 tput [选…