Pytorch自动混合精度的计算:torch.cuda.amp.autocast

1 autocast介绍

1.1 什么是AMP?

默认情况下,大多数深度学习框架都采用32位浮点算法进行训练。2017年,NVIDIA研究了一种用于混合精度训练的方法,该方法在训练网络时将单精度(FP32)与半精度(FP16)结合在一起,并使用相同的超参数实现了与FP32几乎相同的精度。

FP16也即半精度是一种计算机使用的二进制浮点数据类型,使用2字节存储。而FLOAT就是FP32。

1.2 autocast作用

torch.cuda.amp.autocast是PyTorch中一种混合精度的技术(仅在GPU上训练时可使用),可在保持数值精度的情况下提高训练速度和减少显存占用。

    def __init__(self, enabled : bool = True, dtype : torch.dtype = torch.float16, cache_enabled : bool = True):

它是一个自动类型转换器,可以根据输入数据的类型自动选择合适的精度进行计算,从而使得计算速度更快,同时也能够节省显存的使用。使用autocast可以避免在模型训练过程中手动进行类型转换,减少了代码实现的复杂性。

在深度学习中,通常会使用浮点数进行计算,但是浮点数需要占用更多的显存,而低精度数值可以在减少精度的同时,减少缓存使用量。因此,对于正向传播和反向传播中的大多数计算,可以使用低精度型的数值,提高内存使用效率,进而提高模型的训练速度。

1.3 autocast原理

autocast的要做的事情,简单来说就是:在进入算子计算之前,选择性的对输入进行cast操作。为了做到这点,在PyTorch1.9版本的架构上,可以分解为如下两步:

  • 在PyTorch算子调用栈上某一层插入处理函数
  • 在处理函数中对算子的输入进行必要操作

核心代码:autocast_mode.cpp

2 autocast优缺点

PyTorch中的autocast功能是一个性能优化工具,它可以自动调整某些操作的数据类型以提高效率。具体来说,它允许自动将数据类型从32位浮点(float32)转换为16位浮点(float16),这通常在使用深度学习模型进行训练时使用。

2.1 autocast优点

  • 提高性能:使用16位浮点数(half precision)进行计算可以在支持的硬件上显著提高性能,特别是在最新的GPU上。

  • 减少内存占用:16位浮点数占用的内存比32位少,这意味着在相同的内存限制下可以训练更大的模型或使用更大的批量大小。

  • 自动管理autocast能够自动管理何时使用16位浮点数,何时使用32位浮点数,这降低了手动管理数据类型的复杂性。

  • 保持精度:尽管使用了较低的精度,但autocast通常能够维持足够的数值精度,对最终模型的准确度影响不大。

2.2 autocast缺点

  • 硬件要求:并非所有的GPU都支持16位浮点数的高效运算。在不支持或优化不足的硬件上,使用autocast可能不会带来性能提升。

  • 精度问题:虽然在大多数情况下精度损失不显著,但在某些应用中,尤其是涉及到小数值或非常大的数值范围时,降低精度可能会导致问题。

  • 调试复杂性:由于autocast在模型的不同部分自动切换数据类型,这可能会在调试时增加额外的复杂性。

  • 算法限制:某些特定的算法或操作可能不适合在16位精度下运行,或者在半精度下的实现可能还不成熟。

  • 兼容性问题:某些PyTorch的特性或第三方库可能还不完全支持半精度运算。

在实际应用中,是否使用autocast通常取决于特定任务的需求、所使用的硬件以及对性能和精度的权衡。通常,对于大多数现代深度学习应用,特别是在使用最新的GPU时,使用autocast可以带来显著的性能优势。

3 使用示例

3.1 autocast混合精度计算

with autocast(): 语句块内的代码会自动进行混合精度计算,也就是根据输入数据的类型自动选择合适的精度进行计算,并且这里使用了GPU进行加速。使用示例如下:

# 导入相关库
import torch
from torch.cuda.amp import autocast

# 定义一个模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 1)

    def forward(self, x):
        with autocast():
            x = self.linear(x)
        return x

# 初始化数据和模型
x = torch.randn(1, 10).cuda()
model = MyModel().cuda()

# 进行前向传播
with autocast():
    output = model(x)

# 计算损失
loss = output.sum()

# 反向传播
loss.backward()

3.2 autocast与GradScaler一起使用

