【探索AI】二十一 深度学习之第4周:循环神经网络(RNN)与长短时记忆(LSTM)

循环神经网络(RNN)与长短时记忆(LSTM)

    • RNN的基本原理与结构
    • LSTM的原理与实现
    • 序列建模与文本生成任务
    • 实践:使用RNN或LSTM进行文本分类或生成任务
      • 步骤 1: 数据准备
      • 步骤 2: 构建模型
      • 步骤 3: 定义损失函数和优化器
      • 步骤 4: 训练模型
      • 步骤 5: 评估模型
      • 步骤 6: 使用模型进行文本生成

RNN的基本原理与结构

在这里插入图片描述

RNN,即循环神经网络(Recurrent Neural Network),是一种专门用于处理序列数据的神经网络。它的基本原理和结构主要基于以下几点:

基本原理:RNN的基本原理是,序列中的每个元素都与其前后元素存在某种关联或依赖,这种关联或依赖就是序列的时序关系。RNN通过捕捉并记忆这种时序关系,实现对序列数据的建模。RNN不是刚性地记忆所有固定长度的序列,而是通过隐藏状态来存储之前时间步的信息。
结构特点:RNN的结构特点主要体现在其循环性上。RNN的每个神经元不仅接收当前时刻的输入,还接收上一时刻的输出,并将其作为当前时刻的输入。这种结构使得RNN能够处理具有时序关系的数据,并且在处理过程中,RNN会不断地将之前时刻的信息传递到当前时刻,从而实现对序列数据的建模。
权重共享:RNN的另一个重要特点是权重共享。在RNN中,每个时刻的神经元都使用相同的权重,这意味着RNN在处理不同时刻的数据时,使用的是相同的参数。这种权重共享的方式大大减少了RNN的参数量,使得RNN能够更有效地处理序列数据。
综上所述,RNN的基本原理和结构使得它能够有效地处理具有时序关系的数据,实现对序列数据的建模和预测。这使得RNN在自然语言处理、时间序列分析等领域具有广泛的应用前景。

LSTM的原理与实现

LSTM(长短期记忆)是一种特殊的循环神经网络(RNN),设计用于解决传统RNN在处理长期依赖关系时遇到的问题。LSTM通过引入“门”的概念和细胞状态来实现这一点。

  1. 细胞状态与水平线:LSTM的关键在于细胞状态,它类似于传送带,直接在整个链上运行。这个状态只有少量的线性交互,因此信息在上面流传保持不变会很容易。

  2. 门结构:为了实现信息的添加或删除,LSTM使用了一种叫做“门”的结构。门可以实现选择性地让信息通过,这主要通过一个sigmoid神经层和一个逐点相乘的操作来实现。sigmoid层输出的每个元素都是一个在0和1之间的实数,表示让对应信息通过的权重。例如,0表示“不让任何信息通过”,1表示“让所有信息通过”。

  3. 三个门:LSTM通过三个这样的门结构来实现信息的保护和控制,分别是输入门、遗忘门和输出门。

    • 遗忘门:LSTM的第一步是决定从细胞状态中丢弃什么信息。这个决定通过一个称为忘记门层完成。该门会读取上一个时刻的隐藏状态 h t − 1 h_{t-1} ht1和当前时刻的输入 x t x_t xt,输出一个在0到1之间的数值给每个在细胞状态 C t − 1 C_{t-1} Ct1中的数字。1表示“完全保留”,0表示“完全舍弃”。
    • 输入门:负责处理当前时刻的输入,决定哪些信息需要被存储在细胞状态中。它包含两个步骤:首先,一个sigmoid层决定哪些信息需要更新;其次,一个tanh层生成新的候选值向量,这些值可能会被添加到状态中。
    • 输出门:基于细胞状态来决定当前的输出。它首先通过sigmoid层来决定细胞状态的哪些部分将输出到LSTM的当前输出值;然后,将细胞状态通过tanh进行处理(得到一个在-1到1之间的值),再与sigmoid门的输出相乘,从而得到最终的输出。
  4. 实现:在实现LSTM时,通常使用深度学习框架(如TensorFlow、PyTorch等)来构建网络。这些框架提供了高级的API,使得构建和训练LSTM模型变得相对简单。在实现过程中,需要定义网络结构、损失函数、优化器等,并进行模型的训练和评估。

