《动手学深度学习(PyTorch版)》笔记7.5

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.5 Batch Normalization

批量规范化应用于单个可选层(也可以应用到所有层),其原理如下:
在每次训练迭代中,首先规范化输入,即通过减去其均值并除以其标准差(两者均基于当前小批量处理)。接下来应用比例系数和比例偏移。请注意,如果使用大小为1的小批量应用批量规范化,将无法学到任何东西,这是因为在减去均值之后,每个隐藏单元将为0,所以只有使用足够大的小批量,批量规范化才有效且稳定。用 x ∈ B \mathbf{x} \in \mathcal{B} xB表示一个来自小批量 B \mathcal{B} B的输入,批量规范化 B N \mathrm{BN} BN根据下式转换 x \mathbf{x} x

B N ( x ) = γ ⊙ x − μ ^ B σ ^ B + β . \mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}_\mathcal{B}}{\hat{\boldsymbol{\sigma}}_\mathcal{B}} + \boldsymbol{\beta}. BN(x)=γσ^Bxμ^B+β.

在上式中, μ ^ B \hat{\boldsymbol{\mu}}_\mathcal{B} μ^B是小批量 B \mathcal{B} B的样本均值, σ ^ B \hat{\boldsymbol{\sigma}}_\mathcal{B} σ^B是小批量 B \mathcal{B} B的样本标准差。应用标准化后,生成的小批量的平均值为0和方差为1。由于单位方差是一个主观选择的结果,因此通常包含拉伸参数(scale) γ \boldsymbol{\gamma} γ偏移参数(shift) β \boldsymbol{\beta} β,它们的形状与 x \mathbf{x} x相同,且是需要与其他模型参数一起学习的参数。

μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x , σ ^ B 2 = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ . \begin{aligned} \hat{\boldsymbol{\mu}}_\mathcal{B} &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x},\\ \hat{\boldsymbol{\sigma}}_\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned} μ^Bσ^B2=B1xBx,=B1xB(xμ^B)2+ϵ.

我们在方差估计值中添加一个小的常量 ϵ > 0 \epsilon > 0 ϵ>0,以确保永远不会除以零。此外,估计值 μ ^ B \hat{\boldsymbol{\mu}}_\mathcal{B} μ^B σ ^ B {\hat{\boldsymbol{\sigma}}_\mathcal{B}} σ^B还会通过使用关于平均值和方差的噪声(noise)来抵消缩放问题。乍看起来,这种噪声是一个问题,而事实上它是有益的。

7.5.1 Batch Normalize Layer

批量规范化层和其他层之间的一个关键区别是:由于批量规范化在完整的小批量上运行,因此我们不能像以前在引入其他层时那样忽略批量大小。我们在下面讨论全连接层和卷积层的批量规范化实现。

7.5.1.1 Fully Connected Layer

我们通常将批量规范化层置于全连接层中的仿射变换和激活函数之间。设全连接层的输入为x,权重参数和偏置参数分别为 W \mathbf{W} W b \mathbf{b} b,激活函数为 ϕ \phi ϕ,批量规范化的运算符为 B N \mathrm{BN} BN,那么使用批量规范化的全连接层的输出的计算公式如下:
h = ϕ ( B N ( W x + b ) ) . \mathbf{h} = \phi(\mathrm{BN}(\mathbf{W}\mathbf{x} + \mathbf{b}) ). h=ϕ(BN(Wx+b)).

7.5.1.2 Convolutional Layer

对于卷积层,可以在卷积操作之后和非线性激活函数之前应用批量规范化。当卷积有多个输出通道时,我们需要对这些通道的每个输出执行批量规范化。每个通道都有自己的拉伸(scale)和偏移(shift)参数,两者都是标量。假设小批量包含 m m m个样本,并且对于每个通道,卷积的输出具有高度 p p p和宽度 q q q,那么对于卷积层,在每个输出通道的 m ⋅ p ⋅ q m \cdot p \cdot q mpq个元素上同时执行每个批量规范化。因此在计算平均值和方差时,我们会收集所有空间位置的值,然后在给定通道内应用相同的均值和方差,以便在每个空间位置对值进行规范化。

7.5.2 Batch Normalization During Prediction

