【深度学习笔记】3_6 代码实现softmax-regression

注:本文为《动手学深度学习》开源内容,仅为个人学习记录,无抄袭搬运意图

3.6 softmax回归的从零开始实现

这一节我们来动手实现softmax回归。首先导入本节实现所需的包或模块。

import torch
import torchvision
import numpy as np
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l

3.6.1 获取和读取数据

我们将使用Fashion-MNIST数据集,并设置批量大小为256。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.6.2 初始化模型参数

跟线性回归中的例子一样,我们将使用向量表示每个样本。已知每个样本输入是高和宽均为28像素的图像。模型的输入向量的长度是 28 × 28 = 784 28 \times 28 = 784 28×28=784:该向量的每个元素对应图像中每个像素。由于图像有10个类别,单层神经网络输出层的输出个数为10,因此softmax回归的权重和偏差参数分别为 784 × 10 784 \times 10 784×10 1 × 10 1 \times 10 1×10的矩阵。

num_inputs = 784
num_outputs = 10

W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)

同之前一样,我们需要模型参数梯度。

W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True) 

3.6.3 实现softmax运算

在介绍如何定义softmax回归之前,我们先描述一下对如何对多维Tensor按维度操作。在下面的例子中,给定一个Tensor矩阵X。我们可以只对其中同一列(dim=0)或同一行(dim=1)的元素求和,并在结果中保留行和列这两个维度(keepdim=True)。

X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.sum(dim=0, keepdim=True))
print(X.sum(dim=1, keepdim=True))

输出:

tensor([[5, 7, 9]])
tensor([[ 6],
        [15]])

下面我们就可以定义前面小节里介绍的softmax运算了。在下面的函数中,矩阵X的行数是样本数,列数是输出个数。为了表达样本预测各个输出的概率,softmax运算会先通过exp函数对每个元素做指数运算,再对exp矩阵同行元素求和,最后令矩阵每行各元素与该行元素之和相除。这样一来,最终得到的矩阵每行元素和为1且非负。因此,该矩阵每行都是合法的概率分布。softmax运算的输出矩阵中的任意一行元素代表了一个样本在各个输出类别上的预测概率。

def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

可以看到,对于随机输入,我们将每个元素变成了非负数,且每一行和为1。

X = torch.rand((2, 5))
X_prob = softmax(X)
print(X_prob, X_prob.sum(dim=1))

输出:

tensor([[0.2206, 0.1520, 0.1446, 0.2690, 0.2138],
        [0.1540, 0.2290, 0.1387, 0.2019, 0.2765]]) tensor([1., 1.])

3.6.4 定义模型

有了softmax运算,我们可以定义上节描述的softmax回归模型了。这里通过view函数将每张原始图像改成长度为num_inputs的向量。

def net(X):
    return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

3.6.5 定义损失函数

上一节中,我们介绍了softmax回归使用的交叉熵损失函数。为了得到标签的预测概率,我们可以使用gather函数。在下面的例子中,变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。通过使用gather函数,我们得到了2个样本的标签的预测概率。与3.4节(softmax回归)数学表述中标签类别离散值从1开始逐一递增不同,在代码中,标签类别的离散值是从0开始逐一递增的。

y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))

输出:

tensor([[0.1000],
        [0.5000]])

下面实现了3.4节(softmax回归)中介绍的交叉熵损失函数。

def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))

3.6.6 计算分类准确率

给定一个类别的预测概率分布y_hat,我们把预测概率最大的类别作为输出类别。如果它与真实类别y一致,说明这次预测是正确的。分类准确率即正确预测数量与总预测数量之比。

为了演示准确率的计算,下面定义准确率accuracy函数。其中y_hat.argmax(dim=1)返回矩阵y_hat每行中最大元素的索引,且返回结果与变量y形状相同。相等条件判断式(y_hat.argmax(dim=1) == y)是一个类型为ByteTensorTensor,我们用float()将其转换为值为0(相等为假)或1(相等为真)的浮点型Tensor

def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()

让我们继续使用在演示gather函数时定义的变量y_haty,并将它们分别作为预测概率分布和标签。可以看到,第一个样本预测类别为2(该行最大元素0.6在本行的索引为2),与真实标签0不一致;第二个样本预测类别为2(该行最大元素0.5在本行的索引为2),与真实标签2一致。因此,这两个样本上的分类准确率为0.5。

print(accuracy(y_hat, y))

输出:

0.5

类似地,我们可以评价模型net在数据集data_iter上的准确率。

# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进:它的完整实现将在“图像增广”一节中描述
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

因为我们随机初始化了模型net,所以这个随机模型的准确率应该接近于类别个数10的倒数即0.1。

print(evaluate_accuracy(test_iter, net))

输出:

0.0681

3.6.7 训练模型

训练softmax回归的实现跟3.2(线性回归的从零开始实现)一节介绍的线性回归中的实现非常相似。我们同样使用小批量随机梯度下降来优化模型的损失函数。在训练模型时,迭代周期数num_epochs和学习率lr都是可以调的超参数。改变它们的值可能会得到分类更准确的模型。

num_epochs, lr = 5, 0.1

# 本函数已保存在d2lzh包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()
            
            # 梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            
            l.backward()
            if optimizer is None:
                d2l.sgd(params, lr, batch_size)
            else:
                optimizer.step()  # “softmax回归的简洁实现”一节将用到
            
            
            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

输出:

epoch 1, loss 0.7878, train acc 0.749, test acc 0.794
epoch 2, loss 0.5702, train acc 0.814, test acc 0.813
epoch 3, loss 0.5252, train acc 0.827, test acc 0.819
epoch 4, loss 0.5010, train acc 0.833, test acc 0.824
epoch 5, loss 0.4858, train acc 0.836, test acc 0.815

3.6.8 预测

训练完成后,现在就可以演示如何对图像进行分类了。给定一系列图像(第三行图像输出),我们比较一下它们的真实标签(第一行文本输出)和模型预测结果(第二行文本输出)。

X, y = next(iter(test_iter))# 注意这里有坑,pytorch版本不一样next写法可能不一样,可根据自己所使用的版本修改

true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

d2l.show_fashion_mnist(X[0:9], titles[0:9])

在这里插入图片描述

小结

  • 可以使用softmax回归做多类别分类。与训练线性回归相比,你会发现训练softmax回归的步骤和它非常相似:获取并读取数据、定义模型和损失函数并使用优化算法训练模型。事实上,绝大多数深度学习模型的训练都有着类似的步骤。

注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/407264.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode206: 反转链表.

题目描述 给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 解题方法 假设链表为 1→2→3→∅,我们想要把它改成∅←1←2←3。在遍历链表时,将当前节点的 next指针改为指向前一个节点。由于节点没有引用其前一…

详细讲解缓冲区

目录 理解回车和换行(\r&&\n) 那如何实现单独的回车和换行呢? 缓冲区 证明有缓冲区的存在 ​编辑 怎么刷新缓冲区(显示器缓冲区)? fflush函数​编辑 缓冲区出现的意义 I/O流 模拟倒计时小程…

Nodejs 第四十章(prisma)

什么是 prisma? Prisma 是一个现代化的数据库工具套件,用于简化和改进应用程序与数据库之间的交互。它提供了一个类型安全的查询构建器和一个强大的 ORM(对象关系映射)层,使开发人员能够以声明性的方式操作数据库。 Prisma 支持…

EasyRecovery破解版补丁免费钥匙下载

说起数据恢复软件,相信没有小伙伴不知道EasyRecovery这个软件吧,该软件具有快捷、高效、便捷的特点,且提供的功能也非常全面,不仅可以恢复各样被删除的文件、视频、图片等,还可以支持SD卡数据恢复,TF卡等各…

深入浅出CChart 每日一课——快乐高四第六十一课 飞梯十二重,CChart三维曲线图绘制

同学们好,今天继续介绍CChart本身的功能。接下来这几节课呢,笨笨老师准备对CChart的三维视图和场图功能进行详细一些的介绍。本节课首先介绍三维曲线图。 CChart软件库的开发,首先是从二维曲线图开始的,这一部分经过长时间的打磨…

SpringBoot3+Vue3 基础知识(持续更新中~)

bean 把方法的返回结果注入到ioc中 1: 2: 3: 组合注解封装 实战篇: 解析token: 统一携带token: 驼峰命名与下划线命名转换: NotEmpty!!! mybatis: PageHelper设置后,会将pageNum,和pageSize自己拼接…

ubuntu22.04@Jetson Orin Nano之OpenCV安装

ubuntu22.04Jetson Orin Nano之OpenCV安装 1. 源由2. 分析3. 证实3.1 jtop安装3.2 jtop指令3.3 GPU支持情况 4. 安装OpenCV4.1 修改内容4.2 Python2环境【不需要】4.3 ubuntu22.04环境4.4 国内/本地环境问题4.5 cudnn版本问题 5. 总结6. 参考资料 1. 源由 昨天用Jetson跑demo程…