因为autocast会损失部分精度,从而导致梯度消失的问题,并且经过中间层时可能计算得到inf导致最终loss出现nan。所以我们通常将GradScaler与autocast配合使用来对梯度值进行一些放缩,来缓解上述的一些问题。

from torch.cuda.amp import autocast, GradScaler

dataloader = ...
model = Model.cuda(0)
optimizer = ...
scheduler = ...
scaler = GradScaler()  # 新建GradScale对象,用于放缩
for epoch_idx in range(epochs):
    for batch_idx, (dataset) in enumerate(dataloader):
        optimizer.zero_grad()
        dataset = dataset.cuda(0)
        with autocast():  # 自动混精度
            logits = model(dataset)
            loss = ...
        scaler.scale(loss).backward()  # scaler实现的反向误差传播
        scaler.step(optimizer)  # 优化器中的值也需要放缩
        scaler.update()  # 更新scaler
  scheduler.step()
...

4 可能出现的问题

使用autocast技术进行混精度训练时loss经常会出现'nan',有以下三种可能原因:

  • 精度损失,有效位数减少,导致输出时数据末位的值被省去,最终出现nan的现象。该情况可以使用GradScaler(上文所示)来解决。
  • 损失函数中使用了log等形式的函数,或是变量出现在了分母中,并且训练时,该数值变得非常小时,混精度可能会让该值更接近0或是等于0,导致了数学上的log(0)或是x/0的情况出现,从而出现'inf'或'nan'的问题。这种时候需要针对该问题设置一个确定值。例如:当log(x)出现-inf的时候,我们直接将输出中该位置的-inf设置为-100,即可解决这一问题。
  • 模型内部存在的问题,比如模型过深,本身梯度回传时值已经非常小。这种问题难以解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/146294.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Dart利用私有构造函数_()创建单例模式

文章目录 类的构造函数_()函数dart中构造函数定义 类的构造函数 类的构造函数有两种: 1)默认构造函数: 当实例化对象的时候,会自动调用的函数,构造函数的名称和类的名称相同,在一个类中默认构造函数只能由…

post 和get参数 请求

