PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型预测)

文章目录

  • 引言
  • 数据预处理
    • 加载字典对象`en2id`和`zh2id`
    • 文本分词
  • 加载训练好的Seq2Seq模型
  • 模型预测完整代码
  • 结束语

引言

随着全球化的深入,翻译需求日益增长。传统的人工翻译方式虽然质量高,但效率低,成本高。机器翻译的出现,为解决这一问题提供了可能。英译中机器翻译任务是机器翻译领域的一个重要分支,旨在将英文文本自动翻译成中文。本博客以《PyTorch自然语言处理入门与实战》第九章的Seq2seq模型处理英译中翻译任务作为基础,附上模型预测模块。

模型的训练及验证模块的详细解析见PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型训练及验证)

数据预处理

加载字典对象en2idzh2id

在预测阶段中,需要加载模型训练及验证阶段保存的字典对象en2idzh2id

代码如下:

import pickle

with open("en2id.pkl", 'rb') as f:
    en2id = pickle.load(f)
with open("zh2id.pkl", 'rb') as f:
    zh2id = pickle.load(f)

文本分词

在对输入文本进行预测时,需要先将文本进行分词操作。参考代码如下:

def extract_words(sentence):  
    """  
    从给定的英文句子中提取单词,并去除单词后的标点符号。  
      
    Args:  
        sentence (str): 要提取单词的英文句子。  
          
    Returns:  
        List[str]: 提取并处理后的单词列表。  
    """  
    en_words = []  
    for w in sentence.split(' '):  # 将英文句子按空格分词  
        w = w.replace('.', '').replace(',', '')  # 去除跟单词连着的标点符号  
        w = w.lower()  # 统一单词大小写  
        if w:  
            en_words.append(w)  
    return en_words  
  
# 测试函数  
sentence = 'I am Dave Gallo.'  
print(extract_words(sentence))

运行结果:

加载训练好的Seq2Seq模型

代码如下:

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = (src len, batch size)
        embedded = self.dropout(self.embedding(src))
        # embedded = (src len, batch size, emb dim)
        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs = (src len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)
        # rnn的输出总是来自顶部的隐藏层
        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # 各输入的形状
        # input = (batch size)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # LSTM是单向的  ==> n directions == 1
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)

        embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)

        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        # LSTM理论上的输出形状
        # output = (seq len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # 解码器中的序列长度 seq len == 1
        # 解码器的LSTM是单向的 n directions == 1 则实际上
        # output = (1, batch size, hid dim)
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        prediction = self.fc_out(output.squeeze(0))

        # prediction = (batch size, output dim)

        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,
                 encode_dropout, decode_dropout, device):
        """

        :param input_word_count:    英文词表的长度     34737
        :param output_word_count:   中文词表的长度     4015
        :param encode_dim:          编码器的词嵌入维度
        :param decode_dim:          解码器的词嵌入维度
        :param hidden_dim:          LSTM的隐藏层维度
        :param n_layers:            采用n层LSTM
        :param encode_dropout:      编码器的dropout概率
        :param decode_dropout:      编码器的dropout概率
        :param device:              cuda / cpu
        """
        super().__init__()
        self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)
        self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)
        self.device = device

    def forward(self, src):
        # src = (src len, batch size)

        # 编码器的隐藏层输出将作为解码器的第一个隐藏层输入
        hidden, cell = self.encoder(src)

        # 解码器的第一个输入应该是起始标识符<sos>
        input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引
        pred = [0] # 预测的第一个输出应该是起始标识符
        top1 = 0
        while top1 != 1 and len(pred) < 100:
            # 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states
            # 解码器的输出包括:输出张量(predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            top1 = output.argmax(dim=1)  # (batch size, )
            pred.append(top1.item())
            input = top1

        return pred

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU
# Seq2Seq模型实例化
source_word_count = 34737  # 英文词表的长度     34737
target_word_count = 4015  # 中文词表的长度     4015
encode_dim = 256  # 编码器的词嵌入维度
decode_dim = 256  # 解码器的词嵌入维度
hidden_dim = 512  # LSTM的隐藏层维度
n_layers = 2  # 采用n层LSTM
encode_dropout = 0.5  # 编码器的dropout概率
decode_dropout = 0.5  # 编码器的dropout概率
model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,
                decode_dropout, device).to(device)

