NLP之LSTM与BiLSTM

文章目录

  • 代码展示
  • 代码解读
  • 双向LSTM介绍(BiLSTM)

代码展示

import pandas as pd
import tensorflow as tf
tf.random.set_seed(1)
df = pd.read_csv("../data/Clothing Reviews.csv")
print(df.info())

df['Review Text'] = df['Review Text'].astype(str)
x_train = df['Review Text']
y_train = df['Rating']
print(y_train.unique())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 23486 entries, 0 to 23485
Data columns (total 11 columns):
 #   Column                   Non-Null Count  Dtype 
---  ------                   --------------  ----- 
 0   Unnamed: 0               23486 non-null  int64 
 1   Clothing ID              23486 non-null  int64 
 2   Age                      23486 non-null  int64 
 3   Title                    19676 non-null  object
 4   Review Text              22641 non-null  object
 5   Rating                   23486 non-null  int64 
 6   Recommended IND          23486 non-null  int64 
 7   Positive Feedback Count  23486 non-null  int64 
 8   Division Name            23472 non-null  object
 9   Department Name          23472 non-null  object
 10  Class Name               23472 non-null  object
[4 5 3 2 1]
from tensorflow.keras.preprocessing.text import Tokenizer

dict_size = 14848
tokenizer = Tokenizer(num_words=dict_size)

tokenizer.fit_on_texts(x_train)
print(len(tokenizer.word_index),tokenizer.index_word)

x_train_tokenized = tokenizer.texts_to_sequences(x_train)
from tensorflow.keras.preprocessing.sequence import pad_sequences
max_comment_length = 120
x_train = pad_sequences(x_train_tokenized,maxlen=max_comment_length)

for v in x_train[:10]:
    print(v,len(v))
# 构建RNN神经网络
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,SimpleRNN,Embedding,LSTM,Bidirectional
import tensorflow as tf

rnn = Sequential()
# 对于rnn来说首先进行词向量的操作
rnn.add(Embedding(input_dim=dict_size,output_dim=60,input_length=max_comment_length))
# RNN:simple_rnn (SimpleRNN)  (None, 100)   16100
# LSTM:simple_rnn (SimpleRNN)  (None, 100)  64400
rnn.add(Bidirectional(LSTM(units=100)))  # 第二层构建了100个RNN神经元
rnn.add(Dense(units=10,activation=tf.nn.relu))
rnn.add(Dense(units=6,activation=tf.nn.softmax))  # 输出分类的结果
rnn.compile(loss='sparse_categorical_crossentropy',optimizer="adam",metrics=['accuracy'])
print(rnn.summary())
result = rnn.fit(x_train,y_train,batch_size=64,validation_split=0.3,epochs=10)
print(result)
print(result.history)

代码解读

首先,我们来总结这段代码的流程:

  1. 导入了必要的TensorFlow Keras模块。
  2. 初始化了一个Sequential模型,这表示我们的模型会按顺序堆叠各层。
  3. 添加了一个Embedding层,用于将整数索引(对应词汇)转换为密集向量。
  4. 添加了一个双向LSTM层,其中包含100个神经元。
  5. 添加了两个Dense全连接层,分别包含10个和6个神经元。
  6. 使用sparse_categorical_crossentropy损失函数编译了模型。
  7. 打印了模型的摘要。
  8. 使用给定的训练数据和验证数据对模型进行了训练。
  9. 打印了训练的结果。

现在,让我们逐行解读代码:

  1. 导入依赖:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,SimpleRNN,Embedding,LSTM,Bidirectional
import tensorflow as tf

你导入了创建和训练RNN模型所需的TensorFlow Keras库。

  1. 初始化模型:
rnn = Sequential()

你选择了一个顺序模型,这意味着你可以简单地按顺序添加层。

  1. 添加Embedding层:
rnn.add(Embedding(input_dim=dict_size,output_dim=60,input_length=max_comment_length))

此层将整数索引转换为固定大小的向量。dict_size是词汇表的大小,max_comment_length是输入评论的最大长度。

  1. 添加LSTM层:
rnn.add(Bidirectional(LSTM(units=100)))

你选择了双向LSTM,这意味着它会考虑过去和未来的信息。它有100个神经元。

  1. 添加全连接层:
rnn.add(Dense(units=10,activation=tf.nn.relu))
rnn.add(Dense(units=6,activation=tf.nn.softmax))

这两个Dense层用于模型的输出,最后一层使用softmax激活函数进行6类的分类。

  1. 编译模型:
rnn.compile(loss='sparse_categorical_crossentropy',optimizer="adam",metrics=['accuracy'])

你选择了一个适合分类问题的损失函数,并选择了adam优化器。

  1. 显示模型摘要:
print(rnn.summary())

这将展示模型的结构和参数数量。

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 embedding (Embedding)       (None, 120, 60)           890880    
                                                                 
 bidirectional (Bidirectiona  (None, 200)              128800    
 l)                                                              
                                                                 
 dense (Dense)               (None, 10)                2010      
                                                                 
 dense_1 (Dense)             (None, 6)                 66        
                                                                 
=================================================================
Total params: 1,021,756
Trainable params: 1,021,756
Non-trainable params: 0
_________________________________________________________________
None
  1. 训练模型:
result = rnn.fit(x_train,y_train,batch_size=64,validation_split=0.3,epochs=10)

你用训练数据集训练了模型,其中30%的数据用作验证,训练了10个周期。

Epoch 1/10
257/257 [==============================] - 74s 258ms/step - loss: 1.2142 - accuracy: 0.5470 - val_loss: 1.0998 - val_accuracy: 0.5521
Epoch 2/10
257/257 [==============================] - 57s 221ms/step - loss: 0.9335 - accuracy: 0.6293 - val_loss: 0.9554 - val_accuracy: 0.6094
Epoch 3/10
257/257 [==============================] - 59s 229ms/step - loss: 0.8363 - accuracy: 0.6616 - val_loss: 0.9321 - val_accuracy: 0.6168
Epoch 4/10
257/257 [==============================] - 61s 236ms/step - loss: 0.7795 - accuracy: 0.6833 - val_loss: 0.9812 - val_accuracy: 0.6089
Epoch 5/10
257/257 [==============================] - 56s 217ms/step - loss: 0.7281 - accuracy: 0.7010 - val_loss: 0.9559 - val_accuracy: 0.6043
Epoch 6/10
257/257 [==============================] - 56s 219ms/step - loss: 0.6934 - accuracy: 0.7156 - val_loss: 1.0197 - val_accuracy: 0.5999
Epoch 7/10
257/257 [==============================] - 57s 220ms/step - loss: 0.6514 - accuracy: 0.7364 - val_loss: 1.1192 - val_accuracy: 0.6080
Epoch 8/10
257/257 [==============================] - 57s 222ms/step - loss: 0.6258 - accuracy: 0.7486 - val_loss: 1.1350 - val_accuracy: 0.6100
Epoch 9/10
257/257 [==============================] - 57s 220ms/step - loss: 0.5839 - accuracy: 0.7749 - val_loss: 1.1537 - val_accuracy: 0.6019
Epoch 10/10
257/257 [==============================] - 57s 222ms/step - loss: 0.5424 - accuracy: 0.7945 - val_loss: 1.1715 - val_accuracy: 0.5744
<keras.callbacks.History object at 0x00000244DCE06D90>
  1. 显示训练结果:
print(result)
<keras.callbacks.History object at 0x0000013AEAAE1A30>
print(result.history)
{'loss': [1.2142471075057983, 0.9334620833396912, 0.8363043069839478, 0.7795010805130005, 0.7280740141868591, 0.693393349647522, 0.6514003872871399, 0.6257606744766235, 0.5839114189147949, 0.5423741340637207], 
'accuracy': [0.5469586253166199, 0.6292579174041748, 0.6616179943084717, 0.6833333373069763, 0.7010340690612793, 0.7156326174736023, 0.7363746762275696, 0.748600959777832, 0.7748783230781555, 0.7944647073745728], 
'val_loss': [1.0997602939605713, 0.9553984999656677, 0.932131290435791, 0.9812102317810059, 0.9558586478233337, 1.019730806350708, 1.11918044090271, 1.1349923610687256, 1.1536787748336792, 1.1715185642242432], 
'val_accuracy': [0.5520862936973572, 0.609423816204071, 0.6168038845062256, 0.6088560819625854, 0.6043145060539246, 0.5999148488044739, 0.6080045700073242, 0.6099914908409119, 0.6019017696380615, 0.574368417263031]
}

