python pytorch实现RNN,LSTM,GRU,文本情感分类

python pytorch实现RNN,LSTM,GRU,文本情感分类

数据集格式:
在这里插入图片描述
有需要的可以联系我

实现步骤就是:
1.先对句子进行分词并构建词表
2.生成word2id
3.构建模型
4.训练模型
5.测试模型

代码如下:


import pandas as pd
import torch
import matplotlib.pyplot as plt
import jieba
import numpy as np

"""
作业:
一、完成优化
优化思路

1 jieba
2 取常用的3000字
3 修改model:rnn、lstm、gru

二、完成测试代码
"""

# 了解数据
dd = pd.read_csv(r'E:\peixun\data\train.csv')
# print(dd.head())

# print(dd['label'].value_counts())

# 句子长度分析
# 确定输入句子长度为 500
text_len = [len(i) for i in dd['text']]
# plt.hist(text_len)
# plt.show()
# print(max(text_len), min(text_len))

# 基本参数 config
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('my device:', DEVICE)

MAX_LEN = 500
BATCH_SIZE = 16
EPOCH = 1
LR = 3e-4

# 构建词表 word2id
vocab = []
for i in dd['text']:
    vocab.extend(jieba.lcut(i, cut_all=True))  # 使用 jieba 分词
    # vocab.extend(list(i))

vocab_se = pd.Series(vocab)
print(vocab_se.head())
print(vocab_se.value_counts().head())

vocab = vocab_se.value_counts().index.tolist()[:3000]  # 取频率最高的 3000 token
# print(vocab[:10])
# exit()

WORD_PAD = "<PAD>"
WORD_UNK = "<UNK>"
WORD_PAD_ID = 0
WORD_UNK_ID = 1

vocab = [WORD_PAD, WORD_UNK] + list(set(vocab))

print(vocab[:10])
print(len(vocab))

vocab_dict = {k: v for v, k in enumerate(vocab)}

# 词表大小,vocab_dict: word2id; vocab: id2word
print(len(vocab_dict))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import pandas as pd


# 定义数据集 Dataset
class Dataset(data.Dataset):
    def __init__(self, split='train'):
        # ChnSentiCorp 情感分类数据集
        path =  r'E:/peixun/data/' + str(split) + '.csv'
        self.data = pd.read_csv(path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        text = self.data.loc[i, 'text']
        label = self.data.loc[i, 'label']

        return text, label


# 实例化 Dataset
dataset = Dataset('train')

# 样本数量
print(len(dataset))
print(dataset[0])


# 句子批处理函数
def collate_fn(batch):
    # [(text1, label1), (text2, label2), (3, 3)...]
    sents = [i[0][:MAX_LEN] for i in batch]
    labels = [i[1] for i in batch]

    inputs = []
    # masks = []

    for sent in sents:
        sent = [vocab_dict.get(i, WORD_UNK_ID) for i in list(sent)]
        pad_len = MAX_LEN - len(sent)

        # mask = len(sent) * [1] + pad_len * [0]
        # masks.append(mask)

        sent += pad_len * [WORD_PAD_ID]

        inputs.append(sent)

    # 只使用 lstm 不需要用 masks
    # masks = torch.tensor(masks)
   # print(inputs)
    inputs = torch.tensor(inputs)
    labels = torch.LongTensor(labels)

    return inputs.to(DEVICE), labels.to(DEVICE)


# 测试 loader
loader = data.DataLoader(dataset,
                         batch_size=BATCH_SIZE,
                         collate_fn=collate_fn,
                         shuffle=True,
                         drop_last=False)

inputs, labels = iter(loader).__next__()
print(inputs.shape, labels)


# 定义模型
class Model(nn.Module):
    def __init__(self, vocab_size=5000):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, 100, padding_idx=WORD_PAD_ID)

        # 多种 rnn
        self.rnn = nn.RNN(100, 100, 1, batch_first=True, bidirectional=True)
        self.gru = nn.GRU(100, 100, 1, batch_first=True, bidirectional=True)
        self.lstm = nn.LSTM(100, 100, 1, batch_first=True, bidirectional=True)

        self.l1 = nn.Linear(500 * 100 * 2, 100)
        self.l2 = nn.Linear(100, 2)

    def forward(self, inputs):
        out = self.embed(inputs)
        out, _ = self.lstm(out)
        out = out.reshape(BATCH_SIZE, -1)  # 16 * 100000
        out = F.relu(self.l1(out))  # 16 * 100
        out = F.softmax(self.l2(out))  # 16 * 2

        return out


# 测试 Model
model = Model()
print(model)

# 模型训练
dataset = Dataset()
loader = data.DataLoader(dataset,
                         batch_size=BATCH_SIZE,
                         collate_fn=collate_fn,
                         shuffle=True)

model = Model().to(DEVICE)