总的来说,LSTM通过引入细胞状态和门结构,有效地解决了传统RNN在处理长期依赖关系时遇到的问题。这使得LSTM在许多序列处理任务中取得了显著的成果,如语音识别、机器翻译、情感分析等。

序列建模与文本生成任务

序列建模与文本生成任务是自然语言处理(NLP)领域中的两个重要概念。

序列建模是指在给定一组输入序列的情况下,预测或生成相应的输出序列。在机器学习和自然语言处理领域,序列建模问题被广泛应用。例如,在语音识别任务中,根据音频信号的输入序列预测对应的语音文本序列;在机器翻译任务中,根据源语言的输入序列生成目标语言的输出序列。序列建模面临着许多挑战,如数据稀疏性、长距离依赖、多模态输入等。为了解决这些问题,研究者们提出了许多有效的方法,如循环神经网络(RNN)、长短时记忆网络(LSTM)、门控循环单元(GRU)等。

文本生成任务是指通过计算机算法和模型,以一定的策略和规则生成特定领域的文本内容。文本生成任务在多个领域都有广泛的应用,如机器翻译、文本摘要、文本生成等。这些任务可以通过深度学习技术,如递归神经网络(RNN)特别是长短期记忆网络(LSTM)和门控循环单元(GRU)等生成模型来实现。生成模型根据输入信息生成文本,而训练数据则用于训练这些模型。损失函数用于评估模型的预测性能,并指导模型的优化。

综上所述,序列建模与文本生成任务是自然语言处理领域中两个密切相关的概念。序列建模为文本生成提供了基础和支持,而文本生成任务则是序列建模的一个重要应用领域。随着深度学习技术的不断发展,序列建模与文本生成任务将在更多的领域发挥重要作用。

实践:使用RNN或LSTM进行文本分类或生成任务

实践使用RNN或LSTM进行文本分类或生成任务需要一系列步骤。下面我将提供一个简单的指导,使用PyTorch库进行文本分类任务。请注意,为了执行此实践,您需要具备Python编程和PyTorch库的基础知识。

步骤 1: 数据准备

首先,您需要准备用于训练和测试的数据集。数据集应该包含文本数据和相应的标签。您可以将文本数据预处理为适合RNN或LSTM模型的形式。

# 导入必要的库
import torch
from torchtext.legacy import data
from torchtext.legacy import datasets

# 定义字段和文本预处理器
TEXT = data.Field(sequential=True, tokenize='spacy', lower=True)
LABEL = data.LabelField(dtype=torch.float)

# 创建数据管道
fields = [('text', TEXT), ('label', LABEL)]
train_data, test_data = datasets.TabularDataset.splits(
    path='.', train='train.csv', test='test.csv', format='csv',
    skip_header=True, fields=fields
)

# 构建词汇表
TEXT.build_vocab(train_data, max_size=25000, vectors="glove.6B.100d", unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

# 创建数据迭代器
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, test_iterator = data.BucketIterator.splits((train_data, test_data), batch_size=batch_size, device=device)

步骤 2: 构建模型

接下来,您需要定义RNN或LSTM模型。下面是一个简单的LSTM模型的例子。

import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super(LSTMModel, self).__init__()
        
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, text):
        embedded = self.embedding(text)
        output, (hidden, cell) = self.lstm(embedded)
        assert torch.equal(hidden[-1,:,:], cell[-1,:,:])
        return self.fc(hidden[-1,:,:])

input_dim = len(TEXT.vocab)
embedding_dim = 100
hidden_dim = 256
output_dim = 1  # 假设是二分类问题

model = LSTMModel(input_dim, embedding_dim, hidden_dim, output_dim)

步骤 3: 定义损失函数和优化器

选择适当的损失函数和优化器。对于分类任务,通常使用交叉熵损失(CrossEntropyLoss)和Adam优化器。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

