HuggingFace模型头的自定义

 

在线工具推荐:  Three.js AI纹理开发包 -  YOLO合成数据生成器 -  GLTF/GLB在线编辑 -  3D模型格式在线转换 -  可编程3D场景编辑器

在本文中我们将介绍如何使HuggingFace的模型适应你的任务,在Pytorch中建立自定义模型头并将其连接到HF模型的主体,并端到端地训练系统。

1、HF模型头和模型体

这是典型的HF模型的样子:

为什么我需要单独使用模型头(Model Head)和模型体(Model Body)?

一些HF的模型针对下游任务(例如提问或文本分类)训练,并包含有关其权重培训的数据的知识。

有时,尤其是当我们手头的任务包含很少的数据或领域特定(例如医学或运动特定任务)时,我们可以在HUB上使用其他任务训练的模型(不一定与我们的任务相同的任务 手但属于相同领域,例如运动或药物),并利用一些验证的知识来提高我们模型在我们自己任务的性能表现。

  • 一个非常简单的例子是,如果说我们有一个小数据集,比如分类某些财务报表是积极还是负面的。 但是,我们进入了HF,发现许多模型已经经过与金融相关的问答数据集的训练,那么 我们可以使用这些模型的某些层来改进自己的任务。
  • 另一个简单的示例是,某个特定领域的模型经过巨大数据集的训练学会了将文本从中分为5个类别。 假设我们有类似的分类任务,在同一域中的一个完全不同的数据集,只想将数据分类为2个类别而不是5。 这时我们也可以复用模型主体,添加自己的模型头来增强我们自己任务的特定领域知识。

这就是我们要做的事情的示意图:

2、自定义HF模型头

我们的任务是简单的,从Kaggle上的这个数据集进行讽刺检测。

你可以在此处查看完整的代码。 为了时间的考虑,我没有在下面包括预处理和一些训练的详细信息,因此请确保查看整个代码的笔记本。

我将使用一个在大量推文上训练的模型,有5个分类输出不同的情感类型。我们将提取模型体,在pytorch中添加自定义层(2个标签,讽刺/不讽刺),并训练新的模型。

注意:你可以在此示例中使用任何模型(不一定是对分类训练的模型),因为我们只会使用该模型主体并拆除模型头。

这就是我们的工作流程:

我将跳过数据预处理步骤,然后直接跳到主类,但是你可以在本节开头的链接中查看整个代码。

3、令牌化和动态填充

使用如下代码将文本转化为令牌并进行动态填充:

checkpoint = "cardiffnlp/twitter-roberta-base-emotion"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.model_max_len=512

def tokenize(batch):
  return tokenizer(batch["headline"], truncation=True,max_length=512)

tokenized_dataset = data.map(tokenize, batched=True)
print(tokenized_dataset)

tokenized_dataset.set_format("torch",columns=["input_ids", "attention_mask", "label"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

结果如下:

DatasetDict({
    train: Dataset({
        features: ['headline', 'label', 'input_ids', 'attention_mask'],
        num_rows: 22802
    })
    test: Dataset({
        features: ['headline', 'label', 'input_ids', 'attention_mask'],
        num_rows: 2851
    })
    valid: Dataset({
        features: ['headline', 'label', 'input_ids', 'attention_mask'],
        num_rows: 2850
    })
})

4、提取模型体并添加我们自己的层

代码如下:

class CustomModel(nn.Module):
  def __init__(self,checkpoint,num_labels): 
    super(CustomModel,self).__init__() 
    self.num_labels = num_labels 

    #Load Model with given checkpoint and extract its body
    self.model = model = AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=True,output_hidden_states=True))
    self.dropout = nn.Dropout(0.1) 
    self.classifier = nn.Linear(768,num_labels) # load and initialize weights

  def forward(self, input_ids=None, attention_mask=None,labels=None):
    #Extract outputs from the body
    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

    #Add custom layers
    sequence_output = self.dropout(outputs[0]) #outputs[0]=last hidden state

    logits = self.classifier(sequence_output[:,0,:].view(-1,768)) # calculate losses
    
    loss = None
    if labels is not None:
      loss_fct = nn.CrossEntropyLoss()
      loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    
    return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)

如你所见,我们首先是继承Pytorch中的 nn.Module,使用AutoModel(来自transformers库)提取加载了指定检查点的模型主体。

请注意, forward() 方法返回 TokenClassifierOutput,从而确保我们输出的格式与HF预训练模型一致。

5、端到端训练新的模型

代码如下:

from tqdm.auto import tqdm

progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(eval_dataloader)))


