(Transfer Learning)迁移学习在IMDB上训练情感分析模型

1. 背景

有些场景下,开始的时候数据量很小,如果我们用一个几千条数据训练一个全新的深度机器学习的文本分类模型,效果不会很好。这个时候你有两种选择,1.用传统的机器学习训练,2.利用迁移学习在一个预训练的模型上训练。本博客教你怎么用tensorflow Hub和keras 在少量的数据上训练一个文本分类模型。

2. 实践

2.1. 下载IMDB 数据集,参考下面博客。

Imdb影评的数据集介绍与下载_imdb影评数据集-CSDN博客

2.2.  预处理数据

替换掉imdb目录 (imdb_raw_data_dir). 创建dataset目录。

import numpy as np
import os as os

import re
from sklearn.model_selection import train_test_split

vocab_size = 30000
maxlen = 200
imdb_raw_data_dir = "/Users/harry/Documents/apps/ml/aclImdb"
save_dir = "dataset"

def get_data(datapath =r'D:\train_data\aclImdb\aclImdb\train' ):
    pos_files = os.listdir(datapath + '/pos')
    neg_files = os.listdir(datapath + '/neg')
    print(len(pos_files))
    print(len(neg_files))

    pos_all = []
    neg_all = []
    for pf, nf in zip(pos_files, neg_files):
        with open(datapath + '/pos' + '/' + pf, encoding='utf-8') as f:
            s = f.read()
            s = process(s)
            pos_all.append(s)
        with open(datapath + '/neg' + '/' + nf, encoding='utf-8') as f:

            s = f.read()
            s = process(s)
            neg_all.append(s)
    print(len(pos_all))
    # print(pos_all[0])

    print(len(neg_all))

    X_orig= np.array(pos_all + neg_all)
    # print(X_orig)
    Y_orig = np.array([1 for _ in range(len(pos_all))] + [0 for _ in range(len(neg_all))])
    print("X_orig:", X_orig.shape)
    print("Y_orig:", Y_orig.shape)

    return X_orig, Y_orig

def generate_dataset():
    X_orig, Y_orig = get_data(imdb_raw_data_dir + r'/train')
    X_orig_test, Y_orig_test = get_data(imdb_raw_data_dir + r'/test')
    X_orig = np.concatenate([X_orig, X_orig_test])
    Y_orig = np.concatenate([Y_orig, Y_orig_test])
    X = X_orig
    Y = Y_orig

    np.random.seed = 1
    random_indexs = np.random.permutation(len(X))
    X = X[random_indexs]
    Y = Y[random_indexs]
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)
    print("X_train:", X_train.shape)
    print("y_train:", y_train.shape)
    print("X_test:", X_test.shape)
    print("y_test:", y_test.shape)
    np.savez(save_dir + '/train_test', X_train=X_train, y_train=y_train, X_test= X_test, y_test=y_test )


def rm_tags(text):
    re_tag = re.compile(r'<[^>]+>')
    return re_tag.sub(' ', text)

def clean_str(string):
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)  # it's -> it 's
    string = re.sub(r"\'ve", " \'ve", string) # I've -> I 've
    string = re.sub(r"n\'t", " n\'t", string) # doesn't -> does n't
    string = re.sub(r"\'re", " \'re", string) # you're -> you are
    string = re.sub(r"\'d", " \'d", string)  # you'd -> you 'd
    string = re.sub(r"\'ll", " \'ll", string) # you'll -> you 'll
    string = re.sub(r"\'m", " \'m", string) # I'm -> I 'm
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip().lower()

def process(text):
    text = clean_str(text)
    text = rm_tags(text)
    #text = text.lower()
    return  text

if __name__ == '__main__':

    generate_dataset()


执行完后,产生train_test.npz 文件

2.3.  训练模型

1. 取数据集

def get_dataset_to_train():

    train_test = np.load('dataset/train_test.npz', allow_pickle=True)
    x_train =  train_test['X_train']
    y_train = train_test['y_train']
    x_test =  train_test['X_test']
    y_test = train_test['y_test']

    return x_train, y_train, x_test, y_test

2. 创建模型

基于nnlm-en-dim50/2 预训练的文本嵌入向量,在模型外面加了两层全连接。

def get_model():
    hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)
    # Build the model
    model = Sequential([
        hub_layer,
        Dense(16, activation='relu'),
        Dropout(0.5),
        Dense(2, activation='softmax')
    ])

    print(model.summary())

    model.compile(optimizer=keras.optimizers.Adam(),
                  loss=keras.losses.SparseCategoricalCrossentropy(),
                  metrics=[keras.metrics.SparseCategoricalAccuracy()])

    return model

