PyTorch从零开始实现Transformer

文章目录

    • 自注意力
    • Transformer块
    • 编码器
    • 解码器块
    • 解码器
    • 整个Transformer
    • 参考来源
    • 全部代码(可直接运行)

自注意力

计算公式

在这里插入图片描述

代码实现


class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0] # the number of training examples
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # 矩阵乘法,使用爱因斯坦标记法
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20")) 
            #Fills elements of self tensor with value where mask is True

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        ) # 矩阵乘法,使用爱因斯坦标记法einsum
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out

Transformer块

我们把Transfomer块定义为如下图所示的结构,这个Transformer块在编码器和解码器中都有出现过。
在这里插入图片描述

代码实现

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

编码器

编码器结构如下所示,Inputs经过Input Embedding 和Positional Encoding之后,通过多个Transformer块

在这里插入图片描述

代码实现

class Encoder(nn.Module):
    def __init__(self, 
                 src_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length
                 ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
            for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_lengh = x.shape
        positions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

解码器块

解码器块结构如下图所示

在这里插入图片描述

代码实现

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out

解码器

解码器块加上word embedding 和 positional embedding之后构成解码器

在这里插入图片描述

代码实现

class Decoder(nn.Module):
    def __init__(self, 
                 trg_vocab_size, 
                 embed_size, 
                 num_layers, 
                 heads, 
                 forward_expansion, 
                 dropout, 
                 device, 
                 max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
             for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

整个Transformer

在这里插入图片描述

代码实现

class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size, 
                 trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx,
                 embed_size=256,
                 num_layers=6,
                 forward_expansion=4,
                 heads=8,
                 dropout=0,
                 device="cuda",
                 max_length=100
                 ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
        )
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #(N, 1, 1, src_len)
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask,  trg_mask)
        return out
    

参考来源

[1] https://www.youtube.com/watch?v=U0s0f995w14
[2] https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py

[3] https://arxiv.org/abs/1706.03762
[4] https://www.youtube.com/watch?v=pkVwUVEHmfI

全部代码(可直接运行)

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0] # the number of training examples
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20")) 
            #Fills elements of self tensor with value where mask is True

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out
    
class Encoder(nn.Module):
    def __init__(self, 
                 src_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length
                 ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
            for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_lengh = x.shape
        positions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out
    
class Decoder(nn.Module):
    def __init__(self, 
                 trg_vocab_size, 
                 embed_size, 
                 num_layers, 
                 heads, 
                 forward_expansion, 
                 dropout, 
                 device, 
                 max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
             for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size, 
                 trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx,
                 embed_size=256,
                 num_layers=6,
                 forward_expansion=4,
                 heads=8,
                 dropout=0,
                 device="cuda",
                 max_length=100
                 ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
        )
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #(N, 1, 1, src_len)
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask,  trg_mask)
        return out
    
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
        device
    )
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10
    model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
    out = model(x, trg[:, :-1])
    print(out.shape)



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/37596.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RDS-Tools RDS-Knight Crack

RDS 高级安全性 利用全面的网络安全工具箱中有史以来最强大的安全功能集来保护您的 RDS 基础架构。 全方位 360 保护 无与伦比的功能集 无与伦比的物有所值 企业远程桌面安全。现代工作空间的智能解决方案。 办公室正在权力下放。远程办公室和移动员工数量创历史新高。随…

机器学习技术(四)——特征工程与模型评估

