【深度学习笔记】 3_13 丢弃法

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

3.13 丢弃法

除了前一节介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout)[1] 来应对过拟合问题。丢弃法有一些不同的变体。本节中提到的丢弃法特指倒置丢弃法(inverted dropout)。

3.13.1 方法

回忆一下,3.8节(多层感知机)的图3.3描述了一个单隐藏层的多层感知机。其中输入个数为4,隐藏单元个数为5,且隐藏单元 h i h_i hi i = 1 , … , 5 i=1, \ldots, 5 i=1,,5)的计算表达式为

h i = ϕ ( x 1 w 1 i + x 2 w 2 i + x 3 w 3 i + x 4 w 4 i + b i ) h_i = \phi\left(x_1 w_{1i} + x_2 w_{2i} + x_3 w_{3i} + x_4 w_{4i} + b_i\right) hi=ϕ(x1w1i+x2w2i+x3w3i+x4w4i+bi)

这里 ϕ \phi ϕ是激活函数, x 1 , … , x 4 x_1, \ldots, x_4 x1,,x4是输入,隐藏单元 i i i的权重参数为 w 1 i , … , w 4 i w_{1i}, \ldots, w_{4i} w1i,,w4i,偏差参数为 b i b_i bi。当对该隐藏层使用丢弃法时,该层的隐藏单元将有一定概率被丢弃掉。设丢弃概率为 p p p,那么有 p p p的概率 h i h_i hi会被清零,有 1 − p 1-p 1p的概率 h i h_i hi会除以 1 − p 1-p 1p做拉伸。丢弃概率是丢弃法的超参数。具体来说,设随机变量 ξ i \xi_i ξi为0和1的概率分别为 p p p 1 − p 1-p 1p。使用丢弃法时我们计算新的隐藏单元 h i ′ h_i' hi

h i ′ = ξ i 1 − p h i h_i' = \frac{\xi_i}{1-p} h_i hi=1pξihi

由于 E ( ξ i ) = 1 − p E(\xi_i) = 1-p E(ξi)=1p,因此

E ( h i ′ ) = E ( ξ i ) 1 − p h i = h i E(h_i') = \frac{E(\xi_i)}{1-p}h_i = h_i E(hi)=1pE(ξi)hi=hi

丢弃法不改变其输入的期望值。让我们对图3.3中的隐藏层使用丢弃法,一种可能的结果如图3.5所示,其中 h 2 h_2 h2 h 5 h_5 h5被清零。这时输出值的计算不再依赖 h 2 h_2 h2 h 5 h_5 h5,在反向传播时,与这两个隐藏单元相关的权重的梯度均为0。由于在训练中隐藏层神经元的丢弃是随机的,即 h 1 , … , h 5 h_1, \ldots, h_5 h1,,h5都有可能被清零,输出层的计算无法过度依赖 h 1 , … , h 5 h_1, \ldots, h_5 h1,,h5中的任一个,从而在训练模型时起到正则化的作用,并可以用来应对过拟合。在测试模型时,我们为了拿到更加确定性的结果,一般不使用丢弃法。

在这里插入图片描述

图3.5 隐藏层使用了丢弃法的多层感知机

3.13.2 从零开始实现

根据丢弃法的定义,我们可以很容易地实现它。下面的dropout函数将以drop_prob的概率丢弃X中的元素。

%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l

def dropout(X, drop_prob):
    X = X.float()
    assert 0 <= drop_prob <= 1  #如果 drop_prob 小于0或大于1,这个断言将失败,程序将停止执行并显示一个错误信息。
    keep_prob = 1 - drop_prob
    # 这种情况下把全部元素都丢弃
    if keep_prob == 0:
        return torch.zeros_like(X)
    mask = (torch.rand(X.shape) < keep_prob).float()
    
    return mask * X / keep_prob

我们运行几个例子来测试一下dropout函数。其中丢弃概率分别为0、0.5和1。

X = torch.arange(16).view(2, 8)
dropout(X, 0)

在这里插入图片描述

dropout(X, 0.5)

在这里插入图片描述

dropout(X, 1.0)

在这里插入图片描述

3.13.2.1 定义模型参数

实验中,我们依然使用3.6节(softmax回归的从零开始实现)中介绍的Fashion-MNIST数据集。我们将定义一个包含两个隐藏层的多层感知机,其中两个隐藏层的输出个数都是256。

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

W1 = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_hiddens1)), dtype=torch.float, requires_grad=True)
b1 = torch.zeros(num_hiddens1, requires_grad=True)
W2 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens1, num_hiddens2)), dtype=torch.float, requires_grad=True)
b2 = torch.zeros(num_hiddens2, requires_grad=True)
W3 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens2, num_outputs)), dtype=torch.float, requires_grad=True)
b3 = torch.zeros(num_outputs, requires_grad=True)

