政安晨:【示例演绎机器学习】(三)—— 神经网络的多分类问题示例 (新闻分类)

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏政安晨的机器学习笔记

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正,让小伙伴们一起学习、交流进步,不论是学业还是工作都取得好成绩!

这个系列的前两篇文章如下:

政安晨:【示例演绎机器学习】(一)—— 剖析神经网络:学习核心的Keras APIicon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136187781政安晨:【示例演绎机器学习】(二)—— 神经网络的二分类问题示例 (影评分类)icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136204994

准备好环境后,咱们开始。

搭建环境可以参考我机器学习笔记这个栏目中的文章

引言

咱们这个系列的上一篇文章,介绍了如何用密集连接神经网络将向量输入划分为两个互斥的类别。但如果类别不止两个,要怎么做呢?


多分类问题

在机器学习中,神经网络的多分类问题是指将一组输入数据映射到多个离散的类别中的任务。与二分类问题(将数据分为两个类别)不同,多分类问题需要将数据分为多个类别。

在神经网络中,多分类问题通常使用softmax激活函数来将神经网络的输出转化为每个类别的概率。softmax可以将神经网络的输出转化为概率分布,使得每个类别的概率之和为1。然后,根据最高概率来确定数据属于哪个类别。

解决多分类问题的神经网络通常是具有多个输出单元的模型。每个输出单元对应一个类别并且输出概率最高的单元被认为是预测的类别神经网络通过训练数据来学习如何调整权重和偏置,以最小化预测结果与真实类别之间的差距。

咱们这篇文章,将构建一个模型,把某报社新闻划分到46个互斥的主题中。

由于有多个类别,因此这是一个多分类(multiclass classification)问题。

由于每个数据点只能划分到一个类别中,因此更具体地说,这是一个单标签、多分类(single-label,multiclass classification)问题如果每个数据点可以划分到多个类别(主题)中,那就是多标签、多分类(multilabel, multiclass classification)问题。

某报社数据集

咱们这篇文章,将使用某报社数据集,它包含许多短新闻及其对应的主题,由该社于上世纪发布。它是一个简单且广泛使用的文本分类数据集,其中包括46个主题。某些主题的样本相对较多,但训练集中的每个主题都有至少10个样本。与IMDB数据集和MNIST数据集类似,该报社数据集也内置为Keras的一部分。

我们来看一下这个数据集,代码如下所示:

from tensorflow.keras.datasets import reuters
(train_data, train_labels), (test_data,  test_labels) = reuters.load_data(
    num_words=10000)

演绎如下:

与IMDB数据集一样,参数num_words=10000将数据限定为前10 000个最常出现的单词,我们有8982个训练样本和2246个测试样本。

与上篇文章中影评一样,每个样本都是一个整数列表(表示单词索引)。

你可以用如下所示的代码将样本解码为单词:

将新闻解码为文本

word_index = reuters.get_word_index()

reverse_word_index = dict(
    [(value, key) for (key, value) in word_index.items()])

decoded_newswire = " ".join(
    [reverse_word_index.get(i - 3, "?") for i in     train_data[0]])

    # 注意,索引减去了3,因为0、1、2分别是为“padding”(填充)、“start of sequence”(序列开始)、“unknown”(未知词)保留的索引

演绎:

样本对应的标签是一个介于0和45之间的整数,即话题索引编号。

准备数据

你可以沿用上一个例子中的代码将数据向量化,编码输入数据如下:

import numpy as np

def vectorize_sequences(sequences, dimension=10000):

    results = np.zeros((len(sequences), dimension))

    for i, sequence in enumerate(sequences):
        results[i, sequence] = 1.

    return results

# 将训练数据向量化
x_train = vectorize_sequences(train_data)

# 将测试数据向量化
x_test = vectorize_sequences(test_data)

将标签向量化有两种方法:

既可以将标签列表转换为一个整数张量,也可以使用one-hot编码

one-hot编码是分类数据的一种常用格式,也叫分类编码(categorical encoding)。

