【深度学习笔记】7_3 小批量随机梯度下降

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

7.3 小批量随机梯度下降

在每一次迭代中,梯度下降使用整个训练数据集来计算梯度,因此它有时也被称为批量梯度下降(batch gradient descent)。而随机梯度下降在每次迭代中只随机采样一个样本来计算梯度。正如我们在前几章中所看到的,我们还可以在每轮迭代中随机均匀采样多个样本来组成一个小批量,然后使用这个小批量来计算梯度。下面就来描述小批量随机梯度下降。

设目标函数 f ( x ) : R d → R f(\boldsymbol{x}): \mathbb{R}^d \rightarrow \mathbb{R} f(x):RdR。在迭代开始前的时间步设为0。该时间步的自变量记为 x 0 ∈ R d \boldsymbol{x}_0\in \mathbb{R}^d x0Rd,通常由随机初始化得到。在接下来的每一个时间步 t > 0 t>0 t>0中,小批量随机梯度下降随机均匀采样一个由训练数据样本索引组成的小批量 B t \mathcal{B}_t Bt。我们可以通过重复采样(sampling with replacement)或者不重复采样(sampling without replacement)得到一个小批量中的各个样本。前者允许同一个小批量中出现重复的样本,后者则不允许如此,且更常见。对于这两者间的任一种方式,都可以使用

g t ← ∇ f B t ( x t − 1 ) = 1 ∣ B ∣ ∑ i ∈ B t ∇ f i ( x t − 1 ) \boldsymbol{g}_t \leftarrow \nabla f_{\mathcal{B}_t}(\boldsymbol{x}_{t-1}) = \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}_t}\nabla f_i(\boldsymbol{x}_{t-1}) gtfBt(xt1)=B1iBtfi(xt1)

来计算时间步 t t t的小批量 B t \mathcal{B}_t Bt上目标函数位于 x t − 1 \boldsymbol{x}_{t-1} xt1处的梯度 g t \boldsymbol{g}_t gt。这里 ∣ B ∣ |\mathcal{B}| B代表批量大小,即小批量中样本的个数,是一个超参数。同随机梯度一样,重复采样所得的小批量随机梯度 g t \boldsymbol{g}_t gt也是对梯度 ∇ f ( x t − 1 ) \nabla f(\boldsymbol{x}_{t-1}) f(xt1)的无偏估计。给定学习率 η t \eta_t ηt(取正数),小批量随机梯度下降对自变量的迭代如下:

x t ← x t − 1 − η t g t . \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \eta_t \boldsymbol{g}_t. xtxt1ηtgt.

基于随机采样得到的梯度的方差在迭代过程中无法减小,因此在实际中,(小批量)随机梯度下降的学习率可以在迭代过程中自我衰减,例如 η t = η t α \eta_t=\eta t^\alpha ηt=ηtα(通常 α = − 1 \alpha=-1 α=1或者 − 0.5 -0.5 0.5)、 η t = η α t \eta_t = \eta \alpha^t ηt=ηαt(如 α = 0.95 \alpha=0.95 α=0.95)或者每迭代若干次后将学习率衰减一次。如此一来,学习率和(小批量)随机梯度乘积的方差会减小。而梯度下降在迭代过程中一直使用目标函数的真实梯度,无须自我衰减学习率。

小批量随机梯度下降中每次迭代的计算开销为 O ( ∣ B ∣ ) \mathcal{O}(|\mathcal{B}|) O(B)。当批量大小为1时,该算法即为随机梯度下降;当批量大小等于训练数据样本数时,该算法即为梯度下降。当批量较小时,每次迭代中使用的样本少,这会导致并行处理和内存使用效率变低。这使得在计算同样数目样本的情况下比使用更大批量时所花时间更多。当批量较大时,每个小批量梯度里可能含有更多的冗余信息。为了得到较好的解,批量较大时比批量较小时需要计算的样本数目可能更多,例如增大迭代周期数。

7.3.1 读取数据

本章里我们将使用一个来自NASA的测试不同飞机机翼噪音的数据集来比较各个优化算法 [1]。我们使用该数据集的前1,500个样本和5个特征,并使用标准化对数据进行预处理。

%matplotlib inline
import numpy as np
import time
import torch
from torch import nn, optim
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l

