深度学习_17_丢弃法调整过拟合

除了权重衰退法调整过拟合,还有丢弃法调整模型得过拟合现象

过拟合:

在这里插入图片描述
丢弃法如果直接丢弃会导致新期望的不确定性,为了防止这个不确定被模型学到,所以要保证丢弃后的期望和丢弃前的期望一样(个人观点)
在这里插入图片描述

顾名思义,丢弃一些元素,单保持整体期望不变,让模型自己去权衡哪些元素是最重要得部分,从而着重选择那些元素

过拟合是因为模型学习过多无用杂质,用丢弃法,丢弃的可能是重要特征,或者杂质,通过丢弃的结果,让模型权衡哪些部分是最重要的,从而学习更稳健的特征

在这里插入图片描述
丢弃法在含隐藏层的模型中应用非常广泛

实例代码:

完整代码:

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    net.eval()  # 评估状态
    for X, y in data_iter:
        out = net(X)
       # y = y.float()
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]
def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

dropout1, dropout2 = 0.7, 0.7  # 0.7 0.7

class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training = True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)  # 形状(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)  # 总体输出为(num_inputs, num_outputs), num_output类别
        self.relu = nn.ReLU()

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
        out = self.lin3(H2)
        return out

if __name__ == '__main__':

    net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

    num_epochs, lr, batch_size = 10, 0.5, 256
    loss = nn.CrossEntropyLoss()
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)  # 取衣服数据集
    trainer = torch.optim.SGD(net.parameters(), lr=lr)  # 优化器
# d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
    train_losses = []
    test_losses = []
    test_acces = []
    for epoch in range(num_epochs):
        train_metrics, _ = d2l.train_epoch_ch3(net, train_iter, loss, trainer)
        train_losses.append(train_metrics)
        test_acc = d2l.evaluate_accuracy(net, test_iter)
        test_loss = evaluate_loss(net, test_iter, loss)
        test_acces.append(test_acc)
        test_losses.append(test_loss)
        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"  训练损失: {train_metrics:.4f}, 测试损失: {test_loss:.4f}, 测试精度: {test_acc:.4f}")

    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')
    plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')
    plt.plot(test_acces, label='train_acc', color='red', linestyle='--', marker='.')
    plt.xlabel('epoch')
    plt.ylabel('loss & acc')
    plt.title('Test Loss and Train Accuracy over Epochs')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 1)  # 设置y轴的范围从0到1
    plt.show()

实例是 深度学习_11_softmax_图片识别代码&原理解析 和 深度学习_14_单层|多层感知机及代码实现 两者代码结合,为了达到过拟合效果,所以上述代码是三层感知机,在原先两层感知机的条件下多加了一层,这样模型就会过拟合,再用丢弃法调整上述三层感知机

代码讲解:

丢弃函数

def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

求模型损失函数

def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
       # y = y.float()
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]

其他不再赘述

过拟合:

在这里插入图片描述
在正常情况下,模型测试损失波动比较大,存在过拟合现象

丢弃法调整过拟合:
在这里插入图片描述

丢弃率都是0.7,测试损失比较稳定,过拟合被缓解
在这里插入图片描述
丢弃率0.2和0.7的效果

补充:

代码1:

import torch
import torch.nn as nn

# 创建一个均方误差损失函数,使用 'sum' reduction
loss_fn = nn.MSELoss(reduction='none')

# 生成一些示例数据
predictions = torch.randn(3, requires_grad=True)
targets = torch.randn(3)

# 计算均方误差损失
loss = loss_fn(predictions, targets)

# 通过对损失张量调用 .sum() 也可以得到相同的结果
loss_sum = loss_fn(predictions, targets).sum()

# 打印两者的值
print(loss)  # 输出总体均方误差损失值
print(loss_sum.item())  # 输出通过 .sum() 得到的总体均方误差损失值

在这里插入图片描述

损失函数求得是每个样本的损失所以两者输出不一样

代码2:

import torch
import torch.nn as nn

# 创建一个均方误差损失函数,使用 'sum' reduction
loss_fn = nn.MSELoss(reduction='sum')

