Tensorflow2.0笔记 - 循环神经网络RNN做IMDB评价分析

        本笔记记录使用SimpleRNNCell做一个IMDB评价系统情感二分类问题的例子。

import os
import time
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
#np.random.seed(22)
tf.__version__


#取常见的10000个单词
total_words = 10000
#句子最长的单词数量设置为80
max_review_len = 80
#embedding设置为100,表示每个单词用100维向量表示
embedding_len = 100
#加载IMDB数据集
(x_train,y_train), (x_test, y_test) = datasets.imdb.load_data(num_words = total_words)
#对训练数据和测试数据的句子进行填充或截断
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)

#构建数据集
batchsize = 128
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(1000).batch(batchsize, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsize, drop_remainder=True)

#x_train包含25000个句子,每个句子包含80个单词,y_train标签为1表示好评,0表示差评
print('x_train: shape - ', x_train.shape, ' y_train: max/min -', tf.reduce_max(y_train).numpy(), '/', tf.reduce_min(y_train).numpy())
print('x_test: shape - ', x_test.shape)


class MyRNN(keras.Model):
    #units:state的维度
    def __init__(self, total_words, embedding_len, max_review_len, units):
        super(MyRNN, self).__init__()
        #初始的序列状态初始化为0(第0时刻的状态)
        self.state0 = [tf.zeros([batchsize, units])]
        self.state1 = [tf.zeros([batchsize, units])]
        #embedding层,将文本转换为embedding表示
        #[b, 80] => [b, 80, 100]
        self.embedding = layers.Embedding(total_words, embedding_len, input_length=max_review_len)
        #[b, 80, 100] , units: 64 - 转换为64维的state [b, 64]
        self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.2)
        self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.2)
        #全连接层 [[b, 64] => [b, 1]
        self.outlayer = layers.Dense(1)
    #inputs: [b, 80] 
    def call(self, inputs, training=None):
        x = inputs
        #做embedding,[b,80] => [b, 80, 100]
        x = self.embedding(x)
        #做RNN cell计算
        #[b, 80, 100] => [b,  64]
        #遍历句子中的每个单词
        # word: [b, 100]
        state0 = self.state0
        state1 = self.state1
        for word in tf.unstack(x, axis=1):
            #h1 = x*w_xh + h0*w_hh
            out0, state0 = self.rnn_cell0(word, state0, training)
            out1, state1 = self.rnn_cell1(out0, state1)
        #循环完毕后,得到的out为[b, 64],表示每个句子最终得到的状态
        x = self.outlayer(out1)
        #计算最终评价结果
        prob = tf.sigmoid(x)
        return prob

def main():
    units = 64
    epochs = 15
    lr = 0.001

    model = MyRNN(total_words, embedding_len, max_review_len, units)
    model.compile(optimizer = optimizers.Adam(lr), loss = tf.losses.BinaryCrossentropy(),
                 metrics=['accuracy'])
    model.fit(db_train, epochs=epochs, validation_data=db_test)

    model.evaluate(db_test)

if __name__ == '__main__':
    main()

运行结果:

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/608294.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

VisualGLM-6B微调(V100)

Visualglm-6b-CSDN博客文章浏览阅读1.3k次。【官方教程】XrayGLM微调实践,(加强后的GPT-3.5)能力媲美4.0,无次数限制。_visualglm-6bhttps://blog.csdn.net/u012193416/article/details/131074962?ops_request_misc%257B%2522req…

2. Linux 基本指令(上)|ls|pwd|cd|tree|touch|mkdir|rmdir|rm

前言 计算机软硬件体系结构 层状结构应用软件Word,Matlab操作系统Windows,Linux设备驱动声卡驱动硬件CPU,内存,磁盘,显示器,键盘 操作系统概念 操作系统 是一款进行软硬件资源管理的软件 例子 比如在学…

Join优化规则及应用层BI系统实践

目录 一、背景 二、查询优化器概述​编辑 2.1 System R Optimizer 2.2 Volcano Optimizer 2.3 Cascade Optimizer 三、Join相关优化规则 3.1 JoinReorder 3.1.1 少量表的Reorder 3.1.2 大量表的Reorder 3.1.3 星型模型的Reorder 3.2 外连接消除 3.3 Join消除 3.4 谓…

使用ROW_NUMBER()分组遇到的坑