机器学习技术(四)——特征工程与模型评估(1️⃣) 文章目录 机器学习技术(四)——特征工程与模型评估(:one:)一、特征工程1、标准化2、特征缩放3、缩放有离群值的数据4、非线性转换5、样本归一化6、特征二值化7、标称特征编码(one-…

设计模式——命令模式

命令模式 定义 将一个请求封装成一个对象,从而让你使用不同的请求吧客户端参数化,对请求排队或者记录请求日志,可以提供命令的撤销和恢复功能。 命令模式是一个高内聚的模式。 优缺点、应用场景 优点 类间解耦。调用者与接收者之间没有任…

Linux系统使用(超详细)

目录 Linux操作系统简介 Linux和windows区别 Linux常见命令 Linux目录结构 Linux命令提示符 常用命令 ls cd pwd touch cat echo mkdir rm cp mv vim vim的基本使用 grep netstat Linux面试题 Linux操作系统简介 Linux操作系统是和windows操作系统是并列…

Github Pages使用自定义域名

Github Pages使用自定义域名 部署好网站后默认访问地址是xxx.github.io,我们想要自定义为自己的域名 1.DNS解析 这里我使用的是腾讯云,DNS解析DNSPod 添加两条解析记录: 第一个解析记录的记录类型为A,主机记录为,记录值为ping 你的github用户名.githu…

【Java】单例模式

单例模式 设计模式概述单例模式实现思路饿汉式懒汉式饿汉式 vs 懒汉式 设计模式概述 设计模式是在大量的实践中总结和理论化之后优选的代码结构、编程风格、以及解决问题的思考方式。设计模式免去我们自己再思考和摸索。就像是经典的棋谱,不同的棋局,我…

【unity之IMGUI实践】单例模式管理面板对象【一】

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:uni…

electron globalShortcut 快捷键与系统全局快捷键冲突

用 electron 开发自己的接口测试工具(Post Tools),在设置了 globalShortcut 快捷键后,发现应用中的快捷键与系统全局快捷键冲突了,导致系统快捷键不可正常使用。 快捷键配置 export function initGlobalShortcut(main…

【宝塔】宝塔部署ThinkPHP项目

最近搞了个培训教育的小程序,后端服务用的是ThinkPHP。使用的过程中,发现对于这种小项目用php还是很不错的选择,开发便捷,轻量级。宝塔神器也是很不错的,值得推荐使用。 下面介绍一下项目中用宝塔部署ThinkPHP项目&…

【菜鸟の笔记_利用Excel自动总结表格数据_自动链接word文本】

自动更新总结表格数据 1. 撰写原因2. 解决的问题3. Excel自动总结表格数据内容(一段话)。3.1问题引出3.2解决方式 4.Excel数据、总结内容,自动链接更新Word文本 1. 撰写原因 【GPT的答案】利用Excel自动总结表格数据有以下好处: …

redis浅析

一 什么是NoSQL? Nosql not only sql(不仅仅是SQL) 关系型数据库:列行,同一个表下数据的结构是一样的。 非关系型数据库:数据存储没有固定的格式,并且可以进行横向扩展。 NoSQL泛指非关系…

RocketMQ 为何性能高

本文主要从性能角度考虑 RocketMQ 的实现。 整体架构 这是网络上流行的 RocketMQ 的集群部署图。 RocketMQ 主要由 Broker、NameServer、Producer 和 Consumer 组成的一个集群。 **NameServer:整个集群的注册中心和配置中心,管理集群的元数据。包括 T…

解析Android VNDK/VSDK Snapshot编译框架

1.背景 背景一: 为解决Android版本碎片化问题,引入Treble架构,它提供了稳定的新SoC供应商接口,引入HAL 接口定义语言(HIDL/Stable AIDL,技术栈依然是Binder),它指定了 vendor HAL 和system fr…

容器化背后的魔法之Docker底层逻辑解密

Docker内部工作原理是怎样的? 现在我们知道了Docker是什么以及它提供了哪些好处,让我们逐个重要的细节来了解。 什么是容器?它们是如何工作的? 在深入研究Docker的内部机制之前,我们首先要了解容器的概念。简单地说…

从2023中国峰会,看亚马逊云科技的生成式AI战略

“生成式AI的发展就像一场马拉松比赛,当比赛刚刚开始时,如果只跑了三四步就断言某某会赢得这场比赛,显然是不合理的。我们现在还处于非常早期的阶段。” 近日,在2023亚马逊云科技中国峰会上,亚马逊云科技全球产品副总裁…

JSX的基础使用

1. JSX嵌入变量作为子元素的使用 ①当变量是Number、String、Array类型时,可以直接显示; ②当变量是null、undefined、Boolean类型时,内容为空; 若想要展示nul、undefined、Boolean类型,转字符串;转换方式…

增强型视觉系统 (EVS)

增强型视觉系统 EVS 1、增强型视觉系统概览2、车载相机 HAL2.1 EVS 应用2.2 EVS 管理器2.3 EVS HIDL 接口2.4 内核驱动程序 《增强型视觉系统 (EVS) 1.1 集成指南》 车载相机 HAL 1、增强型视觉系统概览 为了增强视频串流管理和错误处理,Android 11 更新了车载相机…

图像处理之图像灰度化

图像灰度化 将彩色图像转化成为灰度图像的过程成为图像的灰度化处理。彩色图像中的每个像素的颜色有R、G、B三个分量决定,而每个分量有255中值可取,这样一个像素点可以有1600多万 (255255255)的颜色的变化范用。而灰度图像是R、G、B三个分量相同的一种特…

纯css3实现小鸡从鸡蛋破壳而出动画特效

实现一个使用纯css3实现小鸡破壳的效果 示例效果如下所示 示例代码 <template><div><div class"eggWrapper"><div class"chickHead"><div class"eyeDiv"></div><div class"eyeDiv"></di…

vue3+mapboxgl鼠标浮动显示cgcs2000

一、需求 鼠标在地图中浮动展示地图的经纬度&#xff0c;cgcs2000 xy 还有显示带号 二、实现效果 展示经度&#xff0c;纬度&#xff0c;x值&#xff0c;y值显示的是带号和y值 三、思路 3.1、mapbox获取经纬度方法 初始化地图后.on方法中有个mousemove方法 mapboxUtil._m…
最新文章