这将展示训练过程中的损失和准确性等信息。

双向LSTM介绍(BiLSTM)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
例子:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/114063.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Django实战项目-学习任务系统-自定义URL拦截器

接着上期代码框架&#xff0c;6个主要功能基本实现&#xff0c;剩下的就是细节点的完善优化了。 首先增加URL拦截器&#xff0c;你不会希望没有登录用户就可以进入用户主页各种功能的&#xff0c;所以增加URL拦截器可以解决这个问题。 Django框架本身也有URL拦截器&#xff0…

非递归(迭代)遍历二叉树

前言 在树结构中我们经常使用递归算法&#xff0c;但是递归本身的特质会带来很多疑难痛点&#xff0c;比如递归过深导致爆栈&#xff0c;或者是逻辑复杂... 本文将以树的前序遍历为例&#xff0c;浅析迭代算法如何模拟递归过程。 思路 我们先来看看这个算法的具体思想。 在递…

泛微e-office系统存在SQL注入漏洞

泛微e-office系统存在SQL注入漏洞 一、泛微简介二、漏洞描述三、影响版本四、fofa查询语句五、漏洞复现 免责声明&#xff1a;请勿利用文章内的相关技术从事非法测试&#xff0c;由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失&#xff0c;均由使…

【算能】stream在docker的环境下编译报错

错误问题一&#xff1a; /workspace/sophon-stream/element/multimedia/encode/../../../3rdparty/websocketpp/websocketpp/common/asio.hpp:56:14: fatal error: boost/version.hpp: No such file or directory 56 | #include <boost/version.hpp> 解决方法&a…

rwkv模型lora微调之accelerate和deepspeed训练加速

目录 一、rwkv模型简介 二、lora原理简介 三、rwkv-lora微调 1、数据整理 2、环境搭建 a、Dockerfile编写 b、制造镜像 c、容器启动 3、训练代码修改 四、模型推理 1、模型推理 2、lora权重合并 3、推理web服务 五、总结 由于业务采用的ChatGLM模型推理成本太大了…

软件测试---边界值分析(功能测试)

能对限定边界规则设计测试点---边界值分析 选取正好等于、刚好大于、刚好小于边界的值作为测试数据 上点: 边界上的点 (正好等于)&#xff1b;必选(不考虑区开闭) 内点: 范围内的点 (区间范围内的数据)&#xff1b;必选(建议选择中间范围) 离点: 距离上点最近的点 (刚好…

linux下mysql-8.2.0集群部署(python版本要在2.7以上)

目录 一、三台主机准备工作 1、mysql官方下载地址&#xff1a;https://dev.mysql.com/downloads/ 2、修改/etc/hosts 3、关闭防火墙 二、三台主机安装mysql-8.2.0 1、解压 2、下载相应配置 3、初始化mysql&#xff0c;启动myslq&#xff0c;设置开机自启 4、查看初始密…

代码训练营第59天:动态规划part17|leetcode647回文子串|leetcode516最长回文子序列

leetcode647&#xff1a;回文子串 文章讲解&#xff1a;leetcode647 leetcode516&#xff1a;最长回文子序列 文章讲解&#xff1a;leetcode516 DP总结&#xff1a;动态规划总结 目录 1&#xff0c;leeetcode647 回文子串。 2&#xff0c;leetcode516 最长回文子串&#xff1…

Agent 应用于提示工程

如果Agent模仿了人类在现实世界中的操作方式&#xff0c;那么&#xff0c;能否应用于提示工程即Prompt Engingeering 呢&#xff1f; 从LLM到Prompt Engineering 大型语言模型(LLM)是一种基于Transformer的模型&#xff0c;已经在一个巨大的语料库或文本数据集上进行了训练&…

Docker(1)

文章目录 Docker物理机部署的缺点虚拟机Docker 与虚拟机的区别Docker 的优势 Docker 概念安装 DockerDocker 架构镜像加速Docker 命令进程服务相关命令 镜像相关文件命令容器相关的命令 镜像加载的原理UnionFS(联合文件系统)docker 镜像加载原理 容器的数据卷数据卷概念配置数据…

