DeepLearning - 余弦退火热重启学习率 CosineAnnealingWarmRestartsLR

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134249925

CosineAnnealingWarmRestartsLR,即 余弦退火热重启学习率,周期性修改学习率的下降和上升,间隔幅度逐渐增大,避免模型的性能抖动。其中核心参数:

  • optimizer 的参数,lr 学习率,默认学习率是 lr * GPU 数量,例如 lr 设置成 0.00001,32卡实际是 0.00032。
  • T_0,衰减的 global step 数,即单卡的运行次数,根据运行时间确定,例如 step 是 28.5 秒一次,(28.5 * 2000) / 3600 = 15.8 小时。
  • T_mult,周期间隔,逐渐加大,例如 T_mult 是 2,则表示,第n次是 T 0 ∗ T m u l t n T_0*T_{mult}^{n} T0Tmultn 步。
  • eta_min,从 LR 衰减的最小步数,可以设置成0。

源码:

optimizer = deepspeed.ops.adam.FusedAdam(self.model.parameters(), lr=learning_rate, eps=eps)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=lr_t_0, T_mult=2, eta_min=0, last_epoch=-1)

LR 曲线如下:

GitHub - SevenZhan/Pytorch: self-used pytorch utilities

源码:CosineAnnealingWarmRestarts