params = [W1, b1, W2, b2, W3, b3]

3.13.2.2 定义模型

下面定义的模型将全连接层和激活函数ReLU串起来,并对每个激活函数的输出使用丢弃法。我们可以分别设置各个层的丢弃概率。通常的建议是把靠近输入层的丢弃概率设得小一点。在这个实验中,我们把第一个隐藏层的丢弃概率设为0.2,把第二个隐藏层的丢弃概率设为0.5。我们可以通过参数is_training来判断运行模式为训练还是测试,并只需在训练模式下使用丢弃法。

drop_prob1, drop_prob2 = 0.2, 0.5

def net(X, is_training=True):
    X = X.view(-1, num_inputs)
    H1 = (torch.matmul(X, W1) + b1).relu()
    if is_training:  # 只在训练模型时使用丢弃法
        H1 = dropout(H1, drop_prob1)  # 在第一层全连接后添加丢弃层
    H2 = (torch.matmul(H1, W2) + b2).relu()
    if is_training:
        H2 = dropout(H2, drop_prob2)  # 在第二层全连接后添加丢弃层
    return torch.matmul(H2, W3) + b3

我们在对模型评估的时候不应该进行丢弃,所以我们修改一下d2lzh_pytorch中的evaluate_accuracy函数:

# 本函数已保存在d2lzh_pytorch
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        if isinstance(net, torch.nn.Module):
            net.eval() # 评估模式, 这会关闭dropout
            acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
            net.train() # 改回训练模式
        else: # 自定义的模型
            if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                # 将is_training设置成False
                acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
            else:
                acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
        n += y.shape[0]
    return acc_sum / n

注:将上诉evaluate_accuracy写回d2lzh_pytorch后要重启一下jupyter kernel才会生效。

3.13.2.3 训练和测试模型

这部分与之前多层感知机的训练和测试类似。

num_epochs, lr, batch_size = 5, 100.0, 256
loss = torch.nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

输出:

epoch 1, loss 0.0044, train acc 0.574, test acc 0.648
epoch 2, loss 0.0023, train acc 0.786, test acc 0.786
epoch 3, loss 0.0019, train acc 0.826, test acc 0.825
epoch 4, loss 0.0017, train acc 0.839, test acc 0.831
epoch 5, loss 0.0016, train acc 0.849, test acc 0.850

注:这里的学习率设置的很大,原因同3.9.6节。

3.13.3 简洁实现

在PyTorch中,我们只需要在全连接层后添加Dropout层并指定丢弃概率。在训练模型时,Dropout层将以指定的丢弃概率随机丢弃上一层的输出元素;在测试模型时(即model.eval()后),Dropout层并不发挥作用。

net = nn.Sequential(
        d2l.FlattenLayer(),
        nn.Linear(num_inputs, num_hiddens1),
        nn.ReLU(),
        nn.Dropout(drop_prob1),
        nn.Linear(num_hiddens1, num_hiddens2), 
        nn.ReLU(),
        nn.Dropout(drop_prob2),
        nn.Linear(num_hiddens2, 10)
        )

for param in net.parameters():
    nn.init.normal_(param, mean=0, std=0.01)

下面训练并测试模型。

optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

输出:

epoch 1, loss 0.0045, train acc 0.553, test acc 0.715
epoch 2, loss 0.0023, train acc 0.784, test acc 0.793
epoch 3, loss 0.0019, train acc 0.822, test acc 0.817
epoch 4, loss 0.0018, train acc 0.837, test acc 0.830
epoch 5, loss 0.0016, train acc 0.848, test acc 0.839