还可以使用来自 TFHub 的许多其他预训练文本嵌入向量:

  • google/nnlm-en-dim128/2 - 基于与 google/nnlm-en-dim50/2 相同的数据并使用相同的 NNLM 架构进行训练,但具有更大的嵌入向量维度。更大维度的嵌入向量可以改进您的任务,但可能需要更长的时间来训练您的模型。
  • google/nnlm-en-dim128-with-normalization/2 - 与 google/nnlm-en-dim128/2 相同,但具有额外的文本归一化,例如移除标点符号。如果您的任务中的文本包含附加字符或标点符号,这会有所帮助。
  • google/universal-sentence-encoder/4 - 一个可产生 512 维嵌入向量的更大模型,使用深度平均网络 (DAN) 编码器训练。

还有很多!在 TFHub 上查找更多文本嵌入向量模型。

3. 评估你的模型

def evaluate_model(test_data, test_labels):
    model = load_trained_model()
    # Evaluate the model
    results = model.evaluate(test_data, test_labels, verbose=2)

    print("Test accuracy:", results[1])


def load_trained_model():
    # model = get_model()
    # model.load_weights('./models/model_new1.h5')

    model = tf.keras.models.load_model('models_pb')

    return model

4. 测试几个例子

def predict(real_data):
    model  = load_trained_model()
    probabilities = model.predict([real_data]);
    print("probabilities :",probabilities)
    result =  get_label(probabilities)
    return result


def get_label(probabilities):

    index = np.argmax(probabilities[0])
    print("index :" + str(index))

    result_str =  index_dic.get(str(index))
    # result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]

    return result_str


def predict_my_module():

    # review = "I don't like it"
    # review = "this is bad movie "
    # review = "This is good movie"
    review = " this is terrible movie"
    # review = "This isn‘t great movie"
    # review = "i think this is bad movie"
    # review = "I'm not very disappoint for this movie"
    # review = "I'm not very disappoint for this movie"
    # review = "I am very happy for this movie"
    #neg:0 postive:1
    s = predict(review)
    print(s)

if __name__ == '__main__':
    x_train, y_train, x_test, y_test = get_dataset_to_train()
    model = get_model()
    model = train(model, x_train, y_train, x_test, y_test)
    evaluate_model(x_test, y_test)
    predict_my_module()

完整代码

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout
import keras as keras
from keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow_hub as hub

embedding_url = "https://tfhub.dev/google/nnlm-en-dim50/2"

index_dic = {"0":"negative", "1": "positive"}

def get_dataset_to_train():

    train_test = np.load('dataset/train_test.npz', allow_pickle=True)
    x_train =  train_test['X_train']
    y_train = train_test['y_train']
    x_test =  train_test['X_test']
    y_test = train_test['y_test']

    return x_train, y_train, x_test, y_test


def get_model():
    hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)
    # Build the model
    model = Sequential([
        hub_layer,
        Dense(16, activation='relu'),
        Dropout(0.5),
        Dense(2, activation='softmax')
    ])

    print(model.summary())

    model.compile(optimizer=keras.optimizers.Adam(),
                  loss=keras.losses.SparseCategoricalCrossentropy(),
                  metrics=[keras.metrics.SparseCategoricalAccuracy()])


    return model

def train(model , train_data, train_labels, test_data, test_labels):
    # train_data, train_labels, test_data, test_labels = get_dataset_to_train()
    train_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in train_data]
    test_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in test_data]

    train_data = np.asarray(train_data)  # Convert to numpy array
    test_data = np.asarray(test_data)  # Convert to numpy array
    print(train_data.shape, test_data.shape)

    early_stop = EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=4, mode='max', verbose=1)
    # 定义ModelCheckpoint回调函数
    # checkpoint = ModelCheckpoint( './models/model_new1.h5', monitor='val_sparse_categorical_accuracy', save_best_only=True,
    #                              mode='max', verbose=1)

    checkpoint_pb = ModelCheckpoint(filepath="./models_pb/",  monitor='val_sparse_categorical_accuracy', save_weights_only=False, save_best_only=True)

    history = model.fit(train_data[:2000], train_labels[:2000], epochs=45, batch_size=45, validation_data=(test_data, test_labels), shuffle=True,
                        verbose=1, callbacks=[early_stop, checkpoint_pb])
    print("history", history)

    return model


def evaluate_model(test_data, test_labels):
    model = load_trained_model()
    # Evaluate the model
    results = model.evaluate(test_data, test_labels, verbose=2)

    print("Test accuracy:", results[1])


