Softmax回归

一、Softmax回归关键思想

1、回归问题和分类问题的区别

       Softmax回归虽然叫“回归”,但是它本质是一个分类问题。回归是估计一个连续值,而分类是预测一个离散类别。

2、Softmax回归模型

       Softmax回归跟线性回归一样将输入特征与权重做线性叠加。与线性回归的一个主要不同在于,Softmax回归的输出值个数等于标签里的类别数。比如一共有4种特征和3种输出动物类别(猫、狗、猪),则权重包含12个标量(带下标的$w$),偏差包含3个标量(带下标的$b$),且对每个输入计算$ O_1,O_2,O_3 $这三个输出:

$ \begin{aligned} o_1 &= x_1 w_{11} + x_2 w_{12} + x_3 w_{13} + x_4 w_{14} + b_1,\\ o_2 &= x_1 w_{21} + x_2 w_{22} + x_3 w_{23} + x_4 w_{24} + b_2,\\ o_3 &= x_1 w_{31} + x_2 w_{32} + x_3 w_{33} + x_4 w_{34} + b_3. \end{aligned} $

最后,再对这些输出值进行Softmax函数运算

       softmax回归同线性回归一样,也是一个单层神经网络。由于每个输出$ O_1,O_2,O_3 $的计算都要依赖于所有的输入$ X_1,X_2,X_3,X_4 $,所以softmax回归的输出层也是一个全连接层。

3、Softmax函数

       Softmax用于多分类过程中,它将多个神经元的输出(比如$ O_1,O_2,O_3 $)映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类!它通过下式将输出值变换成值为正且和为1的概率分布:

$\widehat{y_1},\widehat{y_2},\widehat{y_3} = \mathrm{softmax}(o_1,o_2,o_3)$

其中:

$ \widehat{y}_j=\frac{\exp \left( o_1 \right)}{\sum\limits_{i=1}^3{\exp \left( o_i \right)}} $, $ \widehat{y}_j=\frac{\exp \left( o_2 \right)}{\sum\limits_{i=1}^3{\exp \left( o_i \right)}} $, $ \widehat{y}_j=\frac{\exp \left( o_3 \right)}{\sum\limits_{i=1}^3{\exp \left( o_i \right)}} $

       容易看出 $ \widehat{y_1}+\widehat{y_2}+\widehat{y_3}=1 $ 且 $ \widehat{y_1}+\widehat{y_2}+\widehat{y_3}=1 $,因此 $ \widehat{y_1},\widehat{y_2},\widehat{y_3} $ 是一个合法的概率分布。此外,我们注意到:

$ arg\max\text{\ }o_i=arg\max\text{\ }\widehat{y_i} $

 因此softmax运算不改变预测类别输出。

       下图可以更好的理解Softmax函数,其实就是取自然常数e的指数相加后算比例,由于自然常数的指数($ e^x $)在$ \left( -\infty ,+\infty \right) $单调递增,因此softmax运算不改变预测类别输出。

4、交叉熵损失函数

       假设我们希望根据图片动物的轮廓、颜色等特征,来预测动物的类别,有三种可预测类别:猫、狗、猪。假设我们当前有两个模型(参数不同),这两个模型都是通过sigmoid/softmax的方式得到对于每个预测结果的概率值:

模型1:

模型1
预测真实是否正确
0.30.30.4001正确
0.30.40.3010正确
0.10.20.7100错误

       模型评价:模型1对于样本1和样本2以非常微弱的优势判断正确,对于样本3的判断则彻底错误。

模型2:

模型2
预测真实是否正确
0.10.20.7001正确
0.10.70.2010正确
0.30.40.3100错误

       模型评价:模型2对于样本1和样本2判断非常准确,对于样本3判断错误,但是相对来说没有错得太离谱。

       好了,有了模型之后,我们需要通过定义损失函数来判断模型在样本上的表现了,那么我们可以定义哪些损失函数呢?我们可以先尝试使用以下几种损失函数,然后讨论哪种效果更好。

(1)Classification Error(分类错误率)

       最为直接的损失函数定义为:

$ classification\ error=\frac{count\ of\ error\ items}{count\ of\ all\ items} $

模型1:$ classification\ error=\frac{1}{3} $

模型2:$ classification\ error=\frac{2}{3} $

       我们知道,模型1模型2虽然都是预测错了1个,但是相对来说模型2表现得更好,损失函数值照理来说应该更小,但是,很遗憾的是,classification error 并不能判断出来,所以这种损失函数虽然好理解,但表现不太好。