for epoch in range(num_epochs):
  model.train()
  for batch in train_dataloader:
      batch = {k: v.to(device) for k, v in batch.items()}
      outputs = model(**batch)
      loss = outputs.loss
      loss.backward()

      optimizer.step()
      lr_scheduler.step()
      optimizer.zero_grad()
      progress_bar_train.update(1)

  model.eval()
  for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])
    progress_bar_eval.update(1)
    
  print(metric.compute())
  model.eval()

test_dataloader = DataLoader(
    tokenized_dataset["test"], batch_size=32, collate_fn=data_collator
)

for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

结果如下:

  0%|          | 0/2139 [00:00<?, ?it/s]
  0%|          | 0/270 [00:00<?, ?it/s]
{'f1': 0.9335347432024169}
{'f1': 0.9360090874668686}
{'f1': 0.9274912756882513}

如你所见,我们使用此方法实现了不错的性能。 请记住,该博客的目的不是分析此特定数据集的性能,而是要学习如何使用预训练的身体并添加自定义头。

6、结束语

在本文中,我们看到了如何在HF预训练模型上添加自定义层。

一些收获:

  • 在我们拥有特定于域的数据集并希望利用在同一域(任务 - 努力的task-agnostic)上训练的模型以增强小型数据集中的性能的情况下,此技术特别有用。
  • 我们可以选择接受过与自己任务不同的下游任务训练的模型,并且仍然使用该模型主体的知识。
  • 如果你的数据集足够大且通用,那么这可能根本不需要,在这种情况下,你可以使用 AutoModeForSequenceCecrification或使用 BERT 解决的任何其他任务。 实际上,如果是这样,我强烈建议不要建立自己的模型头。

原文链接:HF自定义模型头 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/134535.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大数据Doris(二十一):数据导入演示

文章目录 数据导入演示 一、启动zookeeper集群(三台节点都启动) 二、启动hdfs集群

LLM代码生成器的挑战【GDELT早期观察】

越来越多的研究开始对LLM大模型生成的代码的质量提出质疑&#xff0c;尽管科技行业不断推出越来越多的旨在增强甚至取代人类编码员的工具。 随着我们&#xff08;GDELT&#xff09;继续探索和评估越来越多的此类工具&#xff0c;以下是我们的一些早期观察结果。 在线工具推荐&a…

CCF ChinaSoft 2023 论坛巡礼|机器人大模型与具身智能挑战赛

2023年CCF中国软件大会&#xff08;CCF ChinaSoft 2023&#xff09;由CCF主办&#xff0c;CCF系统软件专委会、形式化方法专委会、软件工程专委会以及复旦大学联合承办&#xff0c;将于2023年12月1-3日在上海国际会议中心举行。 本次大会主题是“智能化软件创新推动数字经济与社…

银行卡转账记录p图软件,建设邮政工商招商农业,易语言回执单生成开发!

花了好长时间设计出来了这么一个软件&#xff0c;当然各个功能我都做了防范处理界面还有生成的图片都有对应的水印提示&#xff0c;做不了啥坏事&#xff0c;这里就是分享下原理和代码还有运行逻辑&#xff0c;仅此而已&#xff0c;软件加了一个画板&#xff0c;画面上面的图片…

Semantic Kernel 学习笔记1

1. 挂代理跑通openai API 2. 无需魔法跑通Azure API 下载Semantic Kernel的github代码包到本地&#xff0c;主要用于方便学习python->notebooks文件夹中的内容。 1. Openai API&#xff1a;根据上述文件夹中的.env.example示例创建.env文件&#xff0c;需要填写下方两个内…

计网:第一章 概述

目录 1.1计算机网络在信息时代作用 1.2因特网概述 1.3三种交换方式 1.4计算机网络的定义和分类 1.5计算机网络的性能指标 1.6计算机网络的体系结构 基于湖科大教书匠b站计算机网络教学视频以及本校课程老师ppt 整合出的计算机网络学习笔记 根据文章目录&#xff0c;具体内…

删除成绩(数组)

任务要求 设计程序&#xff0c;实现从多名学生某门课程的成绩查找到第一个不及格的成绩&#xff0c;删除其成绩&#xff0c;输出删除成绩后的多名学生这一门课程的成绩。任务保证至少存在1个学生的成绩为不及格。

Vuex:模块化Module :VOA模式

由于使用单一状态树&#xff0c;应用的所有状态会集中到一个比较大的对象。当应用变得非常复杂时&#xff0c;store 对象就有可能变得相当臃肿。 这句话的意思是&#xff0c;如果把所有的状态都放在/src/store/index.js中&#xff0c;当项目变得越来越大的时候&#xff0c;Vue…

推荐 8 款OCR工具(二)完结篇