def predict(real_data):
    model  = load_trained_model()
    probabilities = model.predict([real_data]);
    print("probabilities :",probabilities)
    result =  get_label(probabilities)
    return result


def get_label(probabilities):

    index = np.argmax(probabilities[0])
    print("index :" + str(index))

    result_str =  index_dic.get(str(index))
    # result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]

    return result_str

def load_trained_model():
    # model = get_model()
    # model.load_weights('./models/model_new1.h5')

    model = tf.keras.models.load_model('models_pb')

    return model

def predict_my_module():

    # review = "I don't like it"
    # review = "this is bad movie "
    # review = "This is good movie"
    review = " this is terrible movie"
    # review = "This isn‘t great movie"
    # review = "i think this is bad movie"
    # review = "I'm not very disappoint for this movie"
    # review = "I'm not very disappoint for this movie"
    # review = "I am very happy for this movie"
    #neg:0 postive:1
    s = predict(review)
    print(s)


if __name__ == '__main__':
    x_train, y_train, x_test, y_test = get_dataset_to_train()
    model = get_model()
    model = train(model, x_train, y_train, x_test, y_test)
    evaluate_model(x_test, y_test)
    predict_my_module()











本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/167414.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

设置指定时间之前的时间不可选

1、el-date-picker设置今天之前的日期不可选 <el-date-picker style"width: 100%" type"date" v-model"form.resetDate" align"right" :value-format"yyyy-MM-dd" placeholder"选择调整日期":disabled"t…

【带头学C++】----- 七、链表 ---- 7.5 学生管理系统(链表--上)

目录 1.main函数设计 2.定义Node节点类型 3.链表插入结点 在main函数中调用插入函数、打印函数 插入结点函数实现&#xff08;头插法&#xff09; 插入结点函数实现&#xff08;尾插法&#xff09; 遍历链表函数实现 4.演示插入、遍历结果 目录 1.main函数设计 2.定义…

dgl的cuda版本安装+对应torch的cuda版本安装

优秀教程 1、pytorch 1.8.1 CUDA11.1 对应的DGL版本是0.6.1&#xff08;linux、windows也适用&#xff09; 2、CUDApytorchDGL安装 3、安装dgl-cuda 4、torch 安装备忘录 dgl文件下载 https://anaconda.org/dglteam/repo可以将下面网址中的cu111改成你想要的版本&#xff0…

Vue网页中使用PDF.js弹窗显示pdf文档所有内容

本文中使用的PDF.js组件版本为3.11.174&#xff08;最新版使用上会有所不同&#xff09;&#xff0c;引入文件如下&#xff1a; 首先页面定义一个隐藏的弹窗块&#xff08;此处用ElementUI的Dialog组件&#xff09; <el-dialog ref"dialogPDF" :title"pdffi…

Zookeeper中的Watch机制的原理?

前言 Zookeeper是一个分布式协调组件&#xff0c;为分布式架构下的多个应用组件提供了顺序访问控制能力。它的数据存储采用了类似于文件系统的树形结构&#xff0c;以节点的方式来管理存储在Zookeeper上的数据 Zookeeper提供了一个Watch机制&#xff0c;可以让客户端感知到Zook…

大批量合并识别成一个表或文档的方法

金鸣表格文字识别系统功能强大&#xff0c;其中可以将上百张图片或上百页PDF中的表格文字合并识别成一个表格或文档的功能尤其受到广大用户的欢迎&#xff0c;那应该怎么操作呢&#xff1f; 一、打开金鸣表格文字识别软件&#xff0c;点击左上角的“表格识别”&#xff0c;选择…

如何远程控制别人电脑进行技术支持?

怎么提供远程技术支持&#xff1f; “我朋友的电脑出了一些问题&#xff0c;问我是否可以远程控制他的电脑帮他解决。请问有什么办法能快速的远程控制别人的电脑进行故障排除呢&#xff1f;” 当电脑出问题时&#xff0c;多数情况下会采用电话沟通进行解决&#…

深度学习人体语义分割在弹幕防遮挡上的实现 - python 计算机竞赛

文章目录 1 前言1 课题背景2 技术原理和方法2.1基本原理2.2 技术选型和方法 3 实例分割4 实现效果5 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 深度学习人体语义分割在弹幕防遮挡上的应用 该项目较为新颖&#xff0c;适合作为竞…

Unity 中 TextMesh Pro 认识学习

TextMesh Pro User Guide | TextMeshPro | 3.0.6官方文档 有两个 TextMesh Pro 组件可用。 第一个 TMP 文本组件的类型为 <TextMeshPro> 旨在与 MeshRenderer 配合使用。该组件是旧版 TextMesh 组件的理想替代品。 要添加新的 <TextMeshPro> 文本对象&#xff…