class CosineAnnealingWarmRestarts(LRScheduler):
    r"""Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
    is the number of epochs since the last restart and :math:`T_{i}` is the number
    of epochs between two warm restarts in SGDR:

    .. math::
        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
        \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)

    When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
    When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.

    It has been proposed in
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        T_0 (int): Number of iterations for the first restart.
        T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
        eta_min (float, optional): Minimum learning rate. Default: 0.
        last_epoch (int, optional): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
        https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError(f"Expected positive integer T_0, but got {T_0}")
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
        if not isinstance(eta_min, (float, int)):
            raise ValueError(f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}")
        self.T_0 = T_0
        self.T_i = T_0
        self.T_mult = T_mult
        self.eta_min = eta_min
        self.T_cur = last_epoch
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
                for base_lr in self.base_lrs]

[docs]    def step(self, epoch=None):
        """Step could be called after every batch update

        Example:
            >>> # xdoctest: +SKIP("Undefined vars")
            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
            >>> iters = len(dataloader)
            >>> for epoch in range(20):
            >>>     for i, sample in enumerate(dataloader):
            >>>         inputs, labels = sample['inputs'], sample['labels']
            >>>         optimizer.zero_grad()
            >>>         outputs = net(inputs)
            >>>         loss = criterion(outputs, labels)
            >>>         loss.backward()
            >>>         optimizer.step()
            >>>         scheduler.step(epoch + i / iters)

        This function can be called in an interleaved way.

        Example:
            >>> # xdoctest: +SKIP("Undefined vars")
            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
            >>> for epoch in range(20):
            >>>     scheduler.step()
            >>> scheduler.step(26)
            >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
        """

        if epoch is None and self.last_epoch < 0:
            epoch = 0

        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.T_cur = self.T_cur - self.T_i
                self.T_i = self.T_i * self.T_mult
        else:
            if epoch < 0:
                raise ValueError(f"Expected non-negative epoch, but got {epoch}")
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
        self.last_epoch = math.floor(epoch)

        class _enable_get_lr_call:

            def __init__(self, o):
                self.o = o

            def __enter__(self):
                self.o._get_lr_called_within_step = True
                return self

            def __exit__(self, type, value, traceback):
                self.o._get_lr_called_within_step = False
                return self

        with _enable_get_lr_call(self):
            for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
                param_group, lr = data
                param_group['lr'] = lr
                self.print_lr(self.verbose, i, lr, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

WandB 测试效果:

WandB

参考:

  • 知乎 - PyTorch中学习率调度器可视化介绍

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/119123.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

苹果Mac电脑fcpx视频剪辑:Final Cut Pro中文最新 for mac

Final Cut Pro是苹果公司开发的一款专业视频剪辑软件&#xff0c;它为原生64位软件&#xff0c;基于Cocoa编写&#xff0c;支持多路多核心处理器&#xff0c;支持GPU加速&#xff0c;支持后台渲染。Final Cut Pro在Mac OS平台上运行&#xff0c;适用于进行后期制作。 Final Cu…

c语言实现http下载功能,显示进度条和下载速率

#include <stdio.h>//printf #include <string.h>//字符串处理 #include <sys/socket.h>//套接字 #include <arpa/inet.h>//ip地址处理 #include <fcntl.h>//open系统调用 #include <unistd.h>//write系统调用 #include <netdb.h>//…

java版小程序商城免费搭建-直播商城平台规划及常见的营销模式有哪些?电商源码/小程序/三级分销

1. 涉及平台 平台管理、商家端&#xff08;PC端、手机端&#xff09;、买家平台&#xff08;H5/公众号、小程序、APP端&#xff08;IOS/Android&#xff09;、微服务平台&#xff08;业务服务&#xff09; 2. 核心架构 Spring Cloud、Spring Boot、Mybatis、Redis 3. 前端框架…

【机器学习2】模型评估

模型评估主要分为离线评估和在线评估两个阶段。 针对分类、 排序、 回归、序列预测等不同类型的机器学习问题&#xff0c; 评估指标的选择也有所不同。 1 评估指标 1.1准确率 准确率是指分类正确的样本占总样本个数的比例 但是准确率存在明显的问题&#xff0c;比如当负样本…

深度学习(CNN+RNN)笔记2

文章目录 第五课&#xff1a;序列模型(Sequence Models)第一周&#xff1a;循环神经网络&#xff08;Recurrent Neural Networks&#xff09;【序列模型、语言模型序列生成、对新序列采样。RNN、GRU、LSTM、双向RNN、深度RNN】第二周&#xff1a;自然语言处理与词嵌入&#xff…

easyConnect虚拟网卡未安装,导致连接失败(虚拟网卡安装失败)

前言 使用easyConnect&#xff0c;但是一直连接失败&#xff0c;看到提示错误 虚拟网卡未安装&#xff0c;请确保虚拟网卡安装成功 我的错误原因是因为我自己装过VM虚拟机&#xff0c;用过虚拟网卡然后产生的虚拟网卡冲突 解决方式 1.打开网络设置2.选择你的网络&#xff08…

OpenAI首席科学家:ChatGPT已经出现意识,人类未来将与AI融合

OpenAI首席科学家在最近的专访中抛出了很多惊人言论。在他看来&#xff0c;ChatGPT背后的神经网络已经产生了意识&#xff0c;而且未来人类会与人工智能融合起来&#xff0c;出现新的形态。而他现在工作的重点&#xff0c;已经不是去创建那个必然会出现的通用人工智能&#xff…

从零开始搭建React+TypeScript+webpack开发环境-使用iconfont构建图标库

创建iconfont项目 进入iconfont官网&#xff0c;完成注册流程&#xff0c;即可创建项目。 无法访问iconfont可尝试将电脑dns改为阿里云镜像223.5.5.5和223.6.6.6 添加图标 在图标库里选择图标&#xff0c;加入购物车 将图标添加到之前创建的项目中 生成代码 将代码配置到项目…

Modbus转Profinet网关在暖通空调系统中应用案例

在过去&#xff0c;空调系统一般采用传统的控制方式&#xff0c;通常需要使用独立的控制模块和传感器来监测和控制温度、湿度等参数。这种传统的控制方式不仅复杂&#xff0c;而且容易出现故障和误差&#xff0c;给用户的使用和维护带来了一定的困扰。 然而&#xff0c;通过P…

Linux 实现原理 — NUMA 多核架构中的多线程调度开销与性能优化

前言 NOTE&#xff1a;本文中所指 “线程” 均为可执行调度单元 Kernel Thread。 NUMA 体系结构 NUMA&#xff08;Non-Uniform Memory Access&#xff0c;非一致性存储器访问&#xff09;的设计理念是将 CPU 和 Main Memory 进行分区自治&#xff08;Local NUMA node&#x…

立冬将至,别忘记吃饺子!Don‘t forget to eat dumplings

立冬通常是每年的11月7日或8日。每年这个时候&#xff0c;河水就开始结冰。“Winter Begins” arrives on November 7 or November 8 each year. At this time of the year, some rivers in China start to freeze. 立冬是冬季的第一个节气&#xff0c;进入这一时节&#xff0…

电机应用-直流有刷电机

目录 直流有刷电机 工作原理 直流有刷减速电机的重要参数 电路原理与分析 驱动芯片分析 L298N驱动芯片 直流有刷减速电机控制实现 控制速度原理 硬件设计 L298N 野火直流有刷电机驱动板-MOS管搭建板 软件设计1&#xff1a;两个直流有刷减速电机按键控制 开发设计 …

Android Studio(项目打包成APK)

打包流程 直接上图即可 按照上面操作后&#xff0c;即可以开始打包&#xff0c;一般第一次打包都需要几分钟&#xff08;我第一次打包花了七八分钟&#xff09;&#xff0c;如果打包错误了也别担心&#xff0c;可以查看错误分析一下原因&#xff0c;实在不行可以把错误放到网站…

ElasticSearch 实现 全文检索 支持(PDF、TXT、Word、HTML等文件)通过 ingest-attachment 插件实现 文档的检索

一、Attachment 介绍 Attachment 插件是 Elasticsearch 中的一种插件&#xff0c;允许将各种二进制文件&#xff08;如PDF、Word文档等&#xff09;以及它们的内容索引到 Elasticsearch 中。插件使用 Apache Tika 库来解析和提取二进制文件的内容。通过使用 Attachment 插件&a…

Qt全局定义

一、QtGlobal头文件 头文件中包含了Qt类库的一些全局定义&#xff0c;包括&#xff1a; 基本数据类型全局函数宏定义 二、基本数据类型 三、全局函数 四、宏定义 1.Qt版本相关的宏 1.1 QT_VERSION 这个宏展开为数值形式 0xMMNNPP (MM major, NN minor, PP patch) 表示…

Hadoop知识点全面总结

文章目录 什么是HadoopHadoop发行版介绍Hadoop版本演变历史Hadoop3.x的细节优化Hadoop三大核心组件介绍HDFS体系结构NameNode介绍总结 SecondaryNameNode介绍DataNode介绍DataNode总结 MapReduce介绍分布式计算介绍MapReduce原理剖析MapReduce之Map阶段MapReduce之Reduce阶段 实…

Verilog HDL语言基础知识

目录 Verilog HDL语言基础知识 6.1.2 Verilog HDL模块的结构 6.1.3 逻辑功能定义 6.2.1 常量 6.3 运算符及表达式 6.4.2 条件语句 Verilog HDL语言基础知识 先来看两个Verilog HDL程序。 例6.1 一个8位全加器的 Verilog HDL源代码 module adder8(cout,sum,ina,…

Si4010 一款带有MCU SoC RF发射机芯片 无线遥控器

Si4010是一款完全集成的SoC RF发射机&#xff0c;带有嵌入式CIP-51 8051 MCU&#xff0c;专为1GHz以下ISM频带设计。该芯片针对电池供电的应用进行了优化&#xff0c;工作电压为1.8至3.6 V&#xff0c;待机电流小于10 nA的超低电流消耗。高功率放大器可提供高达10 dBm的输出功率…

pytorch复现_UNet

什么是UNet U-Net由收缩路径和扩张路径组成。收缩路径是一系列卷积层和汇集层&#xff0c;其中要素地图的分辨率逐渐降低。扩展路径是一系列上采样层和卷积层&#xff0c;其中特征地图的分辨率逐渐增加。 在扩展路径中的每一步&#xff0c;来自收缩路径的对应特征地图与当前特征…

前端框架Vue学习 ——(四)Axios

文章目录 Axios 介绍Axios 入门Vue项目中使用 Axios Axios 介绍 介绍: Axios 对原生的 Ajax 进行了封装&#xff0c;简化书写&#xff0c;快速开发。&#xff08;异步请求&#xff09; 官网: https://www.axios-http.cn/ 官网介绍&#xff1a;Axios 是一个基于 promise 网络请…
最新文章