用 Pytorch 训练一个 Transformer模型

昨天说了一下Transformer架构,今天我们来看看怎么 Pytorch 训练一个Transormer模型,真实训练一个模型是个庞大工程,准备数据、准备硬件等等,我只是做一个简单的实现。因为只是做实验,本地用 CPU 也可以运行。
本文包含以下几部分:

  1. 准备环境。
  2. 然后就是跟据架构来定义每一层,包括Embedding、Position Encoding、多头注意力、 网络层。
  3. 准备Encoder。
  4. 准备Decoder。
  5. 运行 Transformer,包括训练和评估。

安装Pytorch 环境

!pip3 install torch torchvision torchaudio

引入所需工具类库

引入需要的类库,pytorch 是强大的训练框架,深度学习中需要的一些函数和基本功能都已经实现。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

Position Embedding

生成位置信息。

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

多头注意力

  1. 初始化model 维度,头数,每个头的维度。
  2. 计算是在 forward 这个方法,主要看这个方法。
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
        
        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)
        
        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        # 转换,每一个 head 独立处理
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # 线性转换并切分
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # 运行计算公式
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # 合并并返回
        output = self.W_o(self.combine_heads(attn_output))
        return output

网络定义

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

Encoder

在这里插入图片描述
跟据这张图看下面的实现比较直观,初始化了MultiHeadAttention、PositionWiseFeedForward、两个LayerNorm。 forward 方法中 x 是 Encoder 的输入。

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

Decoder

在这里插入图片描述
看代码的方式和 Encoder 类似,比较好理解,2 个MultiHeadAttention、3个 Norm,forward 中cross_attn 把 enc_output作为传入的参数。

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

Transformer

Transformer主类,包括初始化 embedding、position embedding、encoder 和 decoder。forward 方法进行计算。

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

训练

首先准备数据,这里的数据是随机生成,只是做演示。

src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

开始训练

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

评估

将模型运行在验证集或者测试集上,这里数据也是随机生成的,只为体验一下完整流程。

transformer.eval()

# Generate random sample validation data
val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

with torch.no_grad():

    val_output = transformer(val_src_data, val_tgt_data[:, :-1])
    val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))
    print(f"Validation Loss: {val_loss.item()}")

如果你对 Pytorch 和神经网络比较熟悉,Transformer整体实现起来并不复杂,如果想我一样对深度学习不太熟悉,理解起来还是有些困难,这里只是大概跑了一下流程,对Transformer训练有一个概念。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/557764.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++ STL 容器 vector

目录 1. vector 对象2. vector 大小 size 和 容量 capacity3. vector 成员函数3.1 迭代器3.2 容量3.3 元素访问3.4 插入3.5 删除3.6 动态扩充与收缩 4. vector 迭代器失效问题总结其他补充 本文测试环境为 编译器 gcc 13.1 vector 是 STL 中的一个顺序容器,它给我们…

如何将静态网页资源“打包“成.exe或者.apk

Hello , 我是小恒不会java。最近有音乐播放器win桌面应用程序的需求,那就说说上手electron 又想到很多人对apk文件不太了解,apk文件就是安卓桌面应用程序,比如你手机现在打开的微信 当然,exe文件基本都清楚,windows可执…

正则表达式(Regular Expression)

正则表达式很重要,是一个合格攻城狮的必备利器,必须要学会!!! (参考视频)10分钟快速掌握正则表达式(奇乐编程学院)https://www.bilibili.com/video/BV1da4y1p7iZ在线测试…

React Hooks(常用)笔记

一、useState(保存组件状态) 1、基本使用 import { useState } from react;function Example() {const [initialState, setInitialState] useState(default); } useState(保存组件状态) :React hooks是function组件(无状态组件) &#xf…

再拓信创版图-Smartbi 与东方国信数据库完成兼容适配认证

近日,思迈特商业智能与数据分析软件 [简称:Smartbi Insight] V11与北京东方国信科技股份有限公司 (以下简称东方国信)CirroData-OLAP分布式数据库V2.14.1完成兼容性测试。经双方严格测试,两款产品能够达到通用兼容性要…

浪潮信息成功打造大规模、高性能、高可靠的单存储集群方案!

为帮助企业应对商业智能应用中面临的关于海量数据存储及实时分析的难题,浪潮信息日前通过技术研发,创新推出全球首个SAP HANA集群方案,该方案实现了最大可支持HANA集群服务器节点数量的翻倍,单存储即可支持16节点的,大…

图片高效批量管理,一键批量旋转150度,高效整理您的图片库

