NLP深入学习:结合源码详解 BERT 模型(三)

文章目录

  • 1. 前言
  • 2. 预训练
    • 2.1 modeling.BertModel
      • 2.1.1 embedding_lookup
      • 2.1.2 embedding_postprocessor
      • 2.1.3 transformer_model
    • 2.2 get_masked_lm_output
    • 2.3 get_next_sentence_output
    • 2.4 训练
  • 3. 参考


1. 前言

前情提要:
《NLP深入学习:结合源码详解 BERT 模型(一)》
《NLP深入学习:结合源码详解 BERT 模型(二)》

之前已经详细说明了 BERT 模型的主要架构和思想,并且讲解了 BERT 源代码对于数据准备的流程,回顾下关键字段的含义:

# 以下是输出到文件的值,也是会作为后续预训练的输入值,重点看!
input_ids:tokens在字典的索引位置,不足max_seq_length(128)则补0
input_mask:初始化为1,不足max_seq_length(128)则补0
segment_ids: 句子A的token和句子B的token,按照0/1排列区分。不足max_seq_length(128)则补0
masked_lm_positions: 被选中 MASK 的token位置索引
masked_lm_ids:被选中 MASK 的token原始值在字典的索引位置
masked_lm_weights:初始化为1
next_sentence_labels:对应is_random_next,1表示随机选择,0表示正常语序

下面我们结合预训练代码详细讲解下 BERT 的预训练流程。

2. 预训练

预训练代码在 run_pretraing.py 文件中,注意我们需要把数据准备的结果作为预训练的输入:
在这里插入图片描述
那我们打上断点,继续开启 debug 吧!
在这里插入图片描述

2.1 modeling.BertModel

看预训练代码,大部分的核心代码集中在 modeling.BertModel 这个 class 的 __init__ 代码中:
在这里插入图片描述
解释下 modeling.BertModel 的参数:

  • config: BERT 的配置文件,后续的很多参数都来源于此。我放到路径 ./multi_cased_L-12_H-768_A-12/bert_config.json ,内容如下:
{
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": 119547
}
  • is_training:True 表示训练,False 表示评估
  • input_ids:对应于数据准备的字段 input_ids,形状 [batch_size, seq_length],即 [32, 128]
  • input_mask:对应于数据准备的字段 input_mask,形状 [batch_size, seq_length],即 [32, 128]
  • token_type_ids:对应于数据准备的字段 segment_ids,形状 [batch_size, seq_length],即 [32, 128]
  • use_one_hot_embeddings:词嵌入是否用 one_hot 模式
  • scope:变量的scope,用于 tf.variable_scope(scope, default_name="bert") 默认是 bert

2.1.1 embedding_lookup

modeling.BertModel__init__ 代码中,第一个重要的方法是 embedding_lookup
在这里插入图片描述
我们看下具体的代码,返回值有两个:

  • out_put 是根据输入的 input_ids 在字典中找到对应的词,并且返回词对应的 embedding 向量,out_put 的形状是 [batch_size, seq_length, embedding_size]
  • embedding_table 是字典每一个词对应的向量,形状是 [vocab_size, embedding_size]

在这里插入图片描述
ps: 有些同学不清楚字典是什么?字典在项目的 ./multi_cased_L-12_H-768_A-12/vocab.txt 里,每一行对应一个词,里例如id=0则表示字典第一个对应的词[PAD],字典内容如下:

[PAD]
[unused1]
[unused2]
[unused3]
[unused4]
...
[unused99]
[UNK]
[CLS]
[SEP]
[MASK]
<S>
<T>
!
"
#
$
%
...
A
B
C
D
E
F
G
H

2.1.2 embedding_postprocessor