(2)Mean Squared Error(均方误差MSE)

       均方误差损失也是一种比较常见的损失函数,其定义为:

$ MSE=\frac{1}{n}\sum_i^n{\left( \widehat{y_i}-y_i \right) ^2} $

模型1:

对所有样本的loss求平均:

模型2:

对所有样本的loss求平均:

       我们发现,MSE能够判断出来模型2优于模型1,那为什么不采样这种损失函数呢?主要原因是在分类问题中,使用sigmoid/softmx得到概率,配合MSE损失函数时,采用梯度下降法进行学习时,会出现模型一开始训练时,学习速率非常慢的情况(损失函数 | Mean-Squared Loss - 知乎)。

       有了上面的直观分析,我们可以清楚的看到,对于分类问题的损失函数来说,分类错误率和均方误差损失都不是很好的损失函数,下面我们来看一下交叉熵损失函数的表现情况。

(3)Cross Entropy Loss Function(交叉熵损失函数)

其中:

$M$:类别的数量

$ y_{ic} $:符号函数(0或1),如果样本 i 的真实类别等于 c 取 1,否则取 0

$ p_{ic} $:观测样本 i 属于类别 c 的预测概率

$N$:样本的数量

现在我们利用这个表达式计算上面例子中的损失函数值:

模型1

对所有样本的loss求平均:

模型2:

对所有样本的loss求平均:

       可以发现,交叉熵损失函数可以捕捉到模型1和模型2预测效果的差异,因此对于Softmax回归问题我们常用交叉熵损失函数。

      下面两图可以很清晰的反应整个Softmax回归算法的流程:

二、图像分类数据集

       MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。我们将使用类似但更复杂的Fashion-MNIST数据集。

       在这里我们定义一些函数用于数据的读取与显示,这些函数已经在Python包d2l中定义好了,但为了便于大家理解,这里没有直接调用d2l中的函数。

1、读取数据集

       我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

       Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像和测试数据集(test dataset)中的1000张图像组成。因此,训练集和测试集分别包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。

print(len(mnist_train), len(mnist_test))
60000 10000

       每个输入图像的高度和宽度均为28像素。数据集由灰度图像组成,其通道数为1。为了简洁起见,本书将高度$h$像素、宽度$w$像素图像的形状记为$h \times w$($h$,$w$)。接下来我们可以打印一下mnist_train的类型和mnist_train的第一个元素。

print(type(mnist_train))
print(type(mnist_train[0]))
print(mnist_train[0])
print(mnist_train[0][0].shape)

       可以看出mnist_train的类型为<class 'torchvision.datasets.mnist.FashionMNIST'>。mnist_train的第一个元素的类型是<class 'tuple'>,是一个元组,元组第一个元素是转化为tensor后的灰度值,第二个元素是图像所属类别index,这里是9。因为是灰度图,因此channel数量为1,图片长和宽都是28,因此形状是(1,28,28)。

       Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)

       以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):   # labels:mnist_train和mnist_test里面图像的类别index(数字)
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]    # 根据index返回文本标签列表('t-shirt', 'trouser'...)

       我们现在可以创建一个函数来可视化这些样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    """
    imgs: tensor向量
    num_rows: 画图时的行数
    num_cols: 画图时的列数
    titles: 每张图片的标题
    scales: 因为要将num_rows*num_cols张图片画到一张图上,并且还要添加一些文字,
    因此需要对大图进行一定的缩放才能保证每张小图之间的间隙
    """
    figsize = (num_cols * scale, num_rows * scale)
    # figsize = (num_cols, num_rows)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

       以下是训练数据集中前18个样本的图像及其相应的标签。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

2、读取小批量数据

       为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。在每次迭代中,数据加载器每次都会读取一小批量数据,大小为`batch_size`。通过内置数据迭代器,我们可以随机打乱所有样本,从而无偏见地读取小批量。

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

3、整合所有组件

       现在我们定义`load_data_fashion_mnist`函数,用于获取和读取Fashion-MNIST数据集。这个函数返回训练集和验证集的数据迭代器。此外,这个函数还接受一个可选参数`resize`,用来将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None):
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]    # 此时的trans是一个列表
    if resize:
        trans.insert(0, transforms.Resize(resize))    # 如果提供了resize参数,则在转换链中插入Resize操作
    trans = transforms.Compose(trans)    # 将一系列的图像转换操作组合成一个转换链。
    # trans是一个由多个图像转换操作组成的列表。它按照列表中的顺序依次应用这些转换操作。
    # 这样可以将多个转换操作组合在一起,以便在加载数据时一次性应用它们。
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

       下面,我们通过指定`resize`参数来测试`load_data_fashion_mnist`函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