一座 “数智桥梁”,华为助力“天堑变通途”

《水调歌头游泳》中的一句话&#xff0c;“一桥飞架南北&#xff0c;天堑变通途”&#xff0c;广为人们所熟知&#xff0c;其中展现出的&#xff0c;是中国人对美好出行的无限向往。 天堑变通途从来不易。 中国是当今世界上交通运输最繁忙、最快捷的国家之一&#xff0c;交通行…

2024上海国际人工智能展(CSITF)以“技术,让生活更精彩”为核心理念,以“创新驱动发展,保护知识产权,促进技术贸易”为主题

2024上海国际人工智能展&#xff08;CSITF&#xff09; China&#xff08;Shanghai&#xff09;International Technology Fair 时间:2024年6月12-14日 地点:上海世博展览馆 主办单位 中华人民共和国商务部 中华人民共和国科学技术部 中华人民共和国国家知识产权局 上海市…

C#,数值计算——求解一组m维线性Volterra方程组的计算方法与源程序

1 文本格式 using System; namespace Legalsoft.Truffer { /// <summary> /// 求解一组m维线性Volterra方程组 /// Solves a set of m linear Volterra equations of the second kind using the /// extended trapezoidal rule.On input, t0 is the st…

Git 标签(Tag)实战:打标签和删除标签的步骤指南

目录 前言使用 Git 打本地和远程标签&#xff08;Tag&#xff09;删除本地和远程 Git 标签&#xff08;Tag&#xff09;开源项目标签&#xff08;Tag&#xff09;实战打标签删除标签 结语开源微服务商城项目前后端分离项目 前言 在开源项目中&#xff0c;版本控制是至关重要的…

python脚本监听域名证书过期时间,并将通知消息到钉钉

版本一&#xff1a; 执行脚本带上 --dingtalk-webhook和–domains后指定钉钉token和域名 python3 ssl_spirtime.py --dingtalk-webhook https://oapi.dingtalk.com/robot/send?access_tokenavd345324 --domains www.abc1.com www.abc2.com www.abc3.com脚本如下 #!/usr/bin…

什么是Node.js的流(stream)?它们有什么作用?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

LeetCode算法题解|​ 669. 修剪二叉搜索树​、108. 将有序数组转换为二叉搜索树、​538. 把二叉搜索树转换为累加树​

一、LeetCode 669. 修剪二叉搜索树​ 题目链接&#xff1a;669. 修剪二叉搜索树 题目描述&#xff1a; 给你二叉搜索树的根节点 root &#xff0c;同时给定最小边界low 和最大边界 high。通过修剪二叉搜索树&#xff0c;使得所有节点的值在[low, high]中。修剪树 不应该 改变…

(免费领源码)java#springboot#MYSQL 电影推荐网站30760-计算机毕业设计项目选题推荐

摘 要 随着互联网时代的到来&#xff0c;同时计算机网络技术高速发展&#xff0c;网络管理运用也变得越来越广泛。因此&#xff0c;建立一个B/S结构的电影推荐网站&#xff1b;电影推荐网站的管理工作系统化、规范化&#xff0c;也会提高平台形象&#xff0c;提高管理效率。 本…

《TCP/IP详解 卷一:协议》第5章的IPv4数据报的IHL字段解释

首先说明一下&#xff0c;这里并不解释整个IPv4数据报各个字段的含义&#xff0c;仅仅针对IHL字段作解释。 我们先看下IPv4数据报格式 对于IHL字段&#xff0c; 《TCP/IP详解 卷一&#xff1a;协议》这么解释&#xff1a; IPv4数据报。头部大小可变&#xff0c;4位的IHL字段…

MongoDB系例全教程

一、系列文章目录 一、MongoDB安装教程—官方原版 二、MongoDB 使用教程(配置、管理、监控)_linux mongodb 监控 三、MongoDB 基于角色的访问控制 四、MongoDB用户管理 五、MongoDB基础知识详解 六、MongoDB—Indexs 七、MongoDB事务详解 八、MongoDB分片教程 九、Mo…
最新文章