Spring Session:入门案例

Spring Session provides an API and implementations for managing a user’s session information. Spring Session提供了一种用于管理用户session信息管理的API。 Spring Session特点 传统的Servlet应用中,Session是存储在服务端的,即:Ses…

聚道云软件连接器:高科技企业财务自动化,提升效率准确性!

客户介绍: 某互联信息技术有限公司是一家专业从事信息技术服务的高科技企业,在业内享有较高的知名度和影响力。近年来,公司业务快速发展,对信息化建设提出了更高的要求。 客户痛点: 在传统情况下,该公司的…

【探索Linux】—— 强大的命令行工具 P.23(线程池 —— 简单模拟)

阅读导航 引言一、线程池简单介绍二、Linux下线程池代码⭕Makefile文件⭕ . h 头文件✅Task.hpp✅thread.hpp✅threadPool.hpp ⭕ . cpp 文件✅testMain.cpp 三、线程池的优点温馨提示 引言 在Linux下,线程池是一种常见的并发编程模型,它能够有效地管理…

大模型综述总结--第一部分

1 目录 本文是学习https://github.com/le-wei/LLMSurvey/blob/main/assets/LLM_Survey_Chinese.pdf的总结,仅供学习,侵权联系就删 目录如下图 本次只总结一部分,刚学习有错请指出,VX关注晓理紫,关注后续。 2、概述…

字符函数和字符串函数(C语言进阶)(一)

前言 C语言中对字符和字符串的处理是很频繁的,但是c语言本身是没有字符串类型的,字符串通常放在常量字符串中或着字符数组中。 字符串常量适用于哪些对它不做修改的字符串函数。 1、函数介绍 1.1 strlen strlen:计算字符串长度 看一个代码&…

“AI教父”李一舟翻车,中国AI培训路在何方

近日,AIGC领域掀起了一场不小的风波,知名AI博主李一舟在各大平台推出的AI课程突然下架,其账号遭到禁止关注的情况。 这一事件不仅引发了广泛关注和热议,更让许多真正想学习AIGC的用户感到迷茫和困惑:在众多的AIGC课程中…

ONLYOFFICE 桌面编辑器现已更新至v8.0啦

希望你开心,希望你健康,希望你幸福,希望你点赞! 最后的最后,关注喵,关注喵,关注喵,佬佬会看到更多有趣的博客哦!!! 喵喵喵,你对我真的…

一个div最简方法画太极图

一个div最简方法画太极图 直接上代码&#xff0c;一目了然 html <div class"太极图"/>css .太极图 {position: relative;width: 400px;height: 400px;background: linear-gradient(to right,white 50%,black 50%);border-radius: 50%;box-shadow:0 0 12px …

c#高级——插件开发

案例&#xff1a;WinForm计算器插件开发 1.建立插件库&#xff0c;设置各种自己所需的插件组件 如下图所示&#xff1a;进行了计算器的加减法插件计算组件 Calculator_DLL为总插件父类 Calculator_DLL_ADD 为插件子类的控件对象 Calculator_DLL_Sub Calculator_DLL_Factory 为…

Map集合特点、遍历方式、TreeMap排序及Collections和Arrays

目录 ​编辑 一、集合框架 二、 Map集合 特点 遍历方式 HashMap与Hashtable的区别 TreeMap Collections Arrays 一、集合框架 二、 Map集合 Map集合是一种键值对的集合&#xff0c;其中每个键对应一个值。在Java中&#xff0c;Map接口定义了一种将键映射到值的数据结…

【Ubuntu】使用WSL安装Ubuntu

WSL 适用于 Linux 的 Windows 子系统 (WSL) 是 Windows 的一项功能&#xff0c;可用于在 Windows 计算机上运行 Linux 环境&#xff0c;而无需单独的虚拟机或双引导。 WSL 旨在为希望同时使用 Windows 和 Linux 的开发人员提供无缝高效的体验。安装 Linux 发行版时&#xff0c…

喝多少瓶汽水

喝多少瓶汽水 题目描述&#xff1a;解法思路&#xff1a;解法代码&#xff1a;运行结果: 题目描述&#xff1a; 水已知1瓶汽水1元&#xff0c;2个空瓶可以换⼀瓶汽水&#xff0c;输入整数n&#xff08;n>0&#xff09;&#xff0c;表示n元钱&#xff0c;计算可以多少汽水&a…
最新文章