批量规范化在训练模式和预测模式下的行为和结果通常不同。在训练过程中,我们无法得知使用整个数据集来估计平均值和方差,所以只能根据每个小批次的平均值和方差不断训练模型,一种常用的方法是通过移动平均(具体方法在batch_norm()函数中定义)估算整个训练数据集的样本均值和方差,并在预测时使用它们得到确定的输出。而在将训练好的模型用于预测时,我们可以根据整个数据集精确计算批量规范化所需的平均值和方差,不再需要样本均值中的噪声以及在微批次上估计每个小批次产生的样本方差了。

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        #参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        #这些参数是nn.Parameter类型,因此会被PyTorch的优化器所管理,梯度更新是自动进行的
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,就将moving_mean和moving_var复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y
    
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

print(net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,)))

#简洁实现
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/373056.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

rclone基础命令解析及实战

rclone命令解析及实战 1 rclone介绍:远程同步工具 rclone是一个开源的远程数据同步工具,由Golang编写,旨在在不同平台的文件系统和多种类型的对象存储产品之间提供数据同步功能。 它支持超过 40 种不同的云存储服务,包括 Amazon S…

常用Hallmark及KEGG、GO基因查询

文献:The Molecular Signatures Database (MSigDB) hallmark gene set collection - PMC (nih.gov) GSEA | MSigDB | Browse Human Gene Sets (gsea-msigdb.org)通过msigdb数据库可以查看各个Hallmark、KEGG、GO具体包含的基因细节。 Hallmark nameProcess categor…

C# Socket通信从入门到精通(21)——Tcp客户端判断与服务器断开连接的三种方法以及C#代码实现

前言 我们开发的tcp客户端程序在连接服务器以后,经常会遇到服务器已经关闭但是作为客户端的我们不知道,这时候应该应该有一个机制我们可以实时监测客户端和服务器已经断开连接,如果已经断开了连接,我们应该及时报警提示用户客户端和服务器已经断开连接,本文介绍三种可以监…

DAY12之滑动窗口最大值

今天内容有点超乎我的能力 直接放卡哥的讲解了 239. 滑动窗口最大值 - 力扣&#xff08;LeetCode&#xff09; 先看超时的暴力解法 class Solution { public:vector<int> maxSlidingWindow(vector<int>& nums, int k) { vector<int>result; for(int …

新手养猫怎么挑选宠物空气净化器?猫用空气净化器测评推荐!

对于养猫的朋友来说&#xff0c;猫咪掉毛绝对是一个令人头痛的问题。猫毛和皮屑在室内飘散&#xff0c;不仅遍布各个角落&#xff0c;而且清理起来也相当费劲。尤其是那些顽固的猫毛&#xff0c;更是令人烦恼。更糟糕的是&#xff0c;这些毛发可能引起人体过敏反应&#xff0c;…

6.s081 学习实验记录(五)traps

文章目录 一、RISC-V assembly简介问题 二、Backtrace简介注意实验代码实验结果 三、Alarm简介注意实验代码实验结果 一、RISC-V assembly 简介 git checkout traps&#xff0c;切换到traps分支user/call.c 文件在我们输入 make fs.img 之后会被汇编为 call.asm 文件&#xf…

libev-ev_timer定时器的理解

1.相关说明 本文主要自己对于libev的ev_timer定时器的代码流程梳理&#xff0c;主要有ev_timer结构体定义变量的初始化&#xff0c;定时器变量的参数设置&#xff0c;定时器变量的使用 2.相关代码流程 下面是图片 3.相关实现代码 main.c #include <stdio.h> #include…

流浪动物救助|基于Springboot的流浪动物救助平台设计与实现(源码+数据库+文档)

流浪动物救助平台目录 目录 基于Springboot的流浪动物救助平台设计与实现 一、前言 二、系统功能设计 三、系统实现 1、用户信息管理 2、动物信息管理 3、商品评论管理 4、公告信息管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设…

使用html2canvas截图踩坑总结

年底的移动端H5需求中&#xff0c;再次用到了html2canvas这个插件&#xff0c;这个插件主要是用来对网页进行截图&#xff0c;在项目需求中&#xff0c;有个交互的点&#xff0c;就是通过用户操作&#xff0c;将页面的内容截图保存下来&#xff0c;方便用户传播扩散。 H5说明&…

【初读论文】

这里写目录标题 万字长文解析深度学习中的术语面向小白的深度学习论文术语&#xff08;持续更新&#xff09;deepsolo不懂的知识pipelinebaselineRoI(Region of Interest)分类问题中的正例负例指示函数&#xff08;indicator function&#xff09;模型性能评估指标&#xff08;…