双十一&#xff0c;又要剁手了&#xff0c;但我还是 推荐 8 款OCR工具&#xff01; 当你感到迷茫时&#xff0c;不妨停下来&#xff0c;深呼吸&#xff0c;重新审视自己所处的位置和你的内心。这样的简单行为可能会帮助你找到方向。 SimpleOCR 网址&#xff1a;https://simple…

web:[网鼎杯 2018]Fakebook

题目 点进页面&#xff0c;页面显示为 查看源代码 用dirsearch扫一下&#xff0c;看一下有什么敏感信息泄露 扫出另一个flag.php和robots.txt&#xff0c;访问flag.php回显内容为空 请求robots.txt 网页提示/user.php.bak&#xff0c;直接访问会自动下载.bak备份文件 进行代码…

过去5年,Python生态有什么变化?

你好&#xff0c;我是 EarlGrey&#xff0c;一名双语学习者&#xff0c;会一点编程&#xff0c;目前已翻译出版《Python 无师自通》、《Python 并行编程手册》等书籍。 点击上方蓝字关注我&#xff0c;持续接收优质好书、高效工具和赚钱机会&#xff0c;一起提升认知和思维。 过…

智慧水利整体解决方案:PPT全文43页,附下载

关键词&#xff1a;智慧水利发展前景&#xff0c;智慧水利解决方案&#xff0c;智慧水利建设方案&#xff0c;智慧水利平台系统 一、智慧水利建设背景 传统水利系统存在一些问题&#xff1a; 现有基础感知不能满足更高标准的水利管理需求&#xff1b;决策调度支撑能力亟需加强…

Benchmarking Large Language Models in Retrieval-Augmented Generation-学习翻译

提检索增强生成中大型语言模型的基准测试文献学习 作者将在https://github.com/chen700564/RGB上发布本文的代码和RGB。 y ˇ \check{y} yˇ​ 文章目录 摘要IntroductionRelated workRetrieval-Augmented Generation BenchmarkRAG所需能力数据构建评估指标 ExperimentsSetting…

PCBA表面污染的分类及处理方法

NO.1 引言 在PCBA生产过程中&#xff0c;锡膏和助焊剂会产生残留物质&#xff0c;残留物中包含的有机酸和电离子&#xff0c;前者易腐蚀PCBA&#xff0c;后者会造成焊盘间短路故障。且近年来&#xff0c;用户对产品的清洁度要求越来越严格&#xff0c;PCBA清洗工艺逐渐被电子组…

国际阿里云:无法访问ECS实例中的服务的排查方法!!!

操作场景 无法访问ECS实例中的服务可能有以下原因&#xff1a; 可能原因 排查方案 ECS实例的安全组未开放相应端口 检查ECS实例安全组规则 ECS实例中&#xff0c;该服务未启动/开启或服务对应端口未被监听 检查服务状态及端口监听状态 ECS实例内防火墙设置错误 检查ECS…

Linux下的调试工具——GDB

GDB 1.什么是GDB GDB 是由 GNU 软件系统社区提供的调试工具&#xff0c;同 GCC 配套组成了一套完整的开发环境&#xff0c;GDB 是 Linux 和许多 类Unix系统的标准开发环境。 一般来说&#xff0c;GDB 主要能够提供以下四个方面的帮助&#xff1a; 启动程序&#xff0c;可以按…

FRI的Commit、Query以及FRI Batching内部机制

1. 引言 前序博客见&#xff1a; Reed-Solomon Codes及其与RISC Zero zkVM的关系RISC Zero ZKP协议中的商多项式 FRI用途&#xff1a; 用于证明某vector commitment是&#xff08;接近&#xff09;某Reed-Solomon codeword。即证明某low-degree多项式。 FRI工作原理&#…

时间序列预测实战(十二)DLinear模型实现滚动长期预测并可视化预测结果

官方论文地址->官方论文地址 官方代码地址->官方代码地址 个人修改代码->个人修改的代码已经上传CSDN免费下载 一、本文介绍 本文给大家带来是DLinear模型&#xff0c;DLinear是一种用于时间序列预测&#xff08;TSF&#xff09;的简单架构&#xff0c;DLinear的核…

uni-app点击按钮弹出提示框-uni.showModal(OBJECT),选择确定和取消

参考文档&#xff1a; https://uniapp.dcloud.io/api/ui/prompt?idshowmodal 显示模态弹窗&#xff0c;可以只有一个确定按钮&#xff0c;也可以同时有确定和取消按钮。类似于一个API整合了 html 中&#xff1a;alert、confirm。 uni.showModal({title: 提示,content: 这是一…

【计算机网络笔记】IP分片

系列文章目录 什么是计算机网络&#xff1f; 什么是网络协议&#xff1f; 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能&#xff08;1&#xff09;——速率、带宽、延迟 计算机网络性能&#xff08;2&#xff09;…