后续的该方法是用于加上位置编码!
在这里插入图片描述
我们进到函数内部查看具体细节:
在这里插入图片描述
上面代码中,token_type_ids 对应的是 segment_ids,即句子的表示(用0/1来表示),细节见《NLP深入学习:结合源码详解 BERT 模型(二)》 的 2.3章节。token_type_table 和上一节的 embedding_table 是一样的含义,这里就是向量化 segment_ids。由于 segment_ids 只用 0和1来表示,所以token_type_vocab_size=2,并且最终将 out_put 加上了 segment_ids 向量化的结果,就是图中的 TokenEmbeddings + SegmentEmbeddings
在这里插入图片描述
那么显而易见,下一段代码就是再加上 PositionEmbeddings 了!
在这里插入图片描述
注意,这里的 position_embeddings 实际就是词在句子中的位置对应的 embedding~

最后将输出加上了 layer_norm_and_dropout ,即层归一和dropout。

2.1.3 transformer_model

顺着代码debug下去,在准备好了数据之后,就是经典的 Transformer 模型了:
在这里插入图片描述
希望深入了解 Transformer 的,建议参考:
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(一)》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(二)》

我们先回忆下 Transformer 的结构,因为下面的代码完全是对论文的编码器实现:
在这里插入图片描述
为了方便查看,我把代码的结构和论文的结构对比在一起:
在这里插入图片描述
transformer 结构构建完成之后,下面的self.sequence_out 是把最后一层的输出作为 transformer 的 encoding 结果输出。
在这里插入图片描述
此外,first_token_tensor 是取第一个 token 的输出结果,即 [CLS] 的结果。因为 [CLS] 已经带有上下文信息了,因此对于分类而言,用 [CLS] 的输出即可。这个论文中也有说明:
在这里插入图片描述
以上就是 BERT 模型的构建整体流程,下面来看 BERT 模型的评估流程,包含 Masked Language Model(MLM)和 Next Sentence Prediction(NSP)。

2.2 get_masked_lm_output

先来看 Masked Language Model(MLM)的评估,对应代码中的 get_masked_lm_out ,见下图:

首先看下 get_masked_lm_out 的输入参数:

  • bert_config : BERT 的配置文件,对应我的路径 ./multi_cased_L-12_H-768_A-12/bert_config.json
  • input_tensor:BERT 模型的输出,即上文的 self.sequence_out
  • output_weights:对应上文 embedding_lookup 的第二个输出,即字典每一个词对应的向量,形状是 [vocab_size, embedding_size]
  • positions:对应 features["masked_lm_positions"] ,即被选中 MASK 的 token 位置索引
  • label_ids:对应 features["masked_lm_ids"],即被选中 MASK 的 token 原始值在字典的索引位置
  • label_weights:对应 features["masked_lm_weights"]

下面是整体的代码,代码有些地方需要细细品味:

在这里插入图片描述
要看懂这里的代码,首先我们要知道 BERT 在 Masked Language Model(MLM)上要干啥。BERT 首先给句子的词打上了 [MASK] ,后续就要对 [MASK] 的词进行预测。预测,就是在词典中出现的词给出一个概率,看属于哪个词,本质上就是多分类问题。那么对于多分类问题,通常的做法是计算交叉熵。

