Pytorch 文本情感分类案例

一共六个脚本,分别是:

        ①generateDictionary.py用于生成词典

        ②datasets.py定义了数据集加载的方法

        ③models.py定义了网络模型

        ④configs.py配置一些参数

        ⑤run_train.py训练模型

        ⑥run_test.py测试模型

数据集icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88486959?spm=1001.2014.3001.5501停用词表icon-default.png?t=N7T8https://download.csdn.net/download/Victor_Li_/88486973?spm=1001.2014.3001.5501

generateDictionary.py如下

import jieba

data_path = "./weibo_senti_100k.csv"
data_stop_path = "./hit_stopwords.txt"
data_list = open(data_path,encoding='utf-8').readlines()[1:]
stops_word = open(data_stop_path,encoding='utf-8').readlines()
stops_word = [line.strip() for line in stops_word]
stops_word.append(" ")
stops_word.append("\n")

voc_dict = {}
min_seq = 1
top_n = 1000
UNK = "UNK"
PAD = "PAD"
for item in data_list:
    label = item[0]
    content = item[2:].strip()
    seg_list = jieba.cut(content,cut_all=False)

    seg_res = []
    for seg_item in seg_list:
        if seg_item in stops_word:
            continue
        seg_res.append(seg_item)
        if seg_item in voc_dict.keys():
            voc_dict[seg_item] += 1
        else:
            voc_dict[seg_item] = 1

    # print(content)
    # print(seg_res)

    voc_list = sorted([_ for _ in voc_dict.items() if _[1] > min_seq],key=lambda x:x[1],reverse=True)[:top_n]
    voc_dict = {word_count[0]:idx for idx,word_count in enumerate(voc_list)}
    voc_dict.update({UNK:len(voc_dict),PAD:len(voc_dict)+1})

ff = open("./dict","w")
for item in voc_dict.keys():
    ff.writelines("{},{}\n".format(item,voc_dict[item]))
ff.close()

datasets.py如下

from torch.utils.data import Dataset, DataLoader
import jieba
import numpy as np