Linux环境搭建(tomcat,jdk,mysql下载)

是否具备环境&#xff08;前端node&#xff0c;后端环境jdk&#xff09;安装jdk,配置环境变量 JDK下载 - 编程宝库 (codebaoku.com) 进入opt目录 把下好的安装包拖到我们的工具中 把解压包解压 解压完成&#xff0c;可以删除解压包 复制解压文件的目录&#xff0c;配置环境变量…

如何看待阿里云发布的全球首个容器计算服务 ACS?

如何看待阿里云发布的全球首个容器计算服务 ACS&#xff1f; 本文目录&#xff1a; 前言 一、什么是ACS 二、ACS 的核心特性 三、ACS 的关键技术 四、本期话题讨论 4.1、你如何看待容器计算服务 ACS 的发布&#xff1f; 4.2、你认为 ACS 的产品设计能降低企业使用 K8s的…

学习UI第一天

在工作闲暇之余&#xff0c;自己画的原型图&#xff0c;再次做一次记录&#xff0c;哈哈哈 萌宠领养UI设计原型图 https://modao.cc/proto/lq2KqIVBs48xwylNZlA7OP/sharing?view_moderead_only #萌宠领养-分享 可以点击此链接&#xff0c;进行查看O(∩_∩)O哈哈~

淘宝详情api(获取主图)2023年11月20日最新版本

返回数据格式&#xff1a; 请求链接 {"item": {"apiStack": [{"name": "esi","value": "{\"delivery\": {\"from\": \"福建莆田\", \"to\": \"全国\", \"com…

APP源码|智慧校园电子班牌源码 智慧校园云平台

智慧校园云平台电子班牌系统包括&#xff1a;智慧校园信息管理平台、saas后台管理平台、微信客户端平台、智慧班牌智能终端软件。主要用于构建学校基础架构&#xff0c;进行成员管理、权限分配以及运营数据监管等&#xff0c;是“智慧校园”的“根基”&#xff0c;是各项应用和…

“智能与未来”2024世亚国际智能机器人展会(简称:世亚智博会)

随着科技的不断发展&#xff0c;智能机器人已经成为了当今社会的热门话题。无论是工业生产、医疗护理、家庭服务等领域&#xff0c;智能机器人都发挥着越来越重要的作用。而世亚智博会作为智能机器人领域的专业展会之一&#xff0c;旨在为全球的智能机器人产业提供一个展示、交…

C语言模拟实现Liunx操作系统与用户之间的桥梁shell(代码详解)

什么是shell&#xff1f; Shell&#xff08;壳&#xff09;是指命令行界面&#xff08;CLI&#xff09;或脚本语言&#xff0c;它为用户提供了与操作系统交互的方式。它是一个程序&#xff0c;从用户那里接收命令&#xff0c;并通过与操作系统内核交互来执行这些命令。Shell充当…

px、em、rem、百分比的区别

文章目录 1. 单位类型与相对基准2. 相对长度1.em2.rem3.%百分比4.vw 和 vh5.vmin 和 vmax6.vm7.ch8. ex 3. 总结 1. 单位类型与相对基准 单位名称 单位类型&#xff08;相对/绝对&#xff09; 相对基准 px 相对 屏幕像素缩放后的尺寸 百分比 相对 font-size相对于继承&#xf…

一步一步教你如何在Windows 10上使用Java,包括下载、安装和配置等

Java开发工具包(JDK)是用于Java编程的软件,与Java虚拟机(JVM)和Java运行时环境(JRE)一起使用。JDK包括编译器和类库,允许开发人员创建可由JVM和JRE执行的Java程序。 在本教程中,你将学习在Windows上安装Java开发工具包。 检查是否安装了Java 在安装Java开发工具包之…

Vatee万腾外汇市场新力量:vatee科技决策力

在当今数字化时代&#xff0c;Vatee万腾崭露头角&#xff0c;以其强大的科技决策力进军外汇市场&#xff0c;成为该领域的新力量。这一新动向将不仅塑造外汇市场的未来&#xff0c;也展现Vatee科技决策力在金融领域的引领作用。 Vatee万腾带着先进的科技决策力进入外汇市场&…

Modbus转Profinet网关在金银精炼控制系统中应用案例

金银精炼控制系统中采用Modbus转Profinet网关&#xff08;XD-MDPN100&#xff09;连接1200plc与PID控制阀门进行通讯&#xff0c;通过控制PID阀门的大小来实现温度的恒温控制。这一系统的好处在于它能够提高金银精炼过程的效率和精确度。PID控制阀门可以根据温度的变化实时调整…