这里就不详细阐述交叉熵的来龙去脉了,直接说明交叉熵如何计算。我们假设真实分布为 y,而模型输出分布为 y ^ \widehat{y} y ,总的类别数为 n,交叉熵损失函数的计算方法为:
l o s s = ∑ i = 1 n [ − y l o g y ^ i − ( 1 − y ) l o g ( 1 − y ^ i ) ] loss = \sum_{i=1}^{n}[-ylog\widehat{y}_i-(1-y)log(1-\widehat{y}_i)] loss=i=1n[ylogy i(1y)log(1y i)]
好,我们来看代码中关键的几个步骤:

  • log_probs = tf.nn.log_softmax(logits, axis=-1) ,这个方法实际上计算的是:
    l o g _ p r o b s = [ l o g y ^ 1 , l o g y ^ 2 , . . . , l o g y ^ n ] log\_probs = [log\widehat{y}_1, log\widehat{y}_2,...,log\widehat{y}_n] log_probs=[logy 1,logy 2,...,logy n]
    其中 l o g y ^ i log\widehat{y}_i logy i 表达的是属于词典第 i 个词的概率的对数值。

  • one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32),计算每个词的在字典的 one_hot 结果,形状是 [batch_size*seq_len, vocab_size]。例如,“animal” 在字典第18883位置,那么"animal"对应的 one_hot 就是 [0,0,…0,1,0,…,0],其中向量长度就是字典的大小,1排在向量的18883个。

  • per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) ,这个方法是用于交叉熵的。因为我们知道真实的分布情况,就是 one_hot_labels 对应的结果,那么对于某一个具体的词,其交叉熵的计算就是 − y l o g y ^ i − ( 1 − y ) l o g ( 1 − y ^ i ) -ylog\widehat{y}_i-(1-y)log(1-\widehat{y}_i) ylogy i(1y)log(1y i),将 y=1(即事先知道一定属于某个词)代入,即交叉熵为 − l o g y ^ i -log\widehat{y}_i logy i。所以事先计算了 log_probsper_example_loss 可以直接得到每个词的交叉熵的结果。

  • lossper_example_loss 得到的结果赋予权重进行加权平均,得到一个最终的 loss,实际上就相当于 l o s s = ∑ i = 1 n w i [ − y l o g y ^ i − ( 1 − y ) l o g ( 1 − y ^ i ) ] loss = \sum_{i=1}^{n}w_i[-ylog\widehat{y}_i-(1-y)log(1-\widehat{y}_i)] loss=i=1nwi[ylogy i(1y)log(1y i)]

2.3 get_next_sentence_output

再来看 Next Sentence Prediction(NSP)评估,预测句子的下一句:
在这里插入图片描述
首先看下 get_next_sentence_output 的输入参数:

  • bert_config: BERT 的配置文件,对应我的路径 ./multi_cased_L-12_H-768_A-12/bert_config.json
  • input_tensor[CLS] 的输出线性变换后的结果,简单理解为 [CLS] 的输出作为当前函数的输入
  • labels:对应 features["next_sentence_labels"] ,1表示下一个句子是随机选择的,0表示正常语序

由于下一个句子只有两种选择,要么是随机的,要么是原先正常的句子,所以其实就是一个二分类问题:
在这里插入图片描述
二分类的交叉熵:
l o s s = ∑ i = 1 n − y l o g y ^ i loss = \sum_{i=1}^{n}-ylog\widehat{y}_i loss=i=1nylogy i
上面的核心逻辑跟 get_masked_lm_output 一模一样。只不过这里的 loss 用的是平均值,没有用加权平均

2.4 训练

计算了 masked_lm_loss 以及 next_sentence_loss 之后,将两种 loss 相加,即是总的 loss
在这里插入图片描述
后续就训练模型降低 loss

3. 参考

《NLP深入学习:结合源码详解 BERT 模型(一)》
《NLP深入学习:结合源码详解 BERT 模型(二)》
《BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(一)》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(二)》

欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎:SmallerFL;

也欢迎关注我的wx公众号:一个比特定乾坤

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/496558.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024年N1叉车司机证考试题库及N1叉车司机试题解析

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年N1叉车司机证考试题库及N1叉车司机试题解析是安全生产模拟考试一点通结合&#xff08;安监局&#xff09;特种作业人员操作证考试大纲和&#xff08;质检局&#xff09;特种设备作业人员上岗证考试大纲随机出的…

串口通信标准RS232 RS485 RS422的区别

RS-232、RS-422、RS-485是关于串口通讯的一个机械和电气接口标准&#xff08;顶多是网络协议中的物理层&#xff09;&#xff0c;不是通讯协议&#xff0c;它们之间的几个不同点如下&#xff1a; 一、硬件管脚接口定义不同 二、工作方式不同 RS232&#xff1a; 3线全双工 RS…