# 加载训练好的模型
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

模型预测完整代码

提示预测代码是我们基于训练及验证代码进行改造的,不一定完全正确,可以参考后自行修改~

import torch
import torch.nn as nn
import pickle


class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = (src len, batch size)
        embedded = self.dropout(self.embedding(src))
        # embedded = (src len, batch size, emb dim)
        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs = (src len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)
        # rnn的输出总是来自顶部的隐藏层
        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # 各输入的形状
        # input = (batch size)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # LSTM是单向的  ==> n directions == 1
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)

        embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)

        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        # LSTM理论上的输出形状
        # output = (seq len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # 解码器中的序列长度 seq len == 1
        # 解码器的LSTM是单向的 n directions == 1 则实际上
        # output = (1, batch size, hid dim)
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        prediction = self.fc_out(output.squeeze(0))

        # prediction = (batch size, output dim)

        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,
                 encode_dropout, decode_dropout, device):
        """

        :param input_word_count:    英文词表的长度     34737
        :param output_word_count:   中文词表的长度     4015
        :param encode_dim:          编码器的词嵌入维度
        :param decode_dim:          解码器的词嵌入维度
        :param hidden_dim:          LSTM的隐藏层维度
        :param n_layers:            采用n层LSTM
        :param encode_dropout:      编码器的dropout概率
        :param decode_dropout:      编码器的dropout概率
        :param device:              cuda / cpu
        """
        super().__init__()
        self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)
        self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)
        self.device = device

    def forward(self, src):
        # src = (src len, batch size)

        # 编码器的隐藏层输出将作为解码器的第一个隐藏层输入
        hidden, cell = self.encoder(src)

        # 解码器的第一个输入应该是起始标识符<sos>
        input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引
        pred = [0] # 预测的第一个输出应该是起始标识符
        top1 = 0
        while top1 != 1 and len(pred) < 100:
            # 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states
            # 解码器的输出包括:输出张量(predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            top1 = output.argmax(dim=1)  # (batch size, )
            pred.append(top1.item())
            input = top1

        return pred


if __name__ == '__main__':
    sentence = 'I am Dave Gallo.'
    en_words = []

    for w in sentence.split(' '):  # 英文内容按照空格字符进行分词
        # 按照空格进行分词后,某些单词后面会跟着标点符号 "." 和 “,”
        w = w.replace('.', '').replace(',', '')  # 去掉跟单词连着的标点符号
        w = w.lower()  # 统一单词大小写
        if w:
            en_words.append(w)

    print(en_words)

    with open("en2id.pkl", 'rb') as f:
        en2id = pickle.load(f)
    with open("zh2id.pkl", 'rb') as f:
        zh2id = pickle.load(f)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU
    # Seq2Seq模型实例化
    source_word_count = 34737  # 英文词表的长度     34737
    target_word_count = 4015  # 中文词表的长度     4015
    encode_dim = 256  # 编码器的词嵌入维度
    decode_dim = 256  # 解码器的词嵌入维度
    hidden_dim = 512  # LSTM的隐藏层维度
    n_layers = 2  # 采用n层LSTM
    encode_dropout = 0.5  # 编码器的dropout概率
    decode_dropout = 0.5  # 编码器的dropout概率
    model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,
                    decode_dropout, device).to(device)

    model.load_state_dict(torch.load("best_model.pth"))
    model.eval()

    src = [0] # 0 --> 起始标识符的编码
    for i in range(len(en_words)):
        src.append(en2id[en_words[i]])
    src = src + [1] # 1 --> 终止标识符的编码

    text_input = torch.LongTensor(src)
    text_input = text_input.unsqueeze(-1).to(device)

    text_output = model(text_input)
    print(text_output)
    id2zh = dict()
    for k, v in zh2id.items():
        id2zh[v] = k

    text_output = [id2zh[index] for index in text_output]
    text_output = " ".join(text_output)
    print(text_output)

结束语

  • 亲爱的读者,感谢您花时间阅读我们的博客。我们非常重视您的反馈和意见,因此在这里鼓励您对我们的博客进行评论。
  • 您的建议和看法对我们来说非常重要,这有助于我们更好地了解您的需求,并提供更高质量的内容和服务。
  • 无论您是喜欢我们的博客还是对其有任何疑问或建议,我们都非常期待您的留言。让我们一起互动,共同进步!谢谢您的支持和参与!
  • 我会坚持不懈地创作,并持续优化博文质量,为您提供更好的阅读体验。
  • 谢谢您的阅读!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/274397.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