def get_data_ch7():  # 本函数已保存在d2lzh_pytorch包中方便以后使用
    data = np.genfromtxt('../../data/airfoil_self_noise.dat', delimiter='\t')
    data = (data - data.mean(axis=0)) / data.std(axis=0)
    return torch.tensor(data[:1500, :-1], dtype=torch.float32), \
    torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)

features, labels = get_data_ch7()
features.shape # torch.Size([1500, 5])

7.3.2 从零开始实现

3.2节(线性回归的从零开始实现)中已经实现过小批量随机梯度下降算法。我们在这里将它的输入参数变得更加通用,主要是为了方便本章后面介绍的其他优化算法也可以使用同样的输入。具体来说,我们添加了一个状态输入states并将超参数放在字典hyperparams里。此外,我们将在训练函数里对各个小批量样本的损失求平均,因此优化算法里的梯度不需要除以批量大小。

def sgd(params, states, hyperparams):
    for p in params:
        p.data -= hyperparams['lr'] * p.grad.data

下面实现一个通用的训练函数,以方便本章后面介绍的其他优化算法使用。它初始化一个线性回归模型,然后可以使用小批量随机梯度下降以及后续小节介绍的其他算法来训练模型。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_ch7(optimizer_fn, states, hyperparams, features, labels,
              batch_size=10, num_epochs=2):
    # 初始化模型
    net, loss = d2l.linreg, d2l.squared_loss
    
    w = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),
                           requires_grad=True)
    b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)

    def eval_loss():
        return loss(net(features, w, b), labels).mean().item()

    ls = [eval_loss()]
    data_iter = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)
    
    for _ in range(num_epochs):
        start = time.time()
        for batch_i, (X, y) in enumerate(data_iter):
            l = loss(net(X, w, b), y).mean()  # 使用平均损失
            
            # 梯度清零
            if w.grad is not None:
                w.grad.data.zero_()
                b.grad.data.zero_()
                
            l.backward()
            optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数
            if (batch_i + 1) * batch_size % 100 == 0:
                ls.append(eval_loss())  # 每100个样本记录下当前训练误差
    # 打印结果和作图
    print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))
    d2l.set_figsize()
    d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)
    d2l.plt.xlabel('epoch')
    d2l.plt.ylabel('loss')

当批量大小为样本总数1,500时,优化使用的是梯度下降。梯度下降的1个迭代周期对模型参数只迭代1次。可以看到6次迭代后目标函数值(训练损失)的下降趋向了平稳。

def train_sgd(lr, batch_size, num_epochs=2):
    train_ch7(sgd, None, {'lr': lr}, features, labels, batch_size, num_epochs)

train_sgd(1, 1500, 6)

输出:

loss: 0.243605, 0.014335 sec per epoch

在这里插入图片描述

当批量大小为1时,优化使用的是随机梯度下降。为了简化实现,有关(小批量)随机梯度下降的实验中,我们未对学习率进行自我衰减,而是直接采用较小的常数学习率。随机梯度下降中,每处理一个样本会更新一次自变量(模型参数),一个迭代周期里会对自变量进行1,500次更新。可以看到,目标函数值的下降在1个迭代周期后就变得较为平缓。

train_sgd(0.005, 1)

输出:

loss: 0.243433, 0.270011 sec per epoch

在这里插入图片描述

虽然随机梯度下降和梯度下降在一个迭代周期里都处理了1,500个样本,但实验中随机梯度下降的一个迭代周期耗时更多。这是因为随机梯度下降在一个迭代周期里做了更多次的自变量迭代,而且单样本的梯度计算难以有效利用矢量计算。

当批量大小为10时,优化使用的是小批量随机梯度下降。它在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

train_sgd(0.05, 10)

输出:

loss: 0.242805, 0.078792 sec per epoch

在这里插入图片描述

7.3.3 简洁实现

在PyTorch里可以通过创建optimizer实例来调用优化算法。这能让实现更简洁。下面实现一个通用的训练函数,它通过优化算法的函数optimizer_fn和超参数optimizer_hyperparams来创建optimizer实例。