在线教学软件推荐!一站式白板让线上教学更顺畅!

可以用于线上教学的软件&#xff0c;之前大家最为熟悉的莫过于使用各类视频会议软件&#xff0c;如腾讯会议、钉钉会议、飞书会议、Zoom 等&#xff0c;基于视频会议软件来共享电脑屏幕&#xff0c;然后再切换到本地的 PPT 演示文稿进行讲解。 但采用这个线上教学方案存在一些…

如何用磁力仪探测管缆的位置和埋深?

不论是航空磁测&#xff0c;还是海洋磁测&#xff0c;都是直接测量磁场总强度T&#xff0c;而后以总磁异常ΔT成图。磁异常总强度Ta是磁场总强度T与正常场T0的矢量差&#xff0c;即&#xff1a; Ta&#xff1d; T&#xff0d; T0 根据参考文献1&#xff0c;2的推导&#xff0c…

2024信息通信展览会|中国通信展览会|通讯大会

2024信息通信展览会|中国通信展览会|通讯大会 2024年中国国际信息通信展览会与同期举办的ICT.中国论坛于2024年9月25-27日在北京.国家会议中心隆重举办&#xff0c;共同奋力开启信息通信的新篇章。这是一场集交流、展示、共赢于一体的盛大盛典&#xff0c;为信息通信领域的企业…

保研线性代数机器学习基础复习2

1.什么是群&#xff08;Group&#xff09;&#xff1f; 对于一个集合 G 以及集合上的操作 &#xff0c;如果G G-> G&#xff0c;那么称&#xff08;G&#xff0c;&#xff09;为一个群&#xff0c;并且满足如下性质&#xff1a; 封闭性&#xff1a;结合性&#xff1a;中性…

一种重要却容易被我们忽略的能力

你有多久没有「发呆」过了&#xff1f; 我指的不是那种偶尔的走神和分心&#xff0c;而是那种持续一段时间&#xff0c;什么也不做、什么也不想&#xff0c;就这样静静站着或坐着&#xff0c;让大脑放空的状态。 可能有人会觉得&#xff1a;这太奢侈了&#xff0c;我们每天都恨…

【任职资格】某大型制造型企业任职资格体系项目纪实

该企业以业绩、责任、能力为导向&#xff0c;确定了分层分类的整体薪酬模式&#xff0c;但是每一名员工到底应该拿多少工资&#xff0c;同一个岗位的人员是否应该拿同样的工资是管理人员比较头疼的事情。华恒智信顾问认为&#xff0c;通过任职资格评价能实现真正的人岗匹配&…

基于Spring boot + Vue协同过滤算法的电影推荐系统

末尾获取源码作者介绍&#xff1a;大家好&#xff0c;我是墨韵&#xff0c;本人4年开发经验&#xff0c;专注定制项目开发 更多项目&#xff1a;CSDN主页YAML墨韵 学如逆水行舟&#xff0c;不进则退。学习如赶路&#xff0c;不能慢一步。 目录 一、项目简介 二、开发技术与环…

分享一下自己成功入职为AIGC工程师的经历

据外媒援引知情人士消息&#xff0c;OpenAI预计2023年收入将达到2亿美元&#xff0c;到2024年将达到10亿美元&#xff0c;全世界都看出了AIGC工程师的市场潜力。 而对于广大职场人士而言&#xff0c;则是意味着新的职场机遇出现了&#xff0c;学习好AIGC技术&#xff0c;无论是…

gemma 大模型(gemma 2B,gemma 7B)微调及基本使用

待整理… gemma介绍 Gemma是Google推出的一系列轻量级、最先进的开放模型&#xff0c;基于创建Gemini模型的相同研究和技术构建。提供了 2B 和 7B 两种不同规模的版本&#xff0c;每种都包含了预训练基础版本和经过指令优化的版本。所有版本均可在各类消费级硬件上运行&#x…