# 交叉熵损失
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

model.train()
for e in range(EPOCH):
    for idx, (inputs, labels) in enumerate(loader):
        # 前向传播,计算预测值
        out = model(inputs)
        # 计算损失
        loss = loss_fn(out, labels)
        # 反向传播,计算梯度
        loss.backward()
        # 参数更新
        optimizer.step()
        # 梯度清零
        optimizer.zero_grad()

        if idx % 10 == 0:
            out = out.argmax(dim=-1)
            acc = (out == labels).sum().item() / len(labels)

            print('>>epoch:', e,
                  '\tbatch:', idx,
                  '\tloss:', loss.item(),
                  '\tacc:', acc)

# 模型测试
test_dataset = Dataset('test')
test_loader = data.DataLoader(test_dataset,
                              batch_size=BATCH_SIZE,
                              collate_fn=collate_fn,
                              shuffle=False)

loss_fn = nn.CrossEntropyLoss()

out_total = []
labels_total = []

model.eval()
for idx, (inputs, labels) in enumerate(test_loader):
    out = model(inputs)
    loss = loss_fn(out, labels)

    out_total.append(out)
    labels_total.append(labels)

    if idx % 50 == 0:
        print('>>batch:', idx, '\tloss:', loss.item())
        
correct=0
sumz=0
for i in range(len(out_total)):
   out = out_total[i].argmax(dim=-1)
   correct = (out == labels_total[i]).sum().item() +correct
   sumz=sumz+len(labels_total[i])
    #acc = (out_total == labels_total).sum().item() / len(labels_total)

print('>>acc:', correct/sumz)

运行结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/205320.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【强迫症患者必备】SpringBoot项目中Mybatis使用mybatis-redis开启三级缓存必须创建redis.properties优化方案

springboot项目中mybatis使用mybatis-redis开启三级缓存需要创建redis.properties优化方案 前言下载mybatis-redis源码分析RedisCache 代码RedisConfigurationBuilder的parseConfiguration方法 优化改造1.创建JedisConfig类2.复制RedisCache代码创建自定义的MyRedisCache3.指定…

INA219电流感应芯片_程序代码

详细跳转借鉴链接INA219例程此处进行总结 简单介绍一下 INA219&#xff1a; 1、 输入脚电压可以从 0V~26V,INA219 采用 3.3V/5V 供电. 2、 能够检测电流&#xff0c;电压和功率&#xff0c;INA219 内置基准器和乘法器使之能够直接以 A 为单位 读出电流值。 3、 16 位可编程地…

计算机体系结构----流水线技术(三)

本文仅供学习&#xff0c;不作任何商业用途&#xff0c;严禁转载。绝大部分资料来自----计算机系统结构教程(第二版)张晨曦等 计算机体系结构----流水线技术&#xff08;三&#xff09; 3.1 流水线的基本概念3.1.1 什么是流水线3.1.2 流水线的分类1. 部件级流水线、处理机级流…

DCAMnet网络复现与讲解

距论文阅读完毕已经过了整整一周多。。。终于抽出时间来写这篇辣&#xff01;~ 论文阅读笔记放这里&#xff1a; 基于可变形卷积和注意力机制的带钢表面缺陷快速检测网络DCAM-Net&#xff08;论文阅读笔记&#xff09;-CSDN博客 为了方便观看&#xff0c;我把结构图也拿过来了。…

Unity 一些常用注解

在Unity中有一些比较常用的注解&#xff1a; 1、[SerializeField]&#xff1a;将私有字段或属性显示在 Unity 编辑器中&#xff0c;使其可以在 Inspector 窗口中进行编辑。 2、[Range(min, max)]&#xff1a;限制数值字段或属性的范围&#xff0c;在 Inspector 窗口中以滑动条…

【合集】MQ消息队列——Message Queue消息队列的合集文章 RabbitMQ入门到使用

前言 RabbitMQ作为一款常用的消息中间件&#xff0c;在微服务项目中得到大量应用&#xff0c;其本身是微服务中的重点和难点。本篇博客是Message Queue相关的学习博客文章的合集篇&#xff0c;目前主要是RabbitMQ入门到使用文章&#xff0c;后续会扩展其他MQ。 目录 前言一、R…

中国毫米波雷达产业分析3——毫米波雷达市场分析(四、五、六)

四、康养雷达市场 &#xff08;一&#xff09;市场背景 1、政府出台系列政策提升智慧健康养老产品供给和应用 康养雷达是一种以老年人为主要监测对象&#xff0c;可以实现人体感应探测、跌倒检测报警、睡眠呼吸心率监测等重要养老监护功能的新型智慧健康养老产品。 随着我国经…

C#常见的设计模式-创建型模式