注:由于这里使用的是PyTorch的SGD而不是d2lzh_pytorch里面的sgd,所以就不存在3.9.6节那样学习率看起来很大的问题了。

小结

  • 我们可以通过使用丢弃法应对过拟合。
  • 丢弃法只在训练模型时使用。

参考文献

[1] Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., & Salakhutdinov, R. (2014). Dropout: a simple way to prevent neural networks from overfitting. JMLR


注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/408808.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

“点击查看显示全文”遇到的超链接默认访问的问题

今天在做一个例子&#xff0c;就是很常见的点击展开全文。 我觉得这是一个很简单的效果&#xff0c;也就几行代码的事&#xff0c;结果点击了以后立刻隐藏不见&#xff0c;控制台代码也不报错&#xff0c;耽误了我很长时间&#xff0c;最后才发现问题出在超链接身上。 “展开全…

k8s-kubeapps部署 20

部署kubeapps应用&#xff0c;为Helm提供web UI界面管理&#xff1a; 下载最新版本的kubeapps并修改其values.yaml文件 下载并拉取所需镜像&#xff1a; 部署应用 添加解析 修改svc暴露方式为LoadBalancer 得到分配地址 访问http://192.168.182.102 授权并获取token 1.24前的…

osmnx笔记:从OpenStreetMap中提取点和边的shp文件(FMM文件准备内容)

1 导入库 import osmnx as ox import time from shapely.geometry import Polygon import os import numpy as np 2 提取Openstreetmap 的graph Gox.graph_from_place(Huangpu,Shanghai,China,network_typedrive,simplifyTrue) ox.plot_graph(G) 3 提取graph中的点和边 gdf…

pytest如何在类的方法之间共享变量?

在pytest中&#xff0c;setup_class是一个特殊的方法&#xff0c;它用于在类级别的测试开始之前设置一些初始化的状态。这个方法会在类中的任何测试方法执行之前只运行一次。 当你在setup_class中使用self来修改类属性时&#xff0c;你实际上是在修改类的一个实例属性。在Pyth…

《模仿游戏》:天才团队如何破解密码学之谜

引言 计算机科学相关的电影不少&#xff0c;有探索人工智能的《黑客帝国》、还有逻辑和结构学的《盗梦空间》、还有互联网创业的《社交网络》和《硅谷海盗》、还有探索虚拟世界的《源代码》&#xff0c;更甚有国产计算机科学科幻启蒙儿童电视剧《快乐星球》。上述电影充满科技和…

函数——递归6(c++)

角谷猜想 题目描述 日本一位中学生发现一个奇妙的 定理&#xff0c;请角谷教授证明&#xff0c;而教授 无能为力&#xff0c;于是产生了角谷猜想。 猜想的内容&#xff1a;任给一个自然数&#xff0c; 若为偶数则除以2&#xff0c;若为奇数则乘 3加1&#xff0c;得到一个新的…

JDK8新特性全解析:Java8变革之旅

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

云原生之容器编排实践-ruoyi-cloud项目部署到K8S:MySQL8

背景 前面搭建好了 Kubernetes 集群与私有镜像仓库&#xff0c;终于要进入服务编排的实践环节了。本系列拿 ruoyi-cloud 项目进行练手&#xff0c;按照 MySQL &#xff0c; Nacos &#xff0c; Redis &#xff0c; Nginx &#xff0c; Gateway &#xff0c; Auth &#xff0c;…

测试圈的网红工具:Jmeter到底难在哪里?!

小欧的公司最近推出了一款在线购物应用&#xff0c;吸引了大量用户。然而随着用户数量的增加&#xff0c;应用的性能开始出现问题。用户抱怨说购物过程中页面加载缓慢&#xff0c;甚至有时候无法完成订单&#xff0c;小欧作为负责人员迫切需要找到解决方案。 在学习JMeter之前…

[VNCTF2024]-Web:CheckIn解析

查看网页 一款很经典的游戏&#xff0c;而且是用js写的 在调试器里面我们可以看见&#xff0c;如果游戏通关的话&#xff0c;它会进行一系列操作&#xff0c;包括使用console.log(_0x3d9d[0]);输出_0x3d9d[0]到控制台&#xff0c;那我们就直接在点击在控制台求出它的值

基于SpringBoot实现的医院药品管理系统