三、softmax回归的从零开始实现

...

参考文献

[1]  损失函数|交叉熵损失函数

[2]  深度学习模型系列一——多分类模型——Softmax 回归-CSDN博客

[3]  Softmax 回归_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/237239.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

排序算法:【冒泡排序】、逻辑运算符not用法、解释if not tag:

注意&#xff1a; 1、排序&#xff1a;将一组无序序列&#xff0c;调整为有序的序列。所谓有序&#xff0c;就是说&#xff0c;要么升序要么降序。 2、列表排序&#xff1a;将无序列表变成有序列表。 3、列表这个类里&#xff0c;内置排序方法&#xff1a;sort( )&#xff0…

排序-选择排序与堆排序

文章目录 一、选择排序二、堆排序三、时间复杂度四、稳定性 一、选择排序 思想&#xff1a; 将数组第一个元素作为min&#xff0c;然后进行遍历与其他元素对比&#xff0c;找到比min小的数就进行交换&#xff0c;直到最后一个元素就停止&#xff0c;然后再将第二个元素min&…

使用NCNN在华为M5部署MobileNet-SSD

一、下载ncnn-android-vulkan ncnn-android-vulkan.zip 文件是一个压缩文件&#xff0c;其中包含了 ncnn 框架在 Android 平台上使用 Vulkan 图形库加速的相关文件和代码。 在 Android 平台上&#xff0c;ncnn 框架可以利用 Vulkan 的并行计算能力来进行神经网络模型的推理计…

算法——位运算

常见位运算总结 基础位运算 << >> ~与&&#xff1a;有0就是0或|&#xff1a;有1就是1异或^&#xff1a;相同为0&#xff0c;相异为1 / 无进位相加 给一个数n&#xff0c;确定他的二进制表示中的第x位是0还是1 让第x位与上1即可先让n右移x位&上一个1&#…

【docker三】Docker镜像的创建方法

目录 一、Docker镜像&#xff1a; 1、 镜像的概念 2、docker的创建镜像方式&#xff1a; 1.1、基于已有镜像进行创建&#xff1a; 1.2、基于模版创建&#xff1a; 1.3、基于dockerfile创建&#xff1a; 二、Dockerfile概述 1、Dockerfile概念&#xff1a; 2、dockerfile…

【UI自动化测试】appium+python+unittest+HTMLRunner

简介 获取AppPackage和AppActivity 定位UI控件的工具 脚本结构 PageObject分层管理 HTMLTestRunner生成测试报告 启动appium server服务 以python文件模式执行脚本生成测试报告 下载与安装 下载需要自动化测试的App并安装到手机 获取AppPackage和AppActivity 参考&#xff…

AWS攻略——使用中转网关(Transit Gateway)连接不同区域(Region)VPC