# 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字
# 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={"lr": 0.05}
def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,
                    batch_size=10, num_epochs=2):
    # 初始化模型
    net = nn.Sequential(
        nn.Linear(features.shape[-1], 1)
    )
    loss = nn.MSELoss()
    optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)

    def eval_loss():
        return loss(net(features).view(-1), labels).item() / 2

    ls = [eval_loss()]
    data_iter = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)

    for _ in range(num_epochs):
        start = time.time()
        for batch_i, (X, y) in enumerate(data_iter):
            # 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2
            l = loss(net(X).view(-1), y) / 2 
            
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            if (batch_i + 1) * batch_size % 100 == 0:
                ls.append(eval_loss())
    # 打印结果和作图
    print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))
    d2l.set_figsize()
    d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)
    d2l.plt.xlabel('epoch')
    d2l.plt.ylabel('loss')

使用PyTorch重复上一个实验。

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, 10)

输出:

loss: 0.245491, 0.044150 sec per epoch

在这里插入图片描述

小结

  • 小批量随机梯度每次随机均匀采样一个小批量的训练样本来计算梯度。
  • 在实际中,(小批量)随机梯度下降的学习率可以在迭代过程中自我衰减。
  • 通常,小批量随机梯度在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

参考文献

[1] 飞机机翼噪音数据集。https://archive.ics.uci.edu/ml/datasets/Airfoil+Self-Noise


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/451775.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

折扣价和折扣实时转换

背景 : react 项目 问题 : 在折扣数中输入折扣2.333333,中间会多很多0,输入2.222,不能正常输入到第三位 如下图 原因 : toFixed()数字转字符串时可能会导致精度问题 解决思路 : parseFloat来解析浮点数,Number.isFinite判断给定的值是否为有…

es 查询案例分析