# 生成一些示例数据
predictions = torch.randn(3, requires_grad=True)
targets = torch.randn(3)

# 计算均方误差损失
loss = loss_fn(predictions, targets)

# 通过对损失张量调用 .sum() 也可以得到相同的结果
loss_sum = loss_fn(predictions, targets).sum()

# 打印两者的值
print(loss.item())  # 输出总体均方误差损失值
print(loss_sum.item())  # 输出通过 .sum() 得到的总体均方误差损失值

在这里插入图片描述

损失函数求得是整体样本的损失和再加.sum()无效,所以两者输出相同

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/424119.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

微服务笔记

什么是微服务? 微服务是一种经过良好架构设计的分布式架构方案&#xff0c;微服务架构特征: 1.单一职责:微服务拆分粒度更小&#xff0c;每一个服务都对应唯一的业务能力&#xff0c;做到单一职责&#xff0c;避免重复业发。 2.面向服务:微服务对外暴露业务接口 3.自治:团…

基于Springboot的人事管理系统 (有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的人事管理系统 &#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&am…

数据结构与算法学习【算法思想之二分法基础】

文章目录 数据结构与算法学习【算法思想之二分查找基础】本文学习目标或巩固的知识点 最基础的二分查找&#x1f7e2;通过题目可知题解结果验证 数据结构与算法学习【算法思想之二分查找基础】 本文学习目标或巩固的知识点 学习二分法类题目 巩固基础的二分法 提前说明&#…

Matlab 机器人工具箱 Link类

文章目录 1 Link类1.1 机械臂Link类1.2 构造函数1.3 信息/显示方法1.4 转换方法1.5 操作方法1.6 测试方法1.7 重载操作1.8 属性(读/写)1.9 例子2 Link.Link2.1 创建机器人连杆对象2.2 OPTIONS2.3 注意2.4 旧语法2.5 例子3 Link的其他函数3.1 Link.A3.2 Link.char3.3 Link.displ…

ABAP - SALV教程10 添加可编辑checkbox列

几乎所有的功能报表都会有那么一个选择列&#xff0c;问了业务顾问&#xff0c;业务顾问说是用户不习惯使用报表原生的选择模式。效果图SALV的选择列是通过将列设置成checkbox_hotspot样式&#xff0c;注册单击事件完成勾选功能的。完成步骤 将SEL列设置成checkbox_hotspot样式…

递推算法(c++)

递推可以说是递归反过来的一种算法&#xff0c;递归是从后往前倒着算&#xff0c;递推是从前往后正着算。 统计每个月兔子的总数 题目描述 有一对兔子&#xff0c;从出生后第3个月起每个月都生一对兔子&#xff0c;一对小兔子长到第三个月后每个月又生一对兔子&#xff0c; …

网络编程(IP、端口、协议、UDP、TCP)【详解】

目录 1.什么是网络编程&#xff1f; 2.基本的通信架构 3.网络通信三要素 4.UDP通信-快速入门 5.UDP通信-多发多收 6.TCP通信-快速入门 7.TCP通信-多发多收 8.TCP通信-同时接收多个客户端 9.TCP通信-综合案例 1.什么是网络编程&#xff1f; 网络编程是可以让设…

Cloud整合Zookeeper代替Eureka

微服务间通信重构与服务治理笔记-CSDN博客 Zookeeper是一个分布式协调工具,可以实现注册中心功能 安装Zookeeper 随便 就用最新版本吧 进入Zookeeper 包目录 cd /usr/local/develop/ 解压 tar -zxvf apache-zookeeper-3.9.1-bin.tar.gz -C /usr/local/develop 进入配置文件…

electron nsis 安装包 window下任务栏无法正常固定与取消固定 Pin to taskbar

问题 win10系统下&#xff0c;程序任务栏在固定后取消固定&#xff0c;展示的程序内容异常。 排查 1.通过论坛查询&#xff0c;应该是与app的api setAppUserModelId 相关 https://github.com/electron/electron/issues/3303 2.electron-builder脚本 electron-builder…

VUE CLI3项目搭建 ESLint配置

VUE项目框架配置 一、工具准备 Node.js安装 安装方法&#xff1a;点击查看WebStorm安装 下载地址&#xff1a;点击查看 二、环境准备 镜像准备 1.查看代理&#xff1a;npm get registry 2.设置淘宝镜像 2.1临时使用. npm --registry https://registry.npm.taobao.org ins…

vue 使用vue-scroller 列表滑动到底部加载更多数据

安装插件 npm install vue-scroller -dmain.js import VueScroller from vue-scroller Vue.use(VueScroller)<template><div class"wrap"><div class"footer"><div class"btn" click"open true">新增</d…

LeetCode --- 三数之和

题目描述 三数之和 代码解析 暴力 在做这一道题的时候&#xff0c;脑海里先想出来的是暴力方法&#xff0c;一次排序&#xff0c;将这个数组变为有序的&#xff0c;再通过三次for循环来寻找满足条件的数字&#xff0c;然后将符合条件的数组与之前符合条件的数组进行一一对比…

Matlab 机器人工具箱 运动学

文章目录 R.fkine()R.ikine()R.ikine6s()R.jacob0、R.jacobn、R.jacob_dotjtrajctraj参考链接官网:Robotics Toolbox - Peter Corke R.fkine() 正运动学,根据关节坐标求末端执行器位姿 mdl_puma560; % 加载puma560模型 qz % 零角度 qr

Spring学习笔记(六)利用Spring的jdbc实现学生管理系统的用户登录功能

一、案例分析 本案例要求学生在控制台输入用户名密码&#xff0c;如果用户账号密码正确则显示用户所属班级&#xff0c;如果登录失败则显示登录失败。 &#xff08;1&#xff09;为了存储学生信息&#xff0c;需要创建一个数据库。 &#xff08;2&#xff09;为了程序连接数…

备战蓝桥杯Day21 - 堆排序的内置模块+topk问题

一、内置模块 在python中&#xff0c;堆排序已经设置好了内置模块&#xff0c;不想自己写的话可以使用内置模块&#xff0c;真的很方便&#xff0c;但是堆排序算法的底层逻辑最好还是要了解并掌握一下的。 使用heapq模块的heapify()函数将列表转换为堆&#xff0c;然后使用he…

CVPR2023 RIFormer, 无需TokenMixer也能达成SOTA性能的极简ViT架构

编辑 | Happy 首发 | AIWalker 链接 | https://mp.weixin.qq.com/s/l3US8Dsd0yNC19o7B1ZBgw project, paper, code Token Mixer是ViT骨干非常重要的组成成分&#xff0c;它用于对不同空域位置信息进行自适应聚合&#xff0c;但常规的自注意力往往存在高计算复杂度与高延迟问题…

记录一次架构优化处理性能从3千->3万

0.背景 优化Kafka消费入Es&#xff0c;适配600台设备上报数据&#xff0c;吞吐量到达2万每秒 1.环境配置 2.压测工具 3.未优化之前的消费逻辑 4.优化之后的消费流程 5.多线程多ESclient 6.修改ES配置&#xff0c;增加kafka分区&#xff0c;增加线程&#xff0c;提升吞吐量 7.…

pytest多重断言插件-pytest-assume

最近准备废弃之前用metersphere做的接口自动化&#xff0c;转战pytest了&#xff0c;先来分享下最近接触到的一个插件&#xff1a;pytest-assume。 在使用这个插件之前&#xff0c;如果一个用例里面有多个断言的话&#xff0c;前面的断言失败了&#xff0c;就不会去执行后面的断…

vite+vue3图片引入方式不生效解决方案

vitevue3图片引入方式不生效解决方案 引入方式改成 const wordImgnew URL(/src/assets/MicsosoftWord.png,import.meta.url).href;原理

Pycharm的下载安装与汉化

一.下载安装包 1.接下来按照步骤来就行 2.然后就能在桌面上找到打开了 3.先建立一个文件夹 二.Pycharm的汉化