在数字化时代,我们的生活中充满了各种图片。从手机拍照到网络下载,从社交媒体到工作文档,图片无处不在。然而,随着图片数量的不断增加,如何高效管理这些图片,让它们有序、易于查找,成为了许多人…

Vue3从入门到实战:深度了解相关API

shallowRef 作用:创建一个响应式数据,但只对顶层属性进行响应式处理。 用法: let myVar shallowRef(initialValue); 特点:只跟踪引用值的变化,不关心值内部的属性变化。 shallowReactive 作用:创建一个…

【MySQL】表的基本约束

文章目录 1、约束类型1.1NOT NULL约束1.2UNIQUE:唯一约束1.3DEFAULT:默认值约束1.4PRIMARY KEY:主键约束1.5FOREIGN KEY:外键约束 2、表的设计2.1一对一2.2一对多2.3多对多 1、约束类型 关键字解释NOT NULL指示某列不能存储NULL值…

点赞列表查询列表

点赞列表查询列表 BlogController GetMapping("/likes/{id}") public Result queryBlogLikes(PathVariable("id") Long id) {return blogService.queryBlogLikes(id); }BlogService Override public Result queryBlogLikes(Long id) {String key BLOG_…

【C++航海王:追寻罗杰的编程之路】C++11(上)

目录 1 -> C11简介 2 -> 统一的列表初始化 2.1 -> {}初始化 2.2 -> std::initializer_list 3 -> 声明 3.1 -> auto 3.2 -> decltype 3.3 -> nullptr 1 -> C11简介 在2003年C标准委员会曾经提交了一份技术勘误表(简称TC1),使得C…

Debian

文章目录 前言一、使用root用户操作二、配置用户使用sudo命令三、添加桌面图标显示1.打开终端2.执行安装命令3.执行成功后重启4. 打开扩展,配置图标 四、图形化界面关闭和打开五、设置静态IP1.查询自己系统网络接口2.修改网络配置文件 总结 前言 Debian 系统在安装…

基于Springboot+Vue的Java项目-在线文档管理系统开发实战(附演示视频+源码+LW)

大家好!我是程序员一帆,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &am…

RUOYI 若依 横向菜单

保留移动端适配 小屏适配 菜单权限等 可轻松进行深度自定义菜单样式 以及分布 仅支持横向布局 如需源码 教程等 ➕ wx 技术支持 wx : 17339827025

【IEEE出版 | 中山大学主办 | 往届会后2-4个月EI检索】第五届电子通讯与人工智能学术会议(ICECAI 2024)

第五届电子通讯与人工智能国际学术会议(ICECAI 2024) 2024 5th International Conference on Electronic communication and Artificial Intelligence 第五届电子通讯与人工智能国际学术会议(ICECAI 2024)将于2024年5月31日-6月…

淘宝订单交易详情查询API是淘宝开放平台提供的接口,可以通过该接口获取淘宝订单的详细信息。

淘宝订单交易详情查询API是淘宝开放平台提供的接口,可以通过该接口获取淘宝订单的详细信息。通过该API,你可以获取订单的基本信息、商品信息、买家信息、物流信息等。 具体使用该API需要进行以下步骤: 在淘宝开放平台注册开发者账号&#xf…

QA测试开发工程师面试题满分问答15: 讲一讲InnoDB和MyISAM

InnoDB和MyISAM是MySQL中两种常见的存储引擎,它们在数据存储和处理方面有着显著的区别。让我们逐一来看一下它们的区别、原理以及适用场景。 区别: 事务支持:InnoDB是一个支持事务的存储引擎,而MyISAM不支持事务。事务是一种用于维…

L2-045 堆宝塔

L2-045 堆宝塔 分数 25 全屏浏览 切换布局 作者 陈越 单位 浙江大学 堆宝塔游戏是让小朋友根据抓到的彩虹圈的直径大小,按照从大到小的顺序堆起宝塔。但彩虹圈不一定是按照直径的大小顺序抓到的。聪明宝宝采取的策略如下: 首先准备两根柱子&#xff…

C++运算符重载和日期类的实现

运算符重载 参数个数与操作个数应该一致(双目操作符就是2个参数,同时参数中包括this) 不能被重载的运算符 " .* "运算符的作用 .*就是用来调用成员函数指针的 调用 1.显式调用 运算符重载可以显式调用 eg. 2.转换调用 运算符重载增强了程序的可读性 bool operato…

SpringBoot版本配置问题与端口占用

前言 ​ 今天在配置springboot项目时遇到了一些问题,jdk版本与springboot版本不一致,在使用idea的脚手架创建项目时,idea的下载地址是spring的官方网站,这导致所下载的版本都是比较高的,而我们使用最多的jdk版本是jdk…
最新文章