步骤 4: 训练模型

现在,您可以开始训练模型了。

num_epochs = 10

for epoch in range(num_epochs):
    for batch in train_iterator:
        optimizer.zero_grad()
        predictions = model(batch.text).squeeze(1)
        loss = criterion(predictions, batch.label)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch: {epoch+1:02}, Loss: {loss.item():.4f}')

步骤 5: 评估模型

在训练完成后,您可以使用测试数据集评估模型的性能。

model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for batch in test_iterator:
        predictions = model(batch.text).squeeze(1)
        predicted = torch.argmax(predictions, dim=1)
        correct += (predicted == batch.label).sum().item()
        total += batch.label.size(0)

print(f'Accuracy of the network on the test data: {100 * correct / total:.2f}%')

步骤 6: 使用模型进行文本生成

在步骤6中,我们将使用之前训练好的模型来进行文本生成。这通常涉及到将模型设置为评估模式(evaluation mode),然后提供一个初始的文本片段作为种子(seed),让模型从这个种子开始生成后续的文本。

以下是一个使用PyTorch和LSTM模型进行文本生成的例子:

import torch
from torch import nn
from torch.autograd import Variable
from model import LSTM, vocab  # 假设你有一个名为LSTM的模型和名为vocab的词汇表

# 假设我们有一个训练好的LSTM模型
model = LSTM(vocab_size=len(vocab), embedding_dim=256, hidden_dim=512, num_layers=2)
model.load_state_dict(torch.load('path_to_saved_model.pt'))  # 加载预训练模型
model.eval()  # 将模型设置为评估模式

# 设置超参数
max_length = 100  # 生成文本的最大长度
starting_text = "I enjoy"  # 初始文本种子
temperature = 1.0  # 控制生成文本多样性的参数(softmax的温度参数)

# 将初始文本转换为模型可以理解的格式
starting_text = starting_text.lower().replace('.', ' .')  # 将所有文本转换为小写,并在句末添加空格
starting_text = [vocab.stoi[word] for word in starting_text.split()]  # 将单词转换为索引
starting_text = torch.LongTensor(starting_text).to(device)  # 转换为PyTorch张量并移到设备上

# 初始化隐藏状态
hidden = model.init_hidden(starting_text.size(0))

# 开始生成文本
generated_text = []
for i in range(max_length):
    # 获取模型的输出
    output, hidden = model(starting_text, hidden)
    
    # 使用softmax函数获取预测单词的概率分布
    predicted = torch.multinomial(output.squeeze(1), num_samples=1)
    
    # 根据概率分布选择一个单词
    predicted_index = predicted.item()
    
    # 将预测的单词添加到生成的文本中
    generated_text.append(predicted_index)
    
    # 将预测的单词作为下一个输入
    starting_text = Variable(torch.LongTensor([predicted_index]).to(device))
    
    # 如果生成的单词是结束标记,则停止生成
    if vocab.itos[predicted_index] == '<eos>':
        break

# 将生成的索引转换为文本
generated_text = [vocab.itos[idx] for idx in generated_text]
generated_sentence = ' '.join(generated_text)

# 打印生成的文本
print(generated_sentence)

在上面的代码中,我们首先将预训练的模型加载到内存中,并将其设置为评估模式。然后,我们定义了生成文本的超参数,包括最大长度和初始文本种子。

我们将初始文本转换为模型可以理解的格式,即单词的索引序列。然后,我们初始化模型的隐藏状态,并开始循环生成文本。在每个循环中,我们将当前生成的文本作为输入传递给模型,并获取模型的输出。我们使用torch.multinomial函数根据模型的输出概率分布选择一个单词,并将其添加到生成的文本中。如果生成的单词是结束标记(例如<eos>),则我们停止生成。

最后,我们将生成的索引序列转换回文本,并打印生成的文本。

请注意,生成的文本的质量和多样性取决于许多因素,包括模型的结构、训练数据的质量、训练时间以及超参数的选择(如温度参数)。在实际应用中,可能需要调整这些参数以获得最佳结果。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/429604.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