引言 在软件开发过程中&#xff0c;设计模式是一种被广泛采用的思想和实践&#xff0c;可以提供一种标准化的解决方案&#xff0c;以解决特定问题。设计模式分为三种类型&#xff1a;创建型模式、结构型模式和行为型模式。本篇文章将重点介绍C#中常见的创建型模式。 目录 引言…

【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析

文章目录 行为型模式1、模板方法模式&#xff08;1&#xff09;概述&#xff08;2&#xff09;结构&#xff08;3&#xff09;案例实现&#xff08;4&#xff09;优缺点&#xff08;5&#xff09;适用场景&#xff08;6&#xff09;JDK源码解析&#xff08;7&#xff09;模板方…

1-3、DOSBox环境搭建

语雀原文链接 文章目录 1、安装DOSBox2、Debug进入Debugrdeautq 1、安装DOSBox 官网下载下载地址&#xff1a;https://www.dosbox.com/download.php?main1此处直接下载这个附件&#xff08;内部有8086的DEBUG.EXE环境&#xff09;8086汇编工作环境.rar执行安装DOSBox0.74-wi…

pandas-profiling / ydata-profiling介绍与使用教程

文章目录 pandas-profilingydata-profilingydata-profiling实际应用iris鸢尾花数据集分析 pandas-profiling pandas_profiling 官网&#xff08;https://pypi.org/project/pandas-profiling/&#xff09;大概在23年4月前发出如下公告&#xff1a; Deprecated pandas-profilin…

keepalive路由缓存实现前进刷新后退缓存

1.在app.vue中配置全局的keepalive并用includes指定要缓存的组件路由name名字数组 <keep-alive :include"keepCachedViews"><router-view /></keep-alive>computed: {keepCachedViews() {console.log(this.$store.getters.keepCachedViews, this.…

提升技能素养,AMCAP做出合适的决策

近年来&#xff0c;智能配置投资与理财逐渐受到关注并走俏。这是一种简单快捷的智慧化理财方式&#xff0c;通过将个人和家族的闲置资金投入到低风险高流动性的产品中。 国际财富管理投资机构AMCAP集团金融分析师表示&#xff1a;智能配置投资与理财之所以持续走俏&#xff0c…

Redis String类型

String 类型是 Redis 最基本的数据类型&#xff0c;String 类型在 Redis 内部使用动态长度数组实现&#xff0c;Redis 在存储数据时会根据数据的大小动态地调整数组的长度。Redis 中字符串类型的值最大可以达到 512 MB。 关于字符串需要特别注意∶ 首先&#xff0c;Redis 中所…

Linux下查看目录大小

查看目录大小 Linux下查看当前目录大小&#xff0c;可用一下命令&#xff1a; du -h --max-depth1它会从下到大的显示文件的大小。

FastDFS+Nginx - 本地搭建文件服务器同时实现在外远程访问「内网穿透」

文章目录 前言1. 本地搭建FastDFS文件系统1.1 环境安装1.2 安装libfastcommon1.3 安装FastDFS1.4 配置Tracker1.5 配置Storage1.6 测试上传下载1.7 与Nginx整合1.8 安装Nginx1.9 配置Nginx 2. 局域网测试访问FastDFS3. 安装cpolar内网穿透4. 配置公网访问地址5. 固定公网地址5.…

【MySQL】视图:简化查询

文章目录 create view … as创建视图更改或删除视图drop view 删除视图replace关键字&#xff1a;更改视图 可更新视图with check option子句&#xff1a;防止行被删除视图的其他优点简化查询减小数据库设计改动的影响使用视图限制基础表访问 create view … as创建视图 把常用…

Python 异常处理(try except)

文章目录 1 概述1.1 异常示例 2 异常处理2.1 捕获异常 try except2.2 抛出异常 raise 3 异常类型3.1 内置异常3.2 自定义异常 1 概述 1.1 异常示例 异常&#xff1a;程序执行中出现错误&#xff0c;若不处理&#xff0c;则程序终止 示例代码&#xff1a; v 6 / 0 # 除数不…

GPT带我学Openpyxl操作Excel

注&#xff1a;以下文字大部分文字和代码由GPT生成 一、openpyxl详细介绍 Openpyxl是一个用于读取和编写Excel 2010 xlsx/xlsm/xltx/xltm文件的Python库。它允许您使用Python操作Excel文件&#xff0c;包括创建新的工作簿、读取和修改现有工作簿中的数据、设置单元格格式以及编…

信贷专员简历模板

这份简历内容&#xff0c;以信贷专员招聘需求为背景&#xff0c;我们制作了1份全面、专业且具有参考价值的简历案例&#xff0c;大家可以灵活借鉴。 信贷专员简历在线编辑下载&#xff1a;百度幻主简历 求职意向 求职类型&#xff1a;全职 意向岗位&#xff1a;信贷专员 …
最新文章