ThreadLocal和Synchronized的区别

目录 背景过程ThreadLocal什么是ThreadLocal&#xff1f;既然都是保证线程访问的安全性&#xff0c;那么和Synchronized区别是什么呢&#xff1f;ThreadLocal的使用TheadLocal使用场景原理高并发场景下ThreadLocal会造成内存泄漏吗&#xff1f;什么原因导致&#xff1f;如何避免…

aws 入门篇 01.aws学习的方法论

aws入门篇 01.aws学习的方法论 第1章 aws学习的方法论 aws的服务很多&#xff0c;现在应该有100多个服务了&#xff0c;怎么来学习aws呢&#xff1f; 这几年也使用了一些aws的服务&#xff0c;谈谈自己对学习aws的理解。 1.先横向&#xff0c;后纵深 比如说&#xff0c;aws最…

SpringCloud微服务集成Dubbo

1、Dubbo介绍 Apache Dubbo 是一款易用、高性能的 WEB 和 RPC 框架,同时为构建企业级微服务提供服务发现、流量治理、可观测、认证鉴权等能力、工具与最佳实践。用于解决微服务架构下的服务治理与通信问题,官方提供了 Java、Golang 等多语言 SDK 实现。使用 Dubbo 开发的微服…

手撕算法-最小覆盖子串

描述 分析 滑动窗口。 参考力扣官方的题解思路 本问题要求我们返回字符串 s 中包含字符串 t 的全部字符的最小窗口。我们称包含 t 的全部字母的窗口为「可行」窗口。 我们可以用滑动窗口的思想解决这个问题。在滑动窗口类型的问题中都会有两个指针&#xff0c;一个用于「延伸…

文件操作(下)(想要了解如何操作文件,那么看这一片就足够了!)

前言&#xff1a;在文件操作&#xff08;上&#xff09;中&#xff0c;我们讲到了基础的文件操作&#xff0c;包括文件的打开&#xff0c;文件的关闭&#xff0c;以及文件的基础读写&#xff0c;那么除了之前学习的读写之外&#xff0c;还有什么其他的方式对文件进行读写操作吗…

P5725 【深基4.习8】求三角形

【深基4.习8】求三角形 - 洛谷https://www.luogu.com.cn/problem/P5725 import java.util.*;public class Main {public static void main(String[] args) {Scanner sc new Scanner(System.in); // 创建一个 Scanner 对象来读取用户输入int n sc.nextInt(); // 从用户输入中…

Linux根据时间删除文件或目录

《liunx根据时间删除文件》和 《Linux 根据时间删除文件或者目录》已经讲述了根据时间删除文件或目录的方法。 下面我做一些补充&#xff0c;讲述一个具体例子。以删除/home目录下的文件为例。 首先通过命令&#xff1a; ls -l --time-style"%Y-%m-%d %H:%M:%S"…

【数据结构与算法】快速排序(详解:快排的Hoare原版,挖坑法和双指针法|避免快排最坏时间复杂度的两种解决方案|小区间优化|非递归的快排)

引言 快速排序作为交换排序的一种&#xff0c;在排序界的影响力毋庸置疑&#xff0c;我们C语言中用的qsort&#xff0c;C中用的sort&#xff0c;底层的排序方式都是快速排序。相比于同为交换排序的冒泡&#xff0c;其效率和性能就要差的多了&#xff0c;本篇博客就是要重点介绍…

2024 ccfcsp认证打卡 2023 03 01 田地丈量

import java.util.Scanner;public class Main {public static void main(String[] args) {Scanner in new Scanner(System.in);int n in.nextInt(); // 输入 n&#xff0c;表示矩形的数量int a in.nextInt(); // 输入 a&#xff0c;表示整个区域的长度int b in.nextInt()…
最新文章