场景描述: 有这样一种场景,比如我们想搜索 title:Brown fox body:Brown fox 文章索引中有两条数据,兔子和狐狸两条数据 PUT /blogs/_bulk {"index": {"_id": 1}} {"title": "…

DayDreamInGIS 之 ArcGIS Pro二次开发 锐角检查

功能:检查图斑中所有的夹角,如果为锐角,在单独的标记图层中标记。生成的结果放在默认gdb中,以 图层名_锐角检查 的方式命名 大体实现方式:遍历图层中的所有要素(多部件要素分别处理)&#xff0…

【C语言】qsort函数的使用

1.使用qsort函数排序整型数据 #include <stdio.h> #include <string.h> #include <stdlib.h>//void qsort(void* base, //指针&#xff0c;指向的是待排序的数组的第一个元素 // size_t num, //是base指向的待排序数组的元素个数 // siz…

力扣每日一题 在受污染的二叉树中查找元素 哈希 DFS 二进制

Problem: 1261. 在受污染的二叉树中查找元素 思路 &#x1f468;‍&#x1f3eb; 灵神题解 &#x1f496; 二进制 时间复杂度&#xff1a;初始化为 O ( 1 ) O(1) O(1)&#xff1b;find 为 O ( m i n ( h , l o g 2 t a r g e t ) O(min(h,log_2target) O(min(h,log2​targ…

Django 学习笔记(Day1)

「写在前面」 本文为千锋教育 Django 教程的学习笔记。本着自己学习、分享他人的态度&#xff0c;分享学习笔记&#xff0c;希望能对大家有所帮助。 目录 0 课程介绍 1 Django 快速入门 1.1 Django 介绍 1.2 Django 安装 1.3 创建 Django 项目 1.4 运行 Django 项目 1.5 数据迁…

【C#】.net core 6.0 使用第三方日志插件Log4net,日志输出到控制台或者文本文档

欢迎来到《小5讲堂》 大家好&#xff0c;我是全栈小5。 这是《C#》系列文章&#xff0c;每篇文章将以博主理解的角度展开讲解&#xff0c; 特别是针对知识点的概念进行叙说&#xff0c;大部分文章将会对这些概念进行实际例子验证&#xff0c;以此达到加深对知识点的理解和掌握。…

【C++】stack/queue

链表完了之后就是我们的栈和队列了&#xff0c;当然我们的STL中也有实现&#xff0c;下面我们先来看一下简单用法&#xff0c;跟我们之前C语言实现的一样&#xff0c;stack和queue有这么几个重要的成员函数 最主要的就是这么几个&#xff1a;empty&#xff0c;push&#xff0c;…

python读取大型csv文件,降低内存占用,提高程序处理速度

文章目录 简介读取前多少行读取属性列逐块读取整个文件总结参考资料 简介 遇到大型的csv文件时&#xff0c;pandas会把该文件全部加载进内存&#xff0c;从而导致程序运行速度变慢。 本文提供了批量读取csv文件、读取属性列的方法&#xff0c;减轻内存占用情况。 import pand…

git commit --amend

git commit --amend 1. 修改已经输入的commit 1. 修改已经输入的commit 我已经输入了commit fix: 删除无用代码 然后现在表示不准确&#xff0c;然后我通过命令git commit --amend修改commit

鼓楼夜市管理wpf+sqlserver

鼓楼夜市管理系统wpfsqlserver 下载地址:鼓楼夜市管理系统wpfsqlserver 说明文档 运行前附加数据库.mdf&#xff08;或sql生成数据库&#xff09; 主要技术&#xff1a; 基于C#wpf架构和sql server数据库 功能模块&#xff1a; 登录注册 鼓楼夜市管理系统主界面所有店铺信…

企业电子招投标系统源码-从源码到实践:深入了解鸿鹄电子招投标系统与电子招投标

在数字化采购领域&#xff0c;企业需要一个高效、透明和规范的管理系统。通过采用Spring Cloud、Spring Boot2、Mybatis等先进技术&#xff0c;我们打造了全过程数字化采购管理平台。该平台具备内外协同的能力&#xff0c;通过待办消息、招标公告、中标公告和信息发布等功能模块…

Tictoc3例子

在tictoc3中&#xff0c;实现了让 tic 和 toc 这两个简单模块之间传递消息&#xff0c;传递十次后结束仿真。 首先来介绍一下程序中用到的两个函数&#xff1a; 1.omnetpp中获取模块名称的函数 virtual const char *getName() const override {return name ? name : "&q…

Python 一步一步教你用pyglet制作汉诺塔游戏(终篇)

目录 汉诺塔游戏 完整游戏 后期展望 汉诺塔游戏 汉诺塔&#xff08;Tower of Hanoi&#xff09;&#xff0c;是一个源于印度古老传说的益智玩具。这个传说讲述了大梵天创造世界的时候&#xff0c;他做了三根金刚石柱子&#xff0c;并在其中一根柱子上从下往上按照大小顺序摞…

js【详解】Promise

为什么需要使用 Promise &#xff1f; 传统回调函数的代码层层嵌套&#xff0c;形成回调地狱&#xff0c;难以阅读和维护&#xff0c;为了解决回调地狱的问题&#xff0c;诞生了 Promise 什么是 Promise &#xff1f; Promise 是一种异步编程的解决方案&#xff0c;本身是一个构…

套接字的地址结构,IP地址转换函数,网络编程的接口

目录 一、套接字的地址结构 1.1 通用socket地址结构 1.2 专用socket地址结构 1.2.1 tcp协议族 1.2.3 IP协议族 二、IP地址转换函数 三、网络编程接口 3.1 socket() 3.2 bind() 3.3 listen() 3.4 accept() 3.5 connect() 3.6 close() 3.7 recv()、send() 3.8 recv…

手写简易操作系统(五)--获得物理内存容量

前情提要 上一章中我们进入了保护模式&#xff0c;并且跳转到了32位模式下执行。这一章较为简单&#xff0c;我们来获取物理内存的实际容量。 一、获得内存容量的方式 在Linux中有多种方法获取内存容量&#xff0c;如果一种方法失败&#xff0c;就会试用其他方法。其本质上是…

考研数学|汤家凤《1800》vs 张宇《1000》,怎么选?

汤家凤的1800题和张宇的1000题都是备考数学考研的热门选择&#xff0c;但究竟哪个更适合备考呢&#xff1f;下面分享一些见解。 首先&#xff0c;让我们来看看传统习题册存在的一些问题。虽然传统习题册通常会覆盖考试的各个知识点和题型&#xff0c;但其中一些问题在于它们可…

JDBC连接MysqL

import java.sql.*;public class Demo {public static void main(String[] args) throws ClassNotFoundException, SQLException {//1.注册驱动&#xff0c;加载驱动&#xff1b;Class.forName("com.mysql.jdbc.Driver");//2.获得连接,返回connection类型的对象&…

汤唯短发造型:保留经典和适合自己的风格,也许才是最重要的

汤唯短发造型&#xff1a;保留经典和适合自己的风格&#xff0c;也许才是最重要的 汤唯短发造型登上Vogue四月刊封面&#xff0c;引发网友热议。#李秘书讲写作#说说是怎么回事&#xff1f; 这次Vogue四月刊的封面大片&#xff0c;汤唯以一头短发亮相&#xff0c;身穿五颜六色的…