八股文打卡day12——计算机网络(12)

面试题&#xff1a;HTTPS的工作原理&#xff1f;HTTPS是怎么建立连接的&#xff1f; 我的回答&#xff1a; 1.客户端向服务器发起请求&#xff0c;请求建立连接。 2.服务器收到请求之后&#xff0c;向客户端发送其SSL证书&#xff0c;这个证书包含服务器的公钥和一些其他信息…

Python 网络编程之搭建简易服务器和客户端

用Python搭建简易的CS架构并通信 文章目录 用Python搭建简易的CS架构并通信前言一、基本结构二、代码编写1.服务器端2.客户端 三、效果展示总结 前言 本文主要是用Python写一个CS架构的东西&#xff0c;包括服务器和客户端。程序运行后在客户端输入消息&#xff0c;服务器端会…

文件操作安全之-目录穿越流量告警运营分析篇

本文从目录穿越的定义,目录穿越的多种编码流量数据包示例,目录穿越的suricata规则,目录穿越的告警分析研判,目录穿越的处置建议等几个方面阐述如何通过IDS/NDR,态势感知等流量平台的目录穿越类型的告警的线索,开展日常安全运营工作,从而挖掘有意义的安全事件。 目录穿越…

Quartus的Signal Tap II的使用技巧

概述&#xff1a; Signal Tap II全称Signal Tap II Logic Analyzer&#xff0c;是第二代系统级调试工具&#xff0c;它集成在Quartus II软件中&#xff0c;可以捕获和显示实时信号&#xff0c;是一款功能强大、极具实用性的FPGA片上调试工具软件。 传统的FPGA板级调试是由外接…

TCP的三次握手

TCP 是一种面向连接的单播协议&#xff0c;在发送数据前&#xff0c;通信双方必须在彼此间建立一条连接。所谓的“连接”&#xff0c;其实是客户端和服务器的内存里保存的一份关于对方的信息&#xff0c;如 IP 地址、端口号等。 TCP 可以看成是一种字节流&#xff0c;它…

《Spring Cloud学习笔记:微服务保护Sentinel》

Review 解决了服务拆分之后的服务治理问题&#xff1a;Nacos解决了服务治理问题OpenFeign解决了服务之间的远程调用问题网关与前端进行交互&#xff0c;基于网关的过滤器解决了登录校验的问题 流量控制&#xff1a;避免因为突发流量而导致的服务宕机。 隔离和降级&#xff1a…

hadoop hive spark flink 安装

下载地址 Index of /dist ubuntu安装hadoop集群 准备 IP地址主机名称192.168.1.21node1192.168.1.22node2192.168.1.23node3 上传 hadoop-3.3.5.tar.gz、jdk-8u391-linux-x64.tar.gz JDK环境 node1、node2、node3三个节点 解压 tar -zxvf jdk-8u391-linux-x64.tar.gz…

苹果cmsV10蜘蛛统计插件+集合采集插件

苹果cmsV10蜘蛛统计插件集合采集插件 安装苹果cms盒子方法&#xff1a; 1.下载到的盒子客户端压缩包内拥有一个application文件夹&#xff0c;直接上传到网站根目录中。 2.添加苹果cms盒子快捷菜单&#xff1a;苹果cms盒子,macBox/stylelist 相信做网站的都想要百度 搜狗 3…

RabbitMQ 和 Kafka 对比

本文对RabbitMQ 和 Kafka 进行下比较 文章目录 前言RabbitMQ架构队列消费队列生产 Kafka本文小结 前言 开源社区有好多优秀的队列中间件&#xff0c;比如RabbitMQ和Kafka&#xff0c;每个队列都貌似有其特性&#xff0c;在进行工程选择时&#xff0c;往往眼花缭乱&#xff0c;不…

浅谈WPF之ToolTip工具提示

在日常应用中&#xff0c;当鼠标放置在某些控件上时&#xff0c;都会有相应的信息提示&#xff0c;从软件易用性上来说&#xff0c;这是一个非常友好的功能设计。那在WPF中&#xff0c;如何进行控件信息提示呢&#xff1f;这就是本文需要介绍的ToolTip【工具提示】内容&#xf…