1、再一次清洗数据时,需要过滤重复数据,使用了ROW_NUMBER() 来分组给每组数据排序号 在获取每组的第一行数据 with records as(select cc.F_Id as Id,REPLACE(cc.F_CNKITitle,char(10),1) as F_CNKITitle,REPLACE(REPLACE(cc.F_Special,专题&#xff1…

适合大学生的鸿蒙开发板-Purple Pi OH之安装Docker

一、介绍 本文基于purple-pi-oh系列主板演示Linux 系统安装Docker,方法适用于RK3566全系列产品。本教程将指导你在基于RK3566的LInux系统上安装Docker。Docker是一个开放源代码的应用容器引擎,允许开发者打包他们的应用及依赖包到一个可移植的容器中&am…

【银角大王——Django课程——分页显示功能实现】

分页显示功能实现 添加假数据,然后演示分页功能分页——功能实现基于之前的靓号列表函数添加代码只显示10条——按照等级排序页码list表样式——bootstrap样式显示当前页面——前五页,后五页给当前页添加样式页码bug更改——出现负数,没有数据…

【neteq】tgcall的调用、neteq的创建及接收侧ReceiveStatisticsImpl统计

G:\CDN\P2P-DEV\Libraries\tg_owt\src\call\call.cc基本是按照原生webrtc的来的:G:\CDN\P2P-DEV\tdesktop-offical\Telegram\ThirdParty\tgcalls\tgcalls\group\GroupInstanceCustomImpl.cpptg对neteq的使用 worker 线程创建call Call的config需要neteqfactory Call::CreateAu…

MySQL——变量的浮点数问题处理

新建链接,自带world数据库,里面自带city表格。 DQL #MySQL变量的浮点数问题处理 set dx3.14,dy3.25; select dxdy;#计算显示异常,会有很多00000的提示set resultdxdy; select result; 查询结果

为何预测预测蛋白质结构这么重要AlphaFold 3;阿里巴巴的开源语音转文字;抱抱脸开源LeRobot

✨ 1: AlphaFold 3 谷歌DeepMind和同构实验室推出AlphaFold 3 AI模型,旨在精确预测生命分子的结构和相互作用。 AlphaFold 3 是由谷歌DeepMind和Isomorphic Labs开发的一款新型AI模型,它可以以前所未有的精确度预测蛋白质、DNA、RNA、配体(…

【VTKExamples::Rendering】第一期 TestAmbientSpheres(环境照明系数)

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例TestAmbientShperes,介绍环境照明系数对Actor颜色的影响,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动…

C++:重载、重写与重定义

一、重载、重写与重定义的概念 C中,重载、重写和重定义是三个与函数和类成员相关的概念,但它们具有不同的含义和用途。 重载:是指在同一作用域内,可以有多个名称相同但参数列表(参数类型、参数个数或参数顺序&#x…

PyCharm安装教程(超详细图文教程)

一、下载和安装 1.进入PyCharm官方下载,官网下载地址: https://www.jetbrains.com/pycharm/download/ 专业版安装插件放网盘了,网盘下载即可:itcxy.xyz/229.html2.安装 1.下载后找到PyCharm安装包,然后双击双击.ex…

网工内推 | 技术支持工程师,最高15k,加班有补贴

01 星网信通 招聘岗位:售前技术支持 职责描述: 1、售前技术支持:技术交流、产品选型报价、方案制作等工作; 2、招投标支持:项目招标参数撰写、标书质疑、应标文件技术部分撰写及资质文件归纳准备、现场讲标及技术澄清…

Linux学习笔记1---Windows上运行Linux

在正点原子的教程中学习linux需要安装虚拟机或者在电脑上安装一个Ubuntu系统,但个人觉得太麻烦了,现在linux之父加入了微软,因此在Windows上也可以运行linux 了。具体方法如下: 一、 在Windows上的设置 在window的搜索框内&#…

vivado 低级别 SVF JTAG 命令、多链 SVF 操作

多链 SVF 操作 以下示例显示了如何在 SVF 链上处理操作。 每个链中连接有 2 个器件 : xcku11 和 xcku9 。配置存储器连接到链中的第 2 个器件 (xcku9) 。为访问此配置存储器 , SVF 会使用 HIR 、 HDR 、 TIR 和 TDR 命令来生成命令。为刷写此…

自动驾驶学习2-毫米波雷达

1、简介 1.1 频段 毫米波波长短、频段宽,比较容易实现窄波束,雷达分辨率高,不易受干扰。波长介于1~10mm的电磁波,频率大致范围是30GHz~300GHz 毫米波雷达是测量被测物体相对距离、相对速度、方位的高精度传感器。 车载毫米波雷达主要有24GHz、60GHz、77GHz、79GHz四个频段。 …

【JavaWeb】Servlet+JSP+EL表达式+JSTL标签库+Filter过滤器+Listener监听器

需要提前准备了哪些技术,接下来的课才能听懂? JavaSE(Java语言的标准版,Java提供的最基本的类库) Java的开发环境搭建Java的基础语法Java的面向对象数组常用类异常集合多线程IO流反射机制注解Annotation… MySQL&…

守护数字疆域:2024年网络安全报告深度解读

在这个数据如潮涌动的数字时代,每一比特信息都可能是攻防双方角力的战场。《Check Point 2024年网络安全报告》不但为我们揭示了过去一年网络安全世界的风云变幻,更以前瞻性的视角勾勒出未来的挑战与机遇。此刻,让我们携手深潜这份权威指南的…

【智能算法】人工原生动物优化算法(APO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献5.获取代码 1.背景 2024年,X Wang受到自然界原生动物启发,提出了人工原生动物优化算法( Artificial Protozoa Optimizer, APO)。 2.算法原理 2.1算法思想 AP…

【比邻智选】MR880A模组

🚀高性价比,5G/4G双模,稳定可靠 🌐功能丰富,5G特性一应俱全 🧩多封装兼容,适配性强,灵活升级智能设备