文章目录 Peering方案Transit Gateway方案环境准备创建Transit Gateway Peering Connection接受邀请修改中转网关路由修改被邀请方中转网关路由修改邀请方中转网关路由 测试修改Public子网路由 知识点参考资料 区别于 《AWS攻略——使用中转网关(Transit Gateway)连接同区域(R…

云降水物理基础

云降水物理基础 云的分类 相对湿度变化方程 由相对湿度的定义&#xff0c;两边取对数之后可以推出 联立克劳修斯-克拉佩龙方程&#xff08;L和R都为常数&#xff09; 由右式看出&#xff0c;增加相对湿度的方式&#xff1a;增加水汽&#xff08;de增大&#xff09;和降低…

SpringData JPA 搭建 xml的 配置方式

1.导入版本管理依赖 到父项目里 <dependencyManagement><dependencies><dependency><groupId>org.springframework.data</groupId><artifactId>spring-data-bom</artifactId><version>2021.1.10</version><scope>…

全新UI彩虹外链网盘系统源码V5.5/支持批量封禁+优化加载速度+用户系统与分块上传

源码简介&#xff1a; 全新UI彩虹外链网盘系统源码V5.5&#xff0c;它可以支持批量封禁优化加载速度。新增用户系统与分块上传。 彩虹外链网盘&#xff0c;作为一款PHP网盘与外链分享程序&#xff0c;具备广泛的文件格式支持能力。它不仅能够实现各种格式文件的上传&#xff…

数据接口测试工具 Postman 介绍!

此文介绍好用的数据接口测试工具 Postman&#xff0c;能帮助您方便、快速、统一地管理项目中使用以及测试的数据接口。 1. Postman 简介 Postman 一款非常流行的 API 调试工具。其实&#xff0c;开发人员用的更多。因为测试人员做接口测试会有更多选择&#xff0c;例如 Jmeter…

LeetCode-周赛-思维训练-中等难度

第一题 1798. 你能构造出连续值的最大数目 解题思路 我们先抛开原题不看&#xff0c;可以先完成一道简单的题目&#xff0c;假设现在就给你一个目标值X&#xff0c;问你能够构造出从【1~X】的连续整数&#xff0c;最小需要几个数&#xff1f; 贪心假设期望&#xff1a;我们要…

node14升级node16之后,webpack3项目无法启动处理

node从14升级到16之后&#xff0c;项目就无法启动了&#xff0c;研究了webpack3升级5&#xff0c;研究好几个小时都无法启动&#xff0c;最后发现&#xff0c;微微升级几个版本就可以了。webpack还是3 版本改了好多个的&#xff0c;但是不确定具体是哪几个起作用的&#xff0c;…

【LVGL】STM32F429IGT6(在野火官网的LCD例程上)移植LVGL官方的例程(还没写完,有问题 排查中)

这里写目录标题 前言一、本次实验准备1、硬件2、软件 二、移植LVGL代码1、获取LVGL官方源码2、整理一下&#xff0c;下载后的源码文件3、开始移植 三、移植显示驱动1、enable LVGL2、修改报错部分3、修改lv_config4、修改lv_port_disp.c文件到此步遇到的问题 Undefined symbol …

Docker中部署ElasticSearch 和Kibana,用脚本实现对数据库资源的未授权访问

图未保存&#xff0c;不过文章当中的某一步骤可能会帮助到您&#xff0c;那么&#xff1a;感恩&#xff01; 1、docker中拉取镜像 #拉取镜像 docker pull elasticsearch:7.7.0#启动镜像 docker run --name elasticsearch -d -e ES_JAVA_OPTS"-Xms512m -Xmx512m" -e…

数字图像处理(实践篇)二十一 人脸识别

目录 1 安装face_recognition 2 涉及的函数 3 人脸识别方案 4 实践 使用face_recognition进行人脸识别。 1 安装face_recognition pip install face_recognition 或者 pip --default-timeout100 install face_recognition -i http://pypi.douban.com/simple --trusted-…

c#读取XML文件实现晶圆wafermapping显示demo计算电机坐标控制电机移动

c#读取XML文件实现晶圆wafermapping显示 功能&#xff1a; 1.读取XML文件&#xff0c;显示mapping图 2.在mapping视图图标移动&#xff0c;实时查看bincode,x,y索引与计算的电机坐标 3.通过设置wafer放在平台的位置x,y轴电机编码值&#xff0c;相机在wafer的中心位置&#…

类与接口常见面试题

抽象类和接口的对比 抽象类是用来捕捉子类的通用特性的。接口是抽象方法的集合。 从设计层面来说&#xff0c;抽象类是对类的抽象&#xff0c;是一种模板设计&#xff0c;接口是行为的抽象&#xff0c;是一种行为的规范。 相同点 接口和抽象类都不能实例化都位于继承的顶端…

每日一题,头歌平台c语言题目

任务描述 题目描述:输入一个字符串&#xff0c;输出反序后的字符串。 相关知识&#xff08;略&#xff09; 编程要求 请仔细阅读右侧代码&#xff0c;结合相关知识&#xff0c;在Begin-End区域内进行代码补充。 输入 一行字符 输出 逆序后的字符串 测试说明 样例输入&…

老师们居然这样把考试成绩发给家长

教育是一个复杂而多元的过程&#xff0c;其中考试成绩的发布和沟通是教育过程中的一个重要环节。然而&#xff0c;有些老师在发布考试成绩时&#xff0c;采取了一些不恰当的方式&#xff0c;给家长和学生带来了不必要的困扰和压力。本文将探讨老师们不应该采取的发布考试成绩的…
最新文章