我选项目和做项目的两大准则

一、项目不在多&#xff0c;做精做透最重要 其实我们做生意做项目的一定要记住&#xff0c;手头上永远要有一样东西&#xff0c;即便是天塌下来你也能挣钱的&#xff0c;我经常称其为基本盘。比如我们手上的CSGO游戏搬砖&#xff0c;可能在经营这个项目的过程中会或多或少的接…

蓝桥杯练习题——dp

五部曲&#xff08;代码随想录&#xff09; 1.确定 dp 数组以及下标含义 2.确定递推公式 3.确定 dp 数组初始化 4.确定遍历顺序 5.debug 入门题 1.斐波那契数 思路 1.f[i]&#xff1a;第 i 个数的值 2.f[i] f[i - 1] f[i - 2] 3.f[0] 0, f[1] 1 4.顺序遍历 5.记得特判 …

项目管理:实现高效团队协作与成功交付的关键

在当今竞争激烈的市场环境中&#xff0c;项目管理已成为企业成功的关键因素之一。项目管理不仅涉及时间、成本和资源的有效管理&#xff0c;还涉及到团队协作、风险管理、沟通和交付。本文将探讨项目管理的核心要素&#xff0c;以及如何实现高效团队协作和成功交付。 一、明确项…

es集群的详细搭建过程

目录 一、VM配置二、集群搭建三、集群配置 一、VM配置 VM的安装 VMware Workstation 15 Pro的安装与破解 VM新建虚拟机 VM新建虚拟机 二、集群搭建 打开新建好的服务器&#xff0c;node1&#xff0c;使用xshell远程连接 下载es&#xff1a;https://www.elastic.co/cn/down…

【自然语言处理六-最重要的模型-transformer-下】

自然语言处理六-最重要的模型-transformer-下 transformer decoderMasked multi-head attentionencoder和decoder的连接部分-cross attentiondecoder的输出AT(Autoregresssive)NAT transformer decoder 今天接上一篇文章讲的encoder 自然语言处理六-最重要的模型-transformer-…

Carbondata编译适配Spark3

背景 当前carbondata版本2.3.1-rc1中项目源码适配的spark版本最高为3.1,我们需要进行spark3.3版本的编译适配。 原始编译 linux系统下载源码后&#xff0c;安装maven3.6.3&#xff0c;然后执行&#xff1a; mvn -DskipTests -Pspark-3.1 clean package会遇到一些网络问题&a…

SpringCloud-RabbitMQ消息模型

本文深入介绍了RabbitMQ消息模型&#xff0c;涵盖了基本消息队列、工作消息队列、广播、路由和主题等五种常见消息模型。每种模型都具有独特的特点和适用场景&#xff0c;为开发者提供了灵活而强大的消息传递工具。通过这些模型&#xff0c;RabbitMQ实现了解耦、异步通信以及高…

如何远程连接MySQL数据库?

在现代互联网时代&#xff0c;远程连接MySQL数据库成为了许多开发者和管理员必备的技能。这不仅方便了数据的共享和管理&#xff0c;还可以使多个团队在全球范围内协同工作。本文将介绍如何通过天联组网实现远程连接MySQL数据库&#xff0c;并实现高效的信息远程通信。 天联组网…

tomcat nginx 动静分离

实验目的:当访问静态资源的时候&#xff0c;nginx自己处理 当访问动态资源的时候&#xff0c;转给tomcat处理 第一步 关闭防火墙 关闭防护 代理服务器操作&#xff1a; 用yum安装nginx tomcat &#xff08;centos 3&#xff09;下载 跟tomcat&#xff08;centos 4&#xff0…

Shell管道和过滤器

一、Shell管道 Shell 还有一种功能&#xff0c;就是可以将两个或者多个命令&#xff08;程序或者进程&#xff09;连接到一起&#xff0c;把一个命令的输出作为下一个命令的输入&#xff0c;以这种方式连接的两个或者多个命令就形成了管道&#xff08;pipe&#xff09;。 重定…

