【Python】科研代码学习:三 PreTrainedModel, PretrainedConfig, PreTrainedTokenizer

【Python】科研代码学习:三 PreTrainedModel, PretrainedConfig, PreTrainedTokenizer

  • 前言
  • Models : PreTrainedModel
    • PreTrainedModel 中重要的方法
  • tensorflow & pytorch 简单对比
  • Configuration : PretrainedConfig
    • PretrainedConfig 中重要的方法
  • Tokenizer : PreTrainedTokenizer
    • PreTrainedTokenizer 中重要的方法

前言

  • HF 官网API
    本文主要从官网API与源代码中学习调用HF的关键模组

Models : PreTrainedModel

  • HF 提供的基础模型类有 PreTrainedModel, TFPreTrainedModel, and FlaxPreTrainedModel
  • 这三者有什么区别呢
    PreTrainedModel 指的是用 torch 的框架
    在这里插入图片描述
    TFPreTrainedModel 指的是用 tensorflow 框架
    在这里插入图片描述
    FlaxPreTrainedModel 指的是用 flax 框架,是用 jax 做的
    在这里插入图片描述
    (哈哈,搜了好久都没搜到,去看源码导包瞬间明白了,也可能是我比较笨)
  • Transformers的大部分模型都会继承PretrainedModel基类。PretrainedModel主要负责管理模型的配置,模型的参数加载、下载和保存。
  • PretrainedModel继承自 nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin
    在初始化时需要提供给它一个 config: PretrainedConfig
  • 所以,我们可以视为它是所有模型的基类
    可以看到很多其他代码在判断模型类型时,一般写 model: Union[PreTrainedModel, nn.Module]

PreTrainedModel 中重要的方法

  • push_to_hub:将模型传到HF hub
from transformers import AutoModel

model = AutoModel.from_pretrained("google-bert/bert-base-cased")

# Push the model to your namespace with the name "my-finetuned-bert".
model.push_to_hub("my-finetuned-bert")

# Push the model to an organization with the name "my-finetuned-bert".
model.push_to_hub("huggingface/my-finetuned-bert")
  • from_pretrained:根据config实例化预训练pytorch模型(Instantiate a pretrained pytorch model from a pre-trained model configuration.)
    默认使用评估模式 .eval()
    可以打开训练模式 .train()

    看下面的例子,可以从官方加载,也可以从本地模型参数加载。如果本地参数是tf的,转pytorch需要设置 from_tf=True,并且会慢些;本地参数是flax的话类似同理。
from transformers import BertConfig, BertModel

# Download model and configuration from huggingface.co and cache.
model = BertModel.from_pretrained("google-bert/bert-base-uncased")
# Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
model = BertModel.from_pretrained("./test/saved_model/")
# Update configuration during loading.
model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
# Loading from a Flax checkpoint file instead of a PyTorch model (slower)
model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)

可以给 torch_dtype 设置数据类型。若不给,则默认为 torch.float16。也可以给 torch_dtype="auto"

  • get_input_embeddings:获得输入的词嵌入在这里插入图片描述
    对应还有 get_output_embeddings
  • init_weights:设置参数初始化
    如果需要自己调整参数初始化的,在 _init_weights_initialize_weights 中设置
  • save_pretrained:把模型和配置参数保存在文件夹中
    保存完后,便可以通过 from_pretrained 再次加载模型了
    在这里插入图片描述

tensorflow & pytorch 简单对比

  • 知乎:Tensorflow 到底比 Pytorch 好在哪里?
    下面截取了比较重要的图
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
  • 里面还提到了一个内容叫做 Keras

Keras是一个由Python编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK和Theano的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化

Configuration : PretrainedConfig

  • 刚才看了,对于 PretrainedModel 初始化提供的参数是 PretrainedConfig 类型的参数。
    它主要为不同的任务,提供了不同的重要参数
    HF官网:PretrainedConfig
  • 列一下对于NLP中比较重要的参数吧,所有的就看官方文档吧
返回信息
output_hidden_states (bool, optional, defaults to False) — Whether or not the model should return all hidden-states.
output_attentions (bool, optional, defaults to False) — Whether or not the model should returns all attentions.
return_dict (bool, optional, defaults to True) — Whether or not the model should return a ModelOutput instead of a plain tuple.
output_scores (bool, optional, defaults to False) — Whether the model should return the logits when used for generation.
return_dict_in_generate (bool, optional, defaults to False) — Whether the model should return a ModelOutput instead of a torch.LongTensor.