【INTEL(ALTERA)】如何使用Tcl打开quartus IP自带的例程

前言 很多INTEL&#xff08;ALTERA&#xff09; IP生成的时候会自带例程&#xff0c;如LVDS SERDES IP&#xff0c;在菜单Generate中可以选择生成官方例程。 之后会在IP所在目录下生产【lvds_0_example_design】文件夹&#xff0c;但在这个文件夹中并没有FPGA工程。 例程在哪&…

【Linux】 last 命令使用

last 命令 用于检索和展示系统中用户的登录信息。它从/var/log/wtmp文件中读取记录&#xff0c;并将登录信息按时间顺序列出。 著者 Miquel van Smoorenburg 语法 last [-R] [-num] [ -n num ] [-adiox] [ -f file ] [name...] [tty...]last 命令 -Linux手册页 选项及作用…

Flask登陆后登陆状态及密码的修改和处理

web/templates/common 是统一布局 登录成功 后flask框架服务器默认由login.html进入仪表盘页面index.html(/),该页面的设置在 (web/controllers/user/index.py)&#xff0c;如果想在 该仪表盘页面 将 用户信息 展示出来&#xff0c;就得想办法先获取到 当前用户的 登陆状态。…

Flutter配置Android和IOS允许http访问

默认情况下&#xff0c;Android和IOS只支持对https的访问&#xff0c;如果需要访问不安全的连接&#xff0c;也就是http&#xff0c;需要做以下配置。 Android 在res目录下的xml目录中(如果不存在&#xff0c;先创建xml目录)&#xff0c;创建一个xml文件network_security_con…

动态规划 多源路径 字典树 LeetCode2977:转换字符串的最小成本

涉及知识点 动态规划 多源最短路径 字典树 题目 给你两个下标从 0 开始的字符串 source 和 target &#xff0c;它们的长度均为 n 并且由 小写 英文字母组成。 另给你两个下标从 0 开始的字符串数组 original 和 changed &#xff0c;以及一个整数数组 cost &#xff0c;其中…

Python+OpenGL绘制3D模型(六)材质文件载入和贴图映射

系列文章 一、逆向工程 Sketchup 逆向工程&#xff08;一&#xff09;破解.skp文件数据结构 Sketchup 逆向工程&#xff08;二&#xff09;分析三维模型数据结构 Sketchup 逆向工程&#xff08;三&#xff09;软件逆向工程从何处入手 Sketchup 逆向工程&#xff08;四&#xf…

家庭教育小妙招让孩子做事学会坚持到底

在生活中&#xff0c;我们常常会遇到一些孩子&#xff0c;他们做事情总是三分钟热度&#xff0c;不能坚持到底。而在面对困难时&#xff0c;他们很容易选择放弃。 看似是孩子缺乏热情&#xff0c;其实这是孩子缺乏自制力和毅力的表现。作为家长&#xff0c;我们需要培养孩子的…

开源项目推荐:Frooodle/Stirling-PDF

简介一个本地的处理 PDF 的工具&#xff0c;界面是 Web UI&#xff0c;可以支持 Docker 部署。各种主要的 PDF 操作都可以支持。比如拆分、合并、转换格式、重新排列、添加图片、旋转、压缩等等。这个本地托管的网络应用最初完全由 ChatGPT 制作&#xff0c;后来逐渐发展&#…

MySQL常用命令合集(Mac版)

mysql信息 MySQL位置 which mysql查看版本 mysql --version启动与关闭 使用mysql.server启用脚本来执行&#xff0c;默认在/usr/local/mysql/support-files这个目录中。 启动 sudo /usr/local/mysql/support-files/mysql.server start关闭 sudo /usr/local/mysql/suppor…

人体检测、跟踪实例 | 附代码

人体检测和跟踪是一项基本的计算机视觉任务&#xff0c;涉及在给定场景中识别并跟踪个体的移动。这项技术在各种实际应用中发挥着关键作用&#xff0c;从监视和安全到自动驾驶车辆和人机交互都有涉及。人体检测的主要目标是在图像或视频帧中定位和分类人体&#xff0c;而跟踪侧…