关于 CTF 中 php 考点与绕过那些事的总结

关于 CTF 中常见 php 绕过的总结可以参考我之前的博客&#xff1a; CTF之PHP特性与绕过 PHP特性之CTF中常见的PHP绕过-CSDN博客 其中主要介绍了 md5()、sha1()、strcmp、switch、intval、$_SERVER 函数、三元运算符、strpos() 、数组、非法参数名传参等相关的绕过。 在此基础上…

vue点击按钮同时下载多个文件

点击下载按钮根据需要的id调接口拿到返回需要下载的文件 再看返回的数据结构 数组中一个对象&#xff0c;就是一个文件&#xff0c;多个对象就是多个文件 下载函数 // 下载tableDownload(row) {getuploadInventoryDownload({ sysBatch: row.sysBatch, fileName: row.fileName…

Linux 进程间通信

目录 管道 匿名管道&#xff08;pipe&#xff09; 有名管道&#xff08;fifo&#xff09; 小结 共享内存 消息队列 信号量 System V IPC的结构设计 Posix与System V的关系 管道 匿名管道&#xff08;pipe&#xff09; 我们知道&#xff0c;在Linux中通过fork创建的子…

YOLOv5优化改进:下采样创新篇 | 新颖的下采样ADown | YOLOv9

💡💡💡本文独家改进:新颖的下采样ADown来自于YOLOv9,助力YOLOv5,将ADown添加在backbone和head处,提供多个yaml改进方法 💡💡💡在多个私有数据集和公开数据集VisDrone2019、PASCAL VOC实现涨点 收录 YOLOv5原创自研 https://blog.csdn.net/m0_63774211/categ…

MongoDB 快速入门

&#x1f4d5;作者简介&#xff1a; 过去日记&#xff0c;致力于Java、GoLang,Rust等多种编程语言&#xff0c;热爱技术&#xff0c;喜欢游戏的博主。 &#x1f4d7;本文收录于MongoDB系列&#xff0c;大家有兴趣的可以看一看 &#x1f4d8;相关专栏Rust初阶教程、go语言基础…

SpringMVC--03--前端传数组给后台

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 案例1乘客个人信息方法1&#xff1a;表单提交&#xff0c;以字段数组接收方法2&#xff1a;表单提交&#xff0c;以BeanListModel接收方法3&#xff1a;将Json对象序…

电脑黑屏如何重装系统 电脑黑屏安装系统操作方法

据了解,75%以上的用户在使用电脑时都有碰到黑屏的现象,而电脑黑屏不但会影响自己的工作,而且还会影响自己的心情,因此,不可马虎,那么,应该怎么办呢?下面我们就来详细介绍一下 据了解,75%以上的用户在使用电脑时都有碰到黑屏的现象,而电脑黑屏不但会影响自己的工作,而…

vue3三级嵌套复选框(element-plus)

一、功能描述 当选择第一级的复选框时下面所有内容全选和取消全选&#xff0c;当选择第二的复选框时第三级的所有内容全选和取消全选。只要有一个第三级的内容没有选&#xff0c;二级和一级则不能勾上。第三级内容全选上了&#xff0c;第二级复选框就钩上。第二级也是同样的道理…

使用GitHub API 查询开源项目信息

一、GitHub API介绍 GitHub API 是一组 RESTful API 接口&#xff0c;用于与 GitHub 平台进行交互。通过使用 GitHub API&#xff0c;开发人员可以访问和操作 GitHub 平台上的各种资源&#xff0c;如仓库、提交记录、问题等。 GitHub API 提供了多种功能和端点&#xff0c;以…

HTTP有什么缺陷,HTTPS是怎么解决的

缺陷 HTTP是明文的&#xff0c;谁都能看得懂&#xff0c;HTTPS是加了TLS/SSL加密的&#xff0c;这样就不容易被拦截和攻击了。 SSL是TLS的前身&#xff0c;他俩都是加密安全协议。前者大部分浏览器都不支持了&#xff0c;后者现在用的多。 对称加密 通信双方握有加密解密算法…