def read_dict(voc_dict_path):
    voc_dict = {}
    with open(voc_dict_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line == '':
                continue
            word, index = line.split(",")
            voc_dict[word] = int(index)
    return voc_dict


def load_data(data_path, data_stop_path,isTest):
    data_list = open(data_path, encoding='utf-8').readlines()[1:]
    stops_word = open(data_stop_path, encoding='utf-8').readlines()
    stops_word = [line.strip() for line in stops_word]
    stops_word.append(" ")
    stops_word.append("\n")

    voc_dict = {}
    data = []
    max_len_seq = 0
    for item in data_list:
        label = item[0]
        content = item[2:].strip()
        seg_list = jieba.cut(content, cut_all=False)

        seg_res = []
        for seg_item in seg_list:
            if seg_item in stops_word:
                continue
            seg_res.append(seg_item)
            if seg_item in voc_dict.keys():
                voc_dict[seg_item] += 1
            else:
                voc_dict[seg_item] = 1
        if len(seg_res) > max_len_seq:
            max_len_seq = len(seg_res)
        if isTest:
            data.append([label, seg_res,content])
        else:
            data.append([label, seg_res])
    return data, max_len_seq


class text_ClS(Dataset):
    def __init__(self, data_path, data_stop_path,voc_dict_path,isTest=False):
        self.isTest = isTest
        self.data_path = data_path
        self.data_stop_path = data_stop_path
        self.voc_dict = read_dict(voc_dict_path)
        self.data, self.max_len_seq = load_data(self.data_path, self.data_stop_path,isTest)
        np.random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        data = self.data[item]
        label = int(data[0])
        word_list = data[1]
        if self.isTest:
            content = data[2]
        input_idx = []
        for word in word_list:
            if word in self.voc_dict.keys():
                input_idx.append(self.voc_dict[word])
            else:
                input_idx.append(self.voc_dict["UNK"])
        if len(input_idx) < self.max_len_seq:
            input_idx += [self.voc_dict["PAD"] for _ in range(self.max_len_seq - len(input_idx))]
        data = np.array(input_idx)
        if self.isTest:
            return label,data,content
        else:
            return label, data

def data_loader(dataset,config):
    return DataLoader(dataset,batch_size=config.batch_size,shuffle=config.is_shuffle,num_workers=4,pin_memory=True)

models.py如下

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Model(nn.Module):
    def __init__(self,config):
        super(Model,self).__init__()
        self.embeding = nn.Embedding(config.n_vocab,config.embed_size,padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed_size,config.hidden_size,config.num_layers,batch_first=True,bidirectional=True,dropout=config.dropout)
        self.maxpool = nn.MaxPool1d(config.pad_size)
        self.fc = nn.Linear(config.hidden_size * 2 + config.embed_size,config.num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self,x):
        embed = self.embeding(x)
        out, _ = self.lstm(embed)
        out = torch.cat((embed, out), 2)
        out = F.relu(out)
        out = out.permute(0, 2, 1)
        out = self.maxpool(out).reshape(out.size()[0],-1)
        out = self.fc(out)
        out = self.softmax(out)
        return out

configs.py如下

import torch.types


class Config():
    def __init__(self):
        self.n_vocab = 1002
        self.embed_size = 256
        self.hidden_size = 256
        self.num_layers = 5
        self.dropout = 0.8
        self.num_classes = 2
        self.pad_size = 32
        self.batch_size = 32
        self.is_shuffle = True
        self.learning_rate = 0.001
        self.num_epochs = 100
        self.devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

run_train.py如下

import torch
import torch.nn as nn
from torch import optim
from models import Model
from datasets import data_loader,text_ClS
from configs import Config
import time
import torch.multiprocessing as mp

if __name__ == '__main__':
    mp.freeze_support()
    cfg = Config()

    data_path = "./weibo_senti_100k.csv"
    data_stop_path = "./hit_stopwords.txt"
    dict_path = "./dict"

    dataset = text_ClS(data_path, data_stop_path, dict_path)
    train_dataloader = data_loader(dataset,cfg)

    cfg.pad_size = dataset.max_len_seq

    model_text_cls = Model(cfg)
    model_text_cls.to(cfg.devices)

    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model_text_cls.parameters(), lr=cfg.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

    for epoch in range(cfg.num_epochs):
        running_loss = 0
        correct = 0
        total = 0
        epoch_start_time = time.time()
        for i,(labels,datas) in enumerate(train_dataloader):
            datas = datas.to(cfg.devices)
            labels = labels.to(cfg.devices)

            pred = model_text_cls.forward(datas)
            loss_val = loss_func(pred,labels)
            running_loss += loss_val.item()
            loss_val.backward()
            if ((i + 1) % 4 == 0) or (i + 1 == len(train_dataloader)):
                optimizer.step()
                optimizer.zero_grad()
            _, predicted = torch.max(pred.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        scheduler.step()
        accuracy_train = 100 * correct / total
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
        tain_loss = running_loss / len(train_dataloader)
        print("Epoch [{}/{}],Time: {:.4f}s,Loss: {:.4f},Acc: {:.2f}%".format(epoch + 1, cfg.num_epochs, epoch_time, tain_loss,accuracy_train))
        torch.save(model_text_cls.state_dict(),"./text_cls_model/text_cls_model{}.pth".format(epoch))

run_test.py如下

import torch
import torch.nn as nn
from torch import optim
from models import Model
from datasets import data_loader,text_ClS
from configs import Config
import time
import torch.multiprocessing as mp

if __name__ == '__main__':
    mp.freeze_support()
    cfg = Config()
    data_path = "./test.csv"
    data_stop_path = "./hit_stopwords.txt"
    dict_path = "./dict"
    cfg.batch_size = 1
    dataset = text_ClS(data_path, data_stop_path, dict_path,isTest=True)
    dataloader = data_loader(dataset,cfg)

    cfg.pad_size = dataset.max_len_seq

    model_text_cls = Model(cfg)
    model_text_cls.load_state_dict(torch.load('./text_cls_model/text_cls_model0.pth'))
    model_text_cls.to(cfg.devices)
    classes_name = ['负面的','正面的']
    for i,(label,input,content) in enumerate(dataloader):
        label = label.to(cfg.devices)
        input = input.to(cfg.devices)
        pred = model_text_cls.forward(input)
        _, predicted = torch.max(pred.data, 1)
        print("内容:{}, 实际结果:{}, 预测结果:{}".format(content,classes_name[label],classes_name[predicted[0]]))

测试结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/115888.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

《研发效能(DevOps)工程师》课程简介(三)丨IDCF

在研发效能领域中&#xff0c;【开发与交付】的学习重点在于掌握高效的开发工具和框架&#xff0c;了解敏捷开发方法&#xff0c;掌握持续集成与持续交付技术&#xff0c;以及如何保证应用程序的安全性和合规性等方面。 由国家工业和信息化部教育与考试中心颁发的职业技术证书…

【多线程】龟兔赛跑

package org.example;public class Race implements Runnable {//胜利者private static String winner;Overridepublic void run() {for(int i0;i<100;i){boolean flag gameOver(i);//如果flag>100,结束比赛if(flag){break;}System.out.println(Thread.currentThread().g…

Vue实现消费清单明细饼图展示

功能 可以进行消费项增删消费额大于500会标红消费金额合计饼图展示消费项 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-…

U盘显示无媒体怎么办?方法很简单

当出现U盘无媒体情况时&#xff0c;您可以在磁盘管理工具中看到一个空白的磁盘框&#xff0c;并且在文件资源管理器中不会显示出来。那么&#xff0c;导致这种问题的原因是什么呢&#xff1f;我们又该怎么解决呢&#xff1f; 导致U盘无媒体的原因是什么&#xff1f; 当您遇到上…

错误号码2058 Plugin caching_sha2_password could not be loaded:vX八白白白白白令自砸

sqlyog连接数据库时报错&#xff1a; 错误号码2058 Plugin caching_sha2_password could not be loaded:vX八白白白白白令自砸 网上查了资料&#xff0c;是MySQL 从 8.0 版本开始加密方式改变导致的原因。具体的咋也不再这里分析了&#xff0c;就直说如何解决这个问题。下面三…

软考之软件工程基础理论知识

软件工程基础 软件开发方法 结构化方法 将整个系统的开发过程分为若干阶段&#xff0c;然后依次进行&#xff0c;前一阶段是后一阶段的工作依据按顺序完成。应用最广泛。特点是注重开发过程的整体性和全局性。缺点是开发周期长文档设计说明繁琐&#xff0c;工作效率低开发前要…

BEV-YOLO 论文学习

1. 解决了什么问题&#xff1f; 出于安全和导航的目的&#xff0c;自驾感知系统需要全面而迅速地理解周围的环境。目前主流的研究方向有两个&#xff1a;第一种传感器融合方案整合激光雷达、相机和毫米波雷达&#xff0c;和第二种纯视觉方案。传感器融合方案的感知表现鲁棒&am…

docker学习笔记

1. docker安装 可以从以下地址下载并安装docker&#xff0c;Linux&#xff0c;Windows&#xff0c;MacOS均支持&#xff1a; 官网&#xff1a;https://docs.docker.com/engine/install/阿里云镜像安装&#xff1a;https://developer.aliyun.com/mirror/docker-ce Docker 安装…

3.Docker的客户端指令学习与实战

1.Docker的命令 1.1 启动Docker&#xff08;systemctl start docker&#xff09; systemctl start docker1.2 查看docker的版本信息&#xff08;docker version&#xff09; docker version1.3 显示docker系统范围的信息&#xff08;docker info&#xff09; docker info1.4…

3.21每日一题(区间在现求定积分)

当发现一个定积分&#xff0c;原函数根本找不出来时&#xff0c;可以用变量代换&#xff1a;区间再现&#xff01;&#xff01;&#xff01;

【JVM】双亲委派机制、打破双亲委派机制

&#x1f40c;个人主页&#xff1a; &#x1f40c; 叶落闲庭 &#x1f4a8;我的专栏&#xff1a;&#x1f4a8; c语言 数据结构 javaEE 操作系统 Redis 石可破也&#xff0c;而不可夺坚&#xff1b;丹可磨也&#xff0c;而不可夺赤。 JVM 一、双亲委派机制1.1 双亲委派的作用1.…

AST注入-从原型链污染到RCE

文章目录 概念漏洞Handlebarspug 例题 [湖湘杯 2021 final]vote 概念 什么是AST注入 在NodeJS中&#xff0c;AST经常被在JS中使用&#xff0c;作为template engines(引擎模版)和typescript等。对于引擎模版&#xff0c;结构如下图所示。 如果在JS应用中存在原型污染漏洞&…

【蓝桥杯选拔赛真题10】C++求奇数和 青少年组蓝桥杯C++选拔赛真题 STEMA比赛真题解析

目录 C/C++求奇数和 一、题目要求 1、编程实现 2、输入输出 二、算法分析 <

sqlsugar查询数据库下的所有表,批量修改表名字

查询数据库中的所有表 using SqlSugar;namespace 批量修改数据库表名 {internal class Program{static void Main(string[] args){SqlSugarClient sqlSugarClient new SqlSugarClient(new ConnectionConfig(){ConnectionString "Data Source(localdb)\\MSSQLLocalDB;In…

VS Code 开发 Spring Boot 类型的项目

在VS Code中开发Spring Boot的项目&#xff0c; 可以导入如下的扩展&#xff1a; Spring Boot ToolsSpring InitializrSpring Boot Dashboard 比较建议的方式是安装Spring Boot Extension Pack&#xff0c; 这里面就包含了上面的扩展。 安装方式就是在扩展查找 “Spring Boot…

故障诊断 | MATLAB实现GRNN广义回归神经网络故障诊断

故障诊断 | MATLAB实现GRNN广义回归神经网络故障诊断 目录 故障诊断 | MATLAB实现GRNN广义回归神经网络故障诊断故障诊断基本介绍模型描述预测过程程序设计参考资料故障诊断 基本介绍 MATLAB实现GRNN广义回归神经网络故障诊断 数据为多特征分类数据,输入12个特征,分3

SpringSecurity全家桶 (二) ——实现原理

1. SpringSecurity的强大之处 当我们并未设置登录页面时&#xff0c;我们只需要导入SpringSecurity的依赖就可以令我们的界面进入保护状态&#xff0c;由下面例子可以凸显出&#xff1a; 随便写个接口 RequestMapping("/hello")public String hello(){return "H…

【this详解】学习JavaScript的this的使用和原理这一篇就够了,超详细,建议收藏!!!

&#x1f601; 作者简介&#xff1a;一名大四的学生&#xff0c;致力学习前端开发技术 ⭐️个人主页&#xff1a;夜宵饽饽的主页 ❔ 系列专栏&#xff1a;JavaScript进阶指南 &#x1f450;学习格言&#xff1a;成功不是终点&#xff0c;失败也并非末日&#xff0c;最重要的是继…

虹科案例 | AR内窥镜手术应用为手术节约45分钟?

相信医疗从业者都知道&#xff0c;在手术室中有非常多的医疗器械屏幕&#xff0c;特别是内窥镜手术室中医生依赖这些内窥镜画面来帮助病患进行手术。但手术室空间有限&#xff0c;屏幕缩放位置相对固定&#xff0c;在特殊场景下医生观看内窥镜画面时无法关注到病患的状态。这存…

FFmpeg 硬件加速视频转码指南

基于 Windows 下演示&#xff0c;Linux 下也可以适用。 所使用 ffmpeg 版本为 BtbN 编译的 win64-gpl 版&#xff08;非 gpl-share&#xff09;&#xff0c;项目地址&#xff1a;BtbN / FFmpeg-Builds 也可以使用 gyan.dev 编译的 git-full 版&#xff0c;地址&#xff1a;gyan…
最新文章