序列生成
max_length (int, optional, defaults to 20) — Maximum length that will be used by default in the generate method of the model.
min_length (int, optional, defaults to 0) — Minimum length that will be used by default in the generate method of the model.
do_sample (bool, optional, defaults to False) — Flag that will be used by default in the generate method of the model. Whether or not to use sampling ; use greedy decoding otherwise.
num_beams (int, optional, defaults to 1) — Number of beams for beam search that will be used by default in the generate method of the model. 1 means no beam search.
diversity_penalty (float, optional, defaults to 0.0) — Value to control diversity for group beam search. that will be used by default in the generate method of the model. 0 means no diversity penalty. The higher the penalty, the more diverse are the outputs.
temperature (float, optional, defaults to 1.0) — The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive.
top_k (int, optional, defaults to 50) — Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model.
top_p (float, optional, defaults to 1) — Value that will be used by default in the generate method of the model for top_p. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
epetition_penalty (float, optional, defaults to 1) — Parameter for repetition penalty that will be used by default in the generate method of the model. 1.0 means no penalty.
length_penalty (float, optional, defaults to 1) — Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
bad_words_ids (List[int], optional) — List of token ids that are not allowed to be generated that will be used by default in the generate method of the model. In order to get the tokens of the words that should not appear in the generated text, use tokenizer.encode(bad_word, add_prefix_space=True).

tokenizer相关
bos_token_id (int, optional) — The id of the beginning-of-stream token.
pad_token_id (int, optional) — The id of the padding token.
eos_token_id (int, optional) — The id of the end-of-stream token.

