人工智能-循环神经网络的简洁实现

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

torch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # 它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens),
                                device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/179653.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于Vue和SpringBoot的学校热点新闻推送系统

项目编号: S 047 ,文末获取源码。 \color{red}{项目编号:S047,文末获取源码。} 项目编号:S047,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 新闻类型模块2.2 新闻档案模块2.3 新…

【深度学习】卷积神经网络(CNN)

一、引子————边界检测 我们来看一个最简单的例子:“边界检测(edge detection)”,假设我们有这样的一张图片,大小88: 图片中的数字代表该位置的像素值,我们知道,像素值越大&#…

ubuntu上查看各个进程的实时CPUMEM占用的办法

top常见参数top界面分析system monitorhtop1、查看htop的使用说明2、显示树状结构3、htop使用好文推荐top top的用法应该是最为普遍的 常见参数 -d 更新频率,top显示的界面几秒钟更新一次 -n 更新的次数,top显示的界面更新多少次之后就自动结束了 当然也可以将top日志通过…

3分钟使用 WebSocket 搭建属于自己的聊天室(WebSocket 原理、应用解析)

文章目录 WebSocket 的由来WebSocket 是什么WebSocket 优缺点优点缺点 WebSocket 适用场景主流浏览器对 WebSocket 的兼容性WebSocket 通信过程以及原理建立连接具体过程示例Sec-WebSocket-KeySec-WebSocket-Extensions 数据通信数据帧帧头(Frame Header&#xff09…

Pandas一键爬取解析代理IP与代理IP池的维护

目录 前言 一、获取代理IP 二、解析代理IP 三、维护代理IP池 四、完整代码 总结 前言 在爬虫过程中,我们经常会使用代理IP来绕过一些限制,比如防止被封IP等问题。而代理IP的获取和维护是一个比较麻烦的问题,需要花费一定的时间和精力。…

机器学习/sklearn 笔记:K-means,kmeans++,MiniBatchKMeans

1 K-means介绍 1.0 方法介绍 KMeans算法通过尝试将样本分成n个方差相等的组来聚类,该算法要求指定群集的数量。它适用于大量样本,并已在许多不同领域的广泛应用领域中使用。KMeans算法将一组样本分成不相交的簇,每个簇由簇中样本的平均值描…

Ocam——自由录屏工具~

当我们想要做一些混剪、恶搞类型的视频时,往往需要源影视作品中的诸多素材,虽然可以通过裁减mp4文件的方式来获取片段,但在高画质的条件下,mp4文件本身通常会非常大,长此以往,会给剪辑工作带来诸多不便&…

使用 PowerShell 创建共享目录

在 Windows 中,可以使用共享目录来将文件和文件夹共享给其他用户或计算机。共享目录可以通过网络访问,这使得它们非常适合用于文件共享、协作和远程访问。 要使用 PowerShell 创建共享目录,可以使用 New-SmbShare cmdlet。New-SmbShare cmdl…

arduino入门一:点亮第一个led

void setup() { pinMode(12, OUTPUT);//12引脚设置为输出模式 } void loop() { digitalWrite(12, HIGH);//设置12引脚为高电平 delay(1000);//延迟1000毫秒(1秒) digitalWrite(12, LOW);//设置12引脚为低电平 delay(1000); }

聚观早报 |快手Q3营收;拼多多杀入大模型;Redmi K70E开启预约

【聚观365】11月23日消息 快手Q3营收 拼多多杀入大模型 Redmi K70E开启预约 华为nova 12系列或下周发布 亚马逊启动“AI就绪”新计划 快手Q3营收 财报显示,快手第三季度营收279亿元,同比增长20.8%;期内盈利21.8亿元,去年同期…

2023 年 亚太赛 APMCM (A题)国际大学生数学建模挑战赛 |数学建模完整代码+建模过程全解全析

当大家面临着复杂的数学建模问题时,你是否曾经感到茫然无措?作为2022年美国大学生数学建模比赛的O奖得主,我为大家提供了一套优秀的解题思路,让你轻松应对各种难题。 完整内容可以在文章末尾领取! 问题1 图像处理&am…

免费多域名SSL证书

顾名思义,免费多域名SSL证书就是一种能够为多个域名或子域提供HTTPS安全保护的证书。这意味着,如果您有三个域名——例如example.com、example.cn和company.com,您可以使用一个免费的多域名SSL证书为所有这些域名提供安全保障,而无…

Missing file libarclite_iphoneos.a 问题解决方案

问题 在Xcode 运行项目会报以下错误 File not found: /Applications/Xcode-beta.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/arc/libarclite_iphoneos.a解决方案 打开URL https://github.com/kamyarelyasi/Libarclite-Files ,下载liba…

解决Android端libc++_shared.so库冲突问题

前言 随着App功能增多,集成的so库也会增多,如果系统中多个so库都使用系统自动生成的libc_shared.so库,如果多个SDK都有该so包,就会出现报错: 解决办法 如果出现该问题,说明您的项目中有多个SDK共同依赖了C标…

什么?Postman也能测WebSocket接口了?

01 WebSocket 简介 WebSocket是一种在单个TCP连接上进行全双工通信的协议。 WebSocket使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。在WebSocket API中,浏览器和服务器只需要完成一次握手,两者之间就直接…

牛客 最小公配数 golang版实现

题目请参考: HJ108 求最小公倍数 题解: 在大的数的倍数里面去找最小的能整除另外一个数的数,就是最小公倍数,按照大的来找,循环次数能够降到很少,提升效率 golang实现: package mainimport ("fmt" )func main() {a : …

电商API接口|电商数据接入|拼多多平台根据商品ID查商品详情SKU和商品价格参数

随着科技的不断进步,API开发领域也逐渐呈现出蓬勃发展的势头。今天我将向大家介绍API接口,电商API接口具备独特的特点,使得数据获取变得更加高效便捷。 快速获取API数据——优化数据访问速度 传统的数据获取方式可能需要经过多个中介环节&…

AI写代码 可以代替人工吗?

近年AI技术非常火热,有人就说,用AI写代码程序员不就都得下岗吗?对此我的回答是否定的,因为AI虽然已经有了编写代码的能力,但它现在的水平大多还仅限于根据业务需求搭建框架,而具体的功能实现还尚且稚嫩&…

【Vue】Node.js的下载安装与配置

目录 一.下载安装 官网: 二.环境变量的配置 三.设置全局路径和缓存路径 四.配置淘宝镜像 五.查看配置 六.使用npm安装cnpm ​ 一.下载安装 官网: https://nodejs.org/en/download 下载完之后,安装的时候一直点next即可&#xff0c…

【C++高阶(四)】红黑树深度剖析--手撕红黑树!

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:C从入门到精通⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学习C   🔝🔝 红黑树 1. 前言2. 红黑树的概念以及性质3. 红黑…
最新文章