在这个例子中,标签的one-hot编码就是将每个标签表示为全零向量,只有标签索引对应的元素为1,如下代码所示(编码标签):

def to_one_hot(labels, dimension=46):
    results = np.zeros((len(labels), dimension))
    for i, label in enumerate(labels):
        results[i, label] = 1.
    return results

# 将训练标签向量化
y_train = to_one_hot(train_labels)

# 将测试标签向量化
y_test = to_one_hot(test_labels)

演绎:

小伙伴们请注意,Keras有一个内置方法可以实现这种编码。

from tensorflow.keras.utils import to_categorical
y_train = to_categorical(train_labels)
y_test = to_categorical(test_labels)

构建模型

这个主题分类问题与前面的影评分类问题类似,二者都是对简短的文本片段进行分类。

但这个问题有一个新的限制条件:输出类别从2个变成46个。输出空间的维度要大得多。

对于前面用过的Dense层堆叠,每一层只能访问上一层输出的信息。如果某一层丢失了与分类问题相关的信息,那么后面的层永远无法恢复这些信息,也就是说,每一层都可能成为信息瓶颈。上一个例子使用了16维的中间层,但对这个例子来说,16维空间可能太小了,无法学会区分46个类别。这种维度较小的层可能成为信息瓶颈,导致相关信息永久性丢失。

因此,我们将使用维度更大的层,它包含64个单元,代码如下(模型定义):

from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential([
    layers.Dense(64, activation="relu"),
    layers.Dense(64, activation="relu"),
    layers.Dense(46, activation="softmax")
])

关于这个架构还应注意以下两点:

第一,模型的最后一层是大小为46的Dense层。也就是说,对于每个输入样本,神经网络都会输出一个46维向量。这个向量的每个元素(每个维度)代表不同的输出类别。

第二,最后一层使用了softmax激活函数。你在MNIST例子中见过这种用法。模型将输出一个在46个输出类别上的概率分布——对于每个输入样本,模型都会生成一个46维输出向量,其中output[i]是样本属于第i个类别的概率。46个概率值的总和为1。

对于这个例子,最好的损失函数是categorical_crossentropy(分类交叉熵),代码如下所示:

model.compile(optimizer="rmsprop",
              loss="categorical_crossentropy",
              metrics=["accuracy"])

它衡量的是两个概率分布之间的距离,这里两个概率分布分别是模型输出的概率分布和标签的真实分布。我们训练模型将这两个分布的距离最小化,从而让输出结果尽可能接近真实标签。

验证你的方法

我们从训练数据中留出1000个样本作为验证集,如下所示(留出验证集):

x_val = x_train[:1000]
partial_x_train = x_train[1000:]
y_val = y_train[:1000]
partial_y_train = y_train[1000:]

现在开始训练模型,共训练20轮,如下代码所示(训练模型):

history = model.fit(partial_x_train,
                    partial_y_train,
                    epochs=20,
                    batch_size=512,
                    validation_data=(x_val, y_val))

演绎:

模型训练完毕后,咱们来绘制损失曲线和精度曲线,代码演绎如下:

(绘制训练损失和验证损失

import matplotlib.pyplot as plt

loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

演绎:

绘制训练精度和验证精度

# 清空图像
plt.clf()
acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]
plt.plot(epochs, acc, "bo", label="Training accuracy")
plt.plot(epochs, val_acc, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

演绎:

模型在9轮之后开始过拟合。我们从头开始训练一个新模型,训练9轮,然后在测试集上评估模型,如下代码所示:

下面咱们从头开始训练一个模型:

model = keras.Sequential([
    layers.Dense(64, activation="relu"),
    layers.Dense(64, activation="relu"),
    layers.Dense(46, activation="softmax")
])
model.compile(optimizer="rmsprop",
              loss="categorical_crossentropy",
              metrics=["accuracy"])
model.fit(x_train,
          y_train,
          epochs=9,
          batch_size=512)
results = model.evaluate(x_test, y_test)

演绎:

最终结果如下:

这种方法可以达到约80%的精度。对于均衡的二分类问题,完全随机的分类器能达到50%的精度。但在这个例子中,我们有46个类别,各类别的样本数量可能还不一样。

那么一个随机基准模型的精度是多少呢?我们可以通过快速实现随机基准模型来验证一下。

import copy

test_labels_copy = copy.copy(test_labels)
np.random.shuffle(test_labels_copy)
hits_array = np.array(test_labels) == np.array(test_labels_copy)
hits_array.mean()

演绎:

可以看到,随机分类器的分类精度约为19%。从这个角度来看,我们的模型结果看起来相当不错。

对新数据进行预测

对新样本调用模型的predict方法,将返回每个样本在46个主题上的概率分布。我们对所有测试数据生成主题预测。

predictions的每个元素都是长度为46的向量:

这个向量的所有元素总和为1,因为它们形成了一个概率分布。

向量的最大元素就是预测类别,即概率最高的类别:

处理标签和损失的另一种方法

前面提到过另一种编码标签的方法,也就是将其转换为整数张量,如下所示:

y_train = np.array(train_labels)
y_test = np.array(test_labels)

对于这种编码方法,唯一需要改变的就是损失函数的选择。对于代码清单4-21使用的损失函数categorical_crossentropy,标签应遵循分类编码。

对于整数标签,你应该使用sparse_categorical_crossentropy(稀疏分类交叉熵)损失函数。

model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

这个新的损失函数在数学上与categorical_crossentropy相同,二者只是接口不同。

拥有足够大的中间层的重要性

前面说过,因为最终输出是46维的,所以中间层的单元不应少于46个。现在我们来看一下,如果中间层的维度远小于46(比如4维),造成了信息瓶颈,那么会发生什么?模型如下代码所示(具有信息瓶颈的模型):

model = keras.Sequential([
    layers.Dense(64, activation="relu"),
    layers.Dense(4, activation="relu"),
    layers.Dense(46, activation="softmax")
])

model.compile(optimizer="rmsprop",
              loss="categorical_crossentropy",
              metrics=["accuracy"])

model.fit(partial_x_train,
          partial_y_train,
          epochs=20,
          batch_size=128,
          validation_data=(x_val, y_val))

演绎如下:

现在模型的最大验证精度约为71%,比之前下降了8%。导致下降的主要原因在于,我们试图将大量信息(这些信息足以找到46个类别的分离超平面)压缩到维度过小的中间层。

模型能够将大部分必要信息塞进这个4维表示中,但并不是全部信息。

可以尝试使用更小或更大的层,比如32个单元、128个单元等。你在最终的softmax分类层之前使用了两个中间层。现在尝试使用一个或三个中间层。

(很多情况下,您对这类模型配置的选择靠的是直觉。)

结论

如果要对N个类别的数据点进行分类,那么模型的最后一层应该是大小为N的Dense层。

对于单标签、多分类问题,模型的最后一层应该使用softmax激活函数,这样可以输出一个在N个输出类别上的概率分布。

对于这种问题,损失函数几乎总是应该使用分类交叉熵。

它将模型输出的概率分布与目标的真实分布之间的距离最小化。

处理多分类问题的标签有两种方法

通过分类编码(也叫one-hot编码)对标签进行编码,然后使用categorical_crossentropy损失函数;

将标签编码为整数,然后使用sparse_categorical_crossentropy损失函数。

如果你需要将数据划分到多个类别中,那么应避免使用太小的中间层,以免在模型中造成信息瓶颈。


该示例演绎完毕,供小伙伴们参考。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/403360.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

简单几步通过DD工具把云服务器系统Linux改为windows

简单几部通过DD安装其他系统,当服务器的web控制台没有我们要装的系统,就需要通过DD(Linux磁盘)工具来更改系统,(已知支持KVM系统) 本文如何简单的更换系统,不通过web控制台来更换&a…

蓝桥 算法训练 粘木棍(C++)

问题描述 有N根木棍,需要将其粘贴成M个长木棍,使得最长的和最短的的差距最小。 输入格式 第一行两个整数N,M。   一行N个整数,表示木棍的长度。 输出格式 一行一个整数,表示最小的差距 样例输入 3 2 10 20 40 样例输出 10…

Excel面试题及答案(1)

1.辅助列添加,快速填充方式填充隔行的编号;定位条件定位到空值后,右击---插入整行 2.利用通配符计算A3:A9含有车间的单元格个数(保留计算公式)。 3.利用身份证号提取 “性别”、“年月日”、“年龄” 性别:利用mid()方法,添加了一列辅助列,根据提取身份证后面第2位…

十八、图像像素类型转换和归一化操作

项目功能实现&#xff1a;对一张图像进行类型转换和归一化操作 按照之前的博文结构来&#xff0c;这里就不在赘述了 一、头文件 norm.h #pragma once#include<opencv2/opencv.hpp>using namespace cv;class NORM { public:void norm(Mat& image); };#pragma once二…

大语言模型的开山之作—探秘GPT系列:GPT-1-GPT2-GPT-3的进化之路

模型模型参数创新点评价GPT1预训练微调&#xff0c; 创新点在于Task-specific input transformations。GPT215亿参数预训练PromptPredict&#xff0c; 创新点在于Zero-shotZero-shot新颖度拉满&#xff0c;但模型性能拉胯GPT31750亿参数预训练PromptPredict&#xff0c; 创新点…

洛谷P3371【模板】单源最短路径(弱化版)(RE版本和AC版本都有,这篇解析很长但受益匪浅)

解释一下什么叫邻接矩阵&#xff1a; 假设有以下无向图&#xff1a; 1/ \2---3/ \ / \4---5---6对应的邻接矩阵为&#xff1a; 1 2 3 4 5 6 1 0 1 1 0 0 0 2 1 0 1 1 1 0 3 1 1 0 0 1 1 4 0 1 0 0 1 0 5 0 1 1 1 0 1 6 0 0 1 0 1 0 …

SpringCloud全家桶---常用微服务组件(1)

注册中心: *作用: 服务管理 Eureka(不推荐)[读音: 优瑞卡] Nacos(推荐) Zookeeper [读音: 如k波] Consul [读音:康寿] **注册中心的核心功能原理(nacos)** 服务注册: 当服务启动时,会通过rest接口请求的方式向Nacos注册自己的服务 服务心跳: NacosClient 会维护一个定时心跳持…

【Python笔记-设计模式】原型模式

一、说明 原型模式是一种创建型设计模式&#xff0c; 用于创建重复的对象&#xff0c;同时又能保证性能。 使一个原型实例指定了要创建的对象的种类&#xff0c;并且通过拷贝这个原型来创建新的对象。 (一) 解决问题 主要解决了对象的创建与复制过程中的性能问题。主要针对…

【stm32】hal库-双通道ADC采集

【stm32】hal库-双通道ADC采集 CubeMX图形化配置 程序编写 /* USER CODE BEGIN PV */ #define BATCH_DATA_LEN 1 uint32_t dmaDataBuffer[BATCH_DATA_LEN]; /* USER CODE END PV *//* USER CODE BEGIN 2 */lcd_init();lcd_show_str(10, 10, 24, "Demo14_4:ADC1 ADC2 S…

SpringCloud(15)之SpringCloud Gateway

一、Spring Cloud Gateway介绍 Spring Cloud Gateway 是Spring Cloud团队的一个全新项目&#xff0c;基于Spring 5.0、SpringBoot2.0、 Project Reactor 等技术开发的网关。旨在为微服务架构提供一种简单有效统一的API路由管理方式。 Spring Cloud Gateway 作为SpringCloud生态…

文件上传---->生僻字解析漏洞

现在的现实生活中&#xff0c;存在文件上传的点&#xff0c;基本上都是白名单判断&#xff08;很少黑名单了&#xff09; 对于白名单&#xff0c;我们有截断&#xff0c;图片马&#xff0c;二次渲染&#xff0c;服务器解析漏洞这些&#xff0c;于是今天我就来补充一种在upload…

银河麒麟桌面版操作系统修改主机名

1图形化方式修改 1.1在计算机图标上右键&#xff0c;选择属性 1.2修改 1.2.1点击修改计算机名 选择玩属性后会自动跳转到关于中&#xff0c;在计算机名中点击修改图标本质就是设置里面的系统下的关于&#xff0c;我们右键计算机选择属性就直接跳转过来了 1.2.2修改系统名字 …

【Spring】SpringBoot 日志文件

目 录 一.日志有什么用&#xff1f;二.日志怎么用&#xff1f;三.自定义日志打印四.日志持久化五.日志级别六.更简单的日志输出—lombok 日志的主要掌握内容&#xff1a; 输出自定义日志信息 将日志持久化 通过设置日志的级别来筛选和控制日志的内容 一.日志有什么用&#…

YOLOv8改进 | Conv篇 | 利用YOLOv9的GELAN模块替换C2f结构(附轻量化版本 + 高效涨点版本 + 结构图)

一、本文介绍 本文给大家带来的改进机制是利用2024/02/21号最新发布的YOLOv9其中提出的GELAN模块来改进YOLOv8中的C2f,GELAN融合了CSPNet和ELAN机制同时其中利用到了RepConv在获取更多有效特征的同时在推理时专用单分支结构从而不影响推理速度,同时本文的内容提供了两种版本…

集合框架之List集合

目录 ​编辑 一、什么是UML 二、集合框架 三、List集合 1.特点 2.遍历方式 3.删除 4.优化 四、迭代器原理 五、泛型 六、装拆箱 七、ArrayList、LinkedList和Vector的区别 ArrayList和Vector的区别 LinkedList和Vector的区别 一、什么是UML UML&#xff08;Unif…

Flask——基于python完整实现客户端和服务器后端流式请求及响应

文章目录 本地客户端Flask服务器后端客户端/服务器端流式接收[打字机]效果 看了很多相关博客&#xff0c;但是都没有本地客户端和服务器后端的完整代码示例&#xff0c;有的也只说了如何流式获取后端结果&#xff0c;基本没有讲两端如何同时实现流式输入输出&#xff0c;特此整…

Nginx基础入门

一、Nginx的优势 nginx是一个高性能的HTTP和反向代理服务器&#xff0c;也是一个SMTP&#xff08;邮局&#xff09;服务器。 Nginx的web优势&#xff1a;IO多路复用&#xff0c;时分多路复用&#xff0c;频分多路复用 高并发&#xff0c;IO多路复用&#xff0c;epoll&#xf…

备战蓝桥杯---基础算法刷题1

最近在忙学校官网上的题&#xff0c;就借此记录分享一下有价值的题&#xff1a; 1.注意枚举角度 如果我们就对于不同的k常规的枚举&#xff0c;复杂度直接炸了。 于是我们考虑换一个角度&#xff0c;我们不妨从1开始枚举因子&#xff0c;我们记录下他的倍数的个数sum个&#…

每日五道java面试题之spring篇(二)

目录&#xff1a; 第一题 Spring事务传播机制第二题 Spring事务什么时候会失效?第三题 什么是bean的⾃动装配&#xff0c;有哪些⽅式&#xff1f;第四题 Spring中的Bean创建的⽣命周期有哪些步骤&#xff1f;第五题 Spring中Bean是线程安全的吗&#xff1f; 第一题 Spring事务…

QT中调用python

一.概述 1.Python功能强大&#xff0c;很多Qt或者c/c开发不方便的功能可以由Python编码开发&#xff0c;尤其是一些算法库的应用上&#xff0c;然后Qt调用Python。 2.在Qt调用Python的过程中&#xff0c;必须要安装python环境&#xff0c;并且Qt Creator中编译器与Python的版…
最新文章