json参数 post请求格式 RestController public class HelloController { //json参数 post 请求RequestMapping("/jsonParam")public String jsonParam(RequestBody User user){System.out.println(user);return "OK";} } postman 接口测试工具…

spring cloud alibaba 简介

微服务搭建组件选型 1.服务注册中心 Nacos(spring-cloud-alibaba) 2.服务通信 OpenFeign(spring-cloud) 3.服务熔断、降级、限流 Sentinel(spring-cloud-alibaba) 4.网关 Gateway(spring-cloud) 5.服务配置中心 …

MySQL被攻击后创建数据库报错1044 - Access denied for user ‘root‘@‘%‘ to database ‘xxx‘

MySQL被攻击后创建数据库报错1044 - Access denied for user root% to database xxx 一、问题二、解决过程1、正常过程2、踩坑(已经解决问题的可以不看) 一、问题 最近数据库被攻击了,业务数据库都没了 还好也不是有重要数据,但再…

【HUST】网安纳米|2023年研究生纳米技术考试参考

目录 1 纳米材料是什么 2 纳米材料的结构特性 3 纳米结构的其他特性 4 纳米结构的检测技术 5 纳米材料的应用 打印建议:PPT彩印(这样重点比较突出),每面12张PPT,简单做一下关键词目录,亲测可以看清。如…

Uniapp开发 购物商城源码 在线电商商城源码 适配移动终端项目及各小程序

lilishop电商商城系统 商城移动端,使用Uniapp开发,可编译为所有移动终端项目及各小程序 源码下载:https://download.csdn.net/download/m0_66047725/88487579 源码下载2:关注我留言

02 # 类型基础:强类型与弱类型

宽泛的定义 在强类型语言中,当一个对象从调用函数传递到被调用函数时,其类型必须与被调用函数中声明的类型兼容 – Liskov, Zilles 1974 通俗定义 强类型语言不允许改变变量的数据类型,除非进行强制类型转换 比如下面 Java 里不能将布尔类…

最小二乘法及参数辨识

文章目录 一、最小二乘法1.1 定义1.2 SISO系统运用最小二乘估计进行辨识1.3 几何解释1.4 最小二乘法性质 二、加权最小二乘法三、递推最小二乘法四、增广最小二乘法 一、最小二乘法 1.1 定义 1974年高斯提出的最小二乘法的基本原理是未知量的最可能值是使各项实际观测值和计算…

〖大前端 - 基础入门三大核心之JS篇㉟〗- JavaScript 的DOM简介

说明:该文属于 大前端全栈架构白宝书专栏,目前阶段免费,如需要项目实战或者是体系化资源,文末名片加V!作者:不渴望力量的哈士奇(哈哥),十余年工作经验, 从事过全栈研发、产品经理等工作&#xf…

基于Python实现汽车销售数据可视化【500010086】

导入模块 import numpy as np import pandas as pd import plotly.graph_objects as go import plotly.express as px获取数据 df1 pd.read_excel(r"./data/中国汽车总体销量.xlsx") print(df1.head(5))df1.info()df1[年份] df1[时间].dt.year df1[月份] df1[时…

科研学习|科研软件——有序多分类Logistic回归的SPSS教程!

一、问题与数据 研究者想调查人们对“本国税收过高”的赞同程度:Strongly Disagree——非常不同意,用“0”表示;Disagree——不同意,用“1”表示;Agree--同意,用“2”表示;Strongly Agree--非常…

【VBA】基于EXCEL生成Insert语句工具

工具介绍 基于Excel生成INSERT语句工具是一个辅助工具,用于帮助用户根据Excel数据生成INSERT语句。通常,在数据库中插入大量数据时,手动编写INSERT语句会非常繁琐和耗时。而使用这个工具,可以通过Excel中的数据自动生成相应的INS…

第五章ARM处理器的嵌入式硬件系统设计——课后习题

1ARM处理器的工作状态 ARM处理器有两种工作状态。具体而言,ARM处理器执行32位ARM指令集时,工作在ARM状态,当ARM处理器执行16位thumb指令集时候,工作在thumb状态。 1ARM指令特点 1一个大的,统一的寄存器文件。 2基于…

nginx四层tcp负载均衡及主备、四层udp负载均衡及主备、7层http负载均衡及主备配置(wndows系统主备、负载均衡)

准备工作 服务器上安装、配置网络负载平衡管理器 windows服务器热备、负载均衡配置-CSDN博客 在windows服务器上安装vmware17 win10 上安装vmware17-CSDN博客 在windows上利用vmware17 搭建centos7 mini版 在windows上利用vmware17 搭建centos7 mini版本服务器-CSDN博客 …

20231114在HP笔记本的ubuntu20.04系统下向RealmeQ手机发送PDF文件

20231114在HP笔记本的ubuntu20.04系统下向RealmeQ手机发送PDF文件 2023/11/14 14:11 手机:Realme Q 笔记本电脑:HP https://item.jd.com/100012583174.html 惠普(HP)战66 三代AMD版 14英寸轻薄笔记本电脑(锐龙7nm 六核…

【Shell脚本11】Shell 函数

Shell 函数 linux shell 可以用户定义函数,然后在shell脚本中可以随便调用。 shell中函数的定义格式如下: [ function ] funname [()]{action;[return int;]}说明: 1、可以带function fun() 定义,也可以直接fun() 定义,不带任何…

【工艺库】SMIC数字后端工艺库

工艺库文件 Calibredigital文件夹apollolefprimetimesynopsys TD系列文件夹 本来是想找一个工艺库,想要其包含逻辑综合和SPICE Model相关的库文件,但是找了很久也没有直接找到想要的,主要原因还是自己对工艺库文件的构成不是很清楚&#xff0…

flink 8081 web页面无法被局域网内其他机器访问

实现 http://localhost:8081/#/overview 可以被局域网其他机器访问

高德地图系列(一):vue项目如何使用高德地图、入门以及基本控件使用

目录 第一章 前言 第二章 准备工作 2.1 账号注册 2.2 高德地图开发平台文档 2.3 创建应用 第三章 使用地图 3.1 地图使用步骤 3.2 理解几个地图基础控件 3.3 基础类理解 第一章 前言 小编都是在vue项目中使用高德地图的,每一个功能都会亲测可用之后才会…

【JavaEE】Servlet API 详解(HttpServlet类)

一、HttpServlet 写 Servlet 代码的时候, 首先第一步就是先创建类, 继承自HttpServlet, 并重写其中的某些方法 1.1 HttpServlet核心方法 1.2 Servlet生命周期 这些方法的调用时机, 就称为 “Servlet 生命周期”. (也就是描述了一个 Servlet 实例从生到死的过程) 1.3 处理G…
最新文章