PyTorch相关
torch_dtype (str, optional) — The dtype of the weights. This attribute can be used to initialize the model to a non-default dtype (which is normally float32) and thus allow for optimal storage allocation. For example, if the saved model is float16, ideally we want to load it back using the minimal amount of memory needed to load float16 weights. Since the config object is stored in plain text, this attribute contains just the floating type string without the torch. prefix. For example, for torch.float16 `torch_dtype is the "float16" string.

常见参数
vocab_size (int) — The number of tokens in the vocabulary, which is also the first dimension of the embeddings matrix (this attribute may be missing for models that don’t have a text modality like ViT).
hidden_size (int) — The hidden size of the model.
num_attention_heads (int) — The number of attention heads used in the multi-head attention layers of the model.
num_hidden_layers (int) — The number of blocks in the model.

PretrainedConfig 中重要的方法

  • push_to_hub:依然是上传到 HF hub
  • from_dict:把一个 dict 类型转到 PretrainedConfig 类型
  • from_json_file:把一个 json 文件转到 PretrainedConfig 类型,传入的是文件路径
  • to_dict:转成 dict 类型
  • to_json_file:保存到 json 文件
  • to_json_string:转成 json 字符串
  • from_pretrained:从预训练模型配置文件中直接获取配置
    可以是HF模型,也可以是本地模型,见下方例子
# We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
# derived class: BertConfig
config = BertConfig.from_pretrained(
    "google-bert/bert-base-uncased"
)  # Download configuration from huggingface.co and cache.
config = BertConfig.from_pretrained(
    "./test/saved_model/"
)  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
assert config.output_attentions == True
config, unused_kwargs = BertConfig.from_pretrained(
    "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
)
assert config.output_attentions == True
assert unused_kwargs == {"foo": False}
  • save_pretrained:把配置文件保存到文件夹中,方便下次 from_pretrained 直接读取

Tokenizer : PreTrainedTokenizer

  • HF官网:PreTrainedTokenizer
    Tokenizer 是用来把输入的字符串,转成 id 数组用的
    先来看一下其中相关的类的继承关系
    在这里插入图片描述
  • PreTrainedTokenizer 的初始化方法是直接给了 **kwargs
    调几个重要的列在下面,可以看到大部分都是设置一些token的含义。
bos_token (str or tokenizers.AddedToken, optional) — A special token representing the beginning of a sentence. Will be associated to self.bos_token and self.bos_token_id.
eos_token (str or tokenizers.AddedToken, optional) — A special token representing the end of a sentence. Will be associated to self.eos_token and self.eos_token_id.
unk_token (str or tokenizers.AddedToken, optional) — A special token representing an out-of-vocabulary token. Will be associated to self.unk_token and self.unk_token_id.
sep_token (str or tokenizers.AddedToken, optional) — A special token separating two different sentences in the same input (used by BERT for instance). Will be associated to self.sep_token and self.sep_token_id.
pad_token (str or tokenizers.AddedToken, optional) — A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms or loss computation. Will be associated to self.pad_token and self.pad_token_id.
cls_token (str or tokenizers.AddedToken, optional) — A special token representing the class of the input (used by BERT for instance). Will be associated to self.cls_token and self.cls_token_id.
mask_token (str or tokenizers.AddedToken, optional) — A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). Will be associated to self.mask_token and self.mask_token_id.

PreTrainedTokenizer 中重要的方法

  • add_tokens:添加一些新的token
    它强调了,添加新token需要确保 token 嵌入矩阵与tokenizer是匹配的,即多调用一下 resize_token_embeddings 方法
    在这里插入图片描述
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")
model = BertModel.from_pretrained("google-bert/bert-base-uncased")

num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
  • add_special_tokens:添加特殊tokens,比如之前的 eos,pad 等,与之前普通的tokens是不大一样的,但要确保该token不在词汇表里
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = GPT2Model.from_pretrained("openai-community/gpt2")

special_tokens_dict = {"cls_token": "<CLS>"}

num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))

assert tokenizer.cls_token == "<CLS>"
  • encode, decode:字符串转id数组,id数组转字符串,即词嵌入
    encodeself.convert_tokens_to_ids(self.tokenize(text)) 等价
    decodeself.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids)) 等价
  • tokenize:把字符串转成token序列,即分词 str → list[str]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/440736.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

influxdb2.0插入数据字段类型出现冲突问题解决

一、问题出现 一个学校换热站自控系统&#xff0c;会定时从换热站获取测点数据&#xff0c;并插入到influxdb数据库中。influxdb插入数据时&#xff0c;报错提示&#xff1a; com.influxdb.exceptions.UnprocessableEntityException: failure writing points to database: par…

组合逻辑电路(二)(译码器和编码器)

目录 译码器 简单逻辑门译码器 二进制译码器 2线-4线译码器 3线-8线译码器 二-十进制译码器 4线-10线译码器 七段显示译码器 编码器 二进制普通编码器 二-十进制普通编码器&#xff08;8421BCD码编码器&#xff09; 优先编码器&#xff08;Priority Encoder&#xff09; 译…

《解密云计算:企业之选》

前言 在当今数字化时代&#xff0c;企业面临着巨大的数据处理压力和信息化需求&#xff0c;传统的IT架构已经无法满足日益增长的业务需求。在这样的背景下&#xff0c;越来越多的企业开始转向云计算&#xff0c;以实现灵活、高效和可扩展的IT资源管理和利用。 云计算 云计算是…

【QT中如何生成导出.exe可执行文件并打包给其他人使用】

1、将QT的部署设置改成Release编译模式。 2、运行项目生成release文件夹&#xff0c;其中包含.exe文件。 3、新建空文件夹&#xff0c;将release文件夹中的.exe文件复制到里面去。&#xff08;此处新建了hellofile空文件夹来存放hello.exe文件&#xff09; 4、在QT终端里&#…

SpringBoot学习之自定义注解和AOP 切面统一保存操作日志(二十九)

一、定义一个注解 这个注解是用来控制是否需要保存操作日志的自定义注解(这个类似标记或者开关) package com.xu.demo.common.anotation;import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; i…

Filter过滤器+JWT令牌实现登陆验证

一、背景 我们需要在客户端访问服务器的时候给定用户一定的操作权限&#xff0c;比如没有登陆时就不能进行其他操作。如果他需要进行其他操作&#xff0c;而在这之前他没有登陆过&#xff0c;服务端则需要将该请求拦截下来&#xff0c;这就需要用到过滤器&#xff0c;过滤器可以…

【YOLO v5 v7 v8 v9小目标改进】AFPN 渐进式特征金字塔网络:解决多尺度特征融合中,信息在传递过程丢失

AFPN 渐进式特征金字塔网络&#xff1a;解决多尺度特征融合中&#xff0c;信息在传递过程丢失 提出背景AFPN 多尺度特征金字塔 非邻近层次的直接特征融合 自适应空间融合操作 小目标涨点YOLO v5 魔改YOLO v7 魔改YOLO v8 魔改YOLO v9 魔改 提出背景 论文&#xff1a;https:…

复试人工智能前沿概念总结

1.大模型相关概念&#xff08;了解即可&#xff09; 1.1 GPT GPT&#xff0c;全称为Generative Pre-training Transformer&#xff0c;是OpenAI开发的一种基于Transformer的大规模自然语言生成模型。GPT模型采用了自监督学习的方式&#xff0c;首先在大量的无标签文本数据上进…

力扣hot100题解(python版55-59题)

55、全排列 给定一个不含重复数字的数组 nums &#xff0c;返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1&#xff1a; 输入&#xff1a;nums [1,2,3] 输出&#xff1a;[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2&#xff1a; 输入&…

论文研读笔记1:

1.Improving Domain-Adapted Sentiment Classification by Deep Adversarial Mutual Learning&#xff1a; 1.1本篇论文提出了一种名为深度对抗性互学习&#xff08;Deep Adversarial Mutual Learning, DAML&#xff09;的新方法&#xff0c;用于改进领域适应性情感分类。 对…

使用 Cypress 进行可视化回归测试:一种务实的方法

每次组件库 Picasso 发布新版本时&#xff0c;都会更新所有的前端应用程序&#xff0c;让绝大部分新功能能与整个平台的设计保持一致。上个月&#xff0c;推出了 Toptal Talent Portal 的 Picasso 更新&#xff0c;这是我们的用户用来找工作和与客户互动的平台。 已知了这个版本…

C++指针(四)万字图文详解!

个人主页&#xff1a;PingdiGuo_guo 收录专栏&#xff1a;C干货专栏 前言 相关文章&#xff1a;C指针&#xff08;一&#xff09;、C指针&#xff08;二&#xff09;、C指针&#xff08;三&#xff09; 本篇博客是介绍函数指针、函数指针数组、回调函数、指针函数的。 点赞破六…

结构体和malloc学习笔记

结构体学习&#xff1a; 为什么会出现结构体&#xff1a; 为了表示一些复杂的数据&#xff0c;而普通的基本类型变量无法满足要求&#xff1b; 定义&#xff1a; 结构体是用户根据实际需要自己定义的符合数类型&#xff1b; 如何使用结构体&#xff1a; //定义结构体 struc…

【工具】Raycast – Mac提效工具

引入 以前看到同事们锁屏的时候&#xff0c;不知按了什么键&#xff0c;直接调出这个框&#xff0c;然后输入lock屏幕就锁了。 跟我习惯的按Mac开机键不大一样。个人觉得还是蛮炫酷的&#xff5e; 调研 但是由于之前比较繁忙&#xff0c;这件事其实都忘的差不多了&#xff0…

C++ · 代码笔记4 ·继承与派生

目录 前言010继承与派生简单例程020多级继承030使用using关键词更改访问权限040隐藏050派生类与基类成员函数同名时不构成重载060使用多级继承展示成员变量在内存中的分布情况071派生类在函数头调用基类构造函数072构造函数调用顺序080构造函数与析构函数的调用顺序091多重继承…

【常见集合】Java 常见集合重点解析

Java 常见集合重点解析 1. 什么是算法时间复杂度&#xff1f; 时间复杂度表示了算法的 执行时间 和 数据规模 之间的增长关系&#xff1b; 什么是算法的空间复杂度&#xff1f; 表示了算法占用的额外 存储空间 与 数据规模 之间的增长关系&#xff1b; 常见的复杂度&#x…

超实用的公众号搭建教程分享,纯干货

微信公众号已经成为了企业、个人和品牌进行宣传和互动的重要平台。在这个拥有海量公众号的时代&#xff0c;如何让你的公众号脱颖而出&#xff0c;吸引更多的关注者&#xff0c;实现有效传播呢&#xff1f;接下来&#xff0c;伯乐网络传媒将为你详细解析公众号搭建教程&#xf…

便捷在线导入:完整Axure元件库集合,让你的设计更高效!

Axure元件库包含基本的工具组件&#xff0c;可以使原型绘制节省大量的重复工作&#xff0c;保持整个设计页面的一致性和标准化&#xff0c;同时显得专业。Axure元件库就像我们日常生活中的门把手、自行车踏板和桌子上的螺丝钉&#xff0c;需要组装才能使用。作为一名成熟的产品…

信息安全管理与评估DCST-6000B-Pro神州数码堡垒机沙箱连接教程

信息安全管理与评估DCST-6000B-Pro神州数码堡垒机沙箱连接教程 一、前言 在全国职业院校技能大赛-信息安全管理与评估赛项中&#xff0c;我们会用到DCST-6000B-Pro神州数码堡垒机沙箱&#xff0c;简称堡垒机&#xff0c; 很多院校并没有购买该设备&#xff0c;导致备赛学生可…

阿珊详解Vue Router的守卫机制

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…
最新文章