一、系统架构 前端&#xff1a;html | layui | js | css 后端&#xff1a;springboot | mybatis-plus 环境&#xff1a;jdk1.6 | mysql | maven 二、代码及数据库 三、功能介绍 01. 登录页 02. 药品库存管理-登记出入口信息 03. 药品库存管理-问题药品信息 …

软考45-上午题-【数据库】-数据操纵语言DML

一、INSERT插入语句 向SQL的基本表中插入数据有两种方式&#xff1a; ①直接插入元组值 ②插入一个查询的结果值 1-1、直接插入元组值 【注意】&#xff1a; 列名序列是可选的&#xff0c;若是所有列都要插入数值&#xff0c;则可以不写列名序列。 示例&#xff1a; 1-2、插…

基于ZYNQ的PCIE高速数据采集卡的设计(二)总体设计与上位机

采集卡总体设计及相关技术 2.1 引言 本课题是来源于雷达辐射源识别项目&#xff0c;需要对雷达辐射源中频信号进行采集传输 和存储。本章基于项目需求&#xff0c;介绍采集卡的总体设计方案。采集卡设计包括硬件设计 和软件设计。首先对采集卡的性能和指标进行分析&#x…

ELK 简介安装

1、概念介绍 日志介绍 日志就是程序产生的&#xff0c;遵循一定格式&#xff08;通常包含时间戳&#xff09;的文本数据。 通常日志由服务器生成&#xff0c;输出到不同的文件中&#xff0c;一般会有系统日志、 应用日志、安全日志。这些日志分散地存储在不同的机器上。 日志…

如何使用移动端设备在公网环境远程访问本地黑群晖

文章目录 前言本教程解决的问题是&#xff1a;按照本教程方法操作后&#xff0c;达到的效果是前排提醒&#xff1a; 1. 搭建群晖虚拟机1.1 下载黑群晖文件vmvare虚拟机安装包1.2 安装VMware虚拟机&#xff1a;1.3 解压黑群晖虚拟机文件1.4 虚拟机初始化1.5 没有搜索到黑群晖的解…

深度学习500问——Chapter01:数学基础

文章目录 前言 1.1 向量和矩阵 1.1.1 标量、向量、矩阵、张量之间的联系 1.1.2 张量与矩阵的区别 1.1.3 矩阵和向量相乘结果 1.1.4 向量和矩阵的范数归纳 1.1.5 如何判断一个矩阵为正定 1.2 导数和偏导数 1.2.1 导数偏导计算 1.2.2 导数和偏导数有什么区别 1.3 特征值和特征向量…

第6.3章:StarRocks查询加速——Bucket Shuffle Join

目录 一、StarRocks数据划分 1.1 分区 1.2 分桶 二、Bucket Shuffle Join实现原理 2.1 Bucket Shuffle Join概述 2.2 Bucket Shuffle Join工作原理 2.3 Bucket Shuffle Join规划规则 三、应用案例 注&#xff1a;本篇文章阐述的是StarRocks-3.2版本的Bucket Shuffle Jo…

rtsp推拉流

1.搭建视频服务器 smart-rtmpd: smart_rtmpd 是一款 rtmp、rtsp 服务器&#xff0c;非常好用&#xff0c;解压既运行&#xff0c;支持跨平台&#xff0c;无任何依赖&#xff0c;性能和 SRS 相比不分上下 2.推拉流 下载windows版本ffmpeg,并设置环境变量. 推流 ffmpeg -re -st…

26.java-单元测试xml注解

单元测试&xml&注解 单元测试 单元测试就是针对最小的功能单元编写测试代码&#xff0c;Java程序最小的功能单元是方法&#xff0c;因此&#xff0c;单元测试就是针对 Java 方法的测试&#xff0c;进而检查方法的正确性。 简单理解 : 就是一个测试代码的工具 目前测试…

BUU [CISCN2019 华东南赛区]Web4

BUU [CISCN2019 华东南赛区]Web4 题目描述&#xff1a;Click to launch instance. 开题&#xff1a; 点击链接&#xff0c;有点像SSRF 使用local_file://协议读到本地文件&#xff0c;无法使用file://协议读取&#xff0c;有过滤。 local_file://协议&#xff1a; local_file…
最新文章