nginx+flask+Gunicorn反代理服务拿不到真实IP的解决

背景 本人在宝塔linux环境,要部署flask的简单后端并且用Ngnix反代理,用Gunicorn框架部署。(o(╥﹏╥)o中间磕磕绊绊总算部署上去了,需要了解Gunicorn怎么部署的朋友,评论区留言,我加补一篇介绍)。但是但是,我发现 其 accesslog日志里竟然是 127.0.0.1。这怎么能…

模拟钉钉官网动画

实现思路&#xff1a;利用粘性定位sticky&#xff0c;以及滚动事件实现。首先我们应该设置滚动动画开始位置和结束位置 &#xff0c;然后根据位置计算透明度或者transform&#xff0c;scale的值。 首先根据上述图线计算属性值&#xff0c;代码如下&#xff1a; function creat…

Python基础知识:Python模块

所谓模块(Module)&#xff0c;就是一种以“.py”为命名后缀的Python 文件&#xff0c;里面包含着很多集成的函数&#xff0c;可以很方便的被其他程序和脚本导入并使用。 如果模块理解为一辆汽车&#xff0c;我们使用汽车可以完成驾驶等工作&#xff0c;那么代码就是一个个细小…

Linux内存管理:(十二)Linux 5.0内核新增的反碎片优化

文章说明&#xff1a; Linux内核版本&#xff1a;5.0 架构&#xff1a;ARM64 参考资料及图片来源&#xff1a;《奔跑吧Linux内核》 Linux 5.0内核源码注释仓库地址&#xff1a; zhangzihengya/LinuxSourceCode_v5.0_study (github.com) 外碎片化发生时&#xff0c;页面分配…

day21网页布局

文章目录 块元素和行内元素列表标签表格标签媒体元素页面结构分析iframe内联框架 块元素和行内元素 块元素&#xff1a;无论内容多少&#xff0c;该元素独占一行。(p标签、h1~h6…) 行内元素&#xff1a;内容撑开宽度&#xff0c;左右都是行内元素的可以排在一行。&#xff08…

Java基础---IO知识以及常用的API整理

Java 中的 I/O&#xff08;输入/输出&#xff09;是一个核心概念&#xff0c;涉及数据的读取和写入。Java I/O API 提供了丰富的类和接口&#xff0c;用于处理不同类型的数据流。了解 Java I/O 的基础知识对于面试和日常编程都非常重要。所以今天来整理一下&#xff1a; 1. Fi…

day42_jdbc

今日内容 0 复习昨日 1 JDBC概述 2 JDBC开发步骤 3 完成增删改操作 4 ResultSet 5 登录案例 0 复习昨日 1 写出JQuery,通过获得id获得dom,并给input输入框赋值的语句 $(“#id”).val(“值”) 2 mysql内连接和外连接的区别 内连接只会保留完全符合关联条件的数据 外连接会保留表…

LeetCode、746. 使用最小花费爬楼梯【简单,动态规划 线性DP】

文章目录 前言LeetCode、746. 使用最小花费爬楼梯【简单&#xff0c;动态规划 线性DP】题目与分类思路 资料获取 前言 博主介绍&#xff1a;✌目前全网粉丝2W&#xff0c;csdn博客专家、Java领域优质创作者&#xff0c;博客之星、阿里云平台优质作者、专注于Java后端技术领域。…

idea 快捷键ctrl+shift+f失效的解决方案

文章目录 搜狗输入法快捷键冲突微软输入法快捷键冲突 idea的快捷键ctrlshiftf按了没反应&#xff0c;理论上是快捷键冲突了&#xff0c;检查搜狗输入法和微软输入法快捷键。 搜狗输入法快捷键冲突 不需要简繁切换的快捷键&#xff0c;可以关闭它&#xff0c;或修改快捷键。 微…

Linux Shell编程系列--开篇

一、目的 从本篇开始介绍Linux Shell脚本编程&#xff0c;为简单起见&#xff0c;本篇中以一个显示当前时间的shell脚本来帮助大家理解shell脚本的组成。 SHELL脚本中可以包含变量、函数、命令等部分。 二、介绍 我们通过vim新建一个myshell.sh的脚本&#xff0c;然后输入以下…
最新文章