Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

1 model.train() 和 model.eval()用法和区别

1.1 model.train()

model.train()的作用是启用 Batch Normalization 和 Dropout

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

1.2 model.eval()

model.eval()的作用是不启用Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

1.3 分析原因

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval。model.eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!

# 定义一个网络
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    # 实例化这个网络
    Model = Net()
     
    # 训练模式使用.train()
    Model.train(mode=True)
    
    # 测试模型使用.eval()
    Model.eval()

为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。
在这里插入图片描述
想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。

2.model.eval()和torch.no_grad()的区别

1.在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

主要用于通知dropout层和BN层在train和validation/test模式间切换:
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2. 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/108587.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【网络编程】传输层——UDP协议

文章目录 一、传输层1. 再谈端口号2. 端口号范围划分3. 认识知名端口号4. 两个问题5. netstat 与 pidof 二、UDP协议1. UDP协议格式2. UDP协议的特点3. 面向数据报4. UDP的缓冲区5. UDP使用注意事项6. 基于UDP的应用层协议 一、传输层 传输层 负责负责两台计算机之间的端到端的…

java原子类-Atomic

什么是原子类? java 1.5引进原子类,具体在java.util.concurrent.atomic包下,atomic包里面一共提供了13个类,分为4种类型,分别是: 原子更新基本类型,原子更新数组,原子更新引用&…

Pritunl搭建OpenVPN服务器详细流程,快速实现公网远程连接!

文章目录 前言1.环境安装2.开始安装3.访问测试4.创建连接5.局域网测试连接6.安装cpolar7.配置固定公网访问地址8.远程连接测试 前言 Pritunl是一款免费开源的 VPN 平台软件(但使用的不是标准的开源许可证,用户受到很多限制)。这是一种简单有…

GO学习之 通道(nil Channel妙用)

GO系列 1、GO学习之Hello World 2、GO学习之入门语法 3、GO学习之切片操作 4、GO学习之 Map 操作 5、GO学习之 结构体 操作 6、GO学习之 通道(Channel) 7、GO学习之 多线程(goroutine) 8、GO学习之 函数(Function) 9、GO学习之 接口(Interface) 10、GO学习之 网络通信(Net/Htt…

STM32F4VGT6-DISCOVERY:uart1驱动

对于这款板子&#xff0c;官方并没有提供串口例程&#xff0c;只能自行添加。 一、PA9/PA10复用成串口1功能不可用 驱动测试代码如下&#xff1a; main.c: #include "main.h" #include <stdio.h>void usart1_init(void) {GPIO_InitTypeDef GPIO_InitStruct…

前端 : 用html ,css,js写一个你画我猜的游戏

1.HTML&#xff1a; <body><div id "content"><div id "box1">计时器</div><div id"box"><div id "top"><div id "box-top-left">第几题:</div><div id "box…

linux-磁盘应用

目录 一、磁盘内容简述 1、一些基本概念 2、分区简述 3、常见文件系统 4、linux硬盘文件 二、对linux系统进行分区 1、用fdisk进行分区 2、用parted进行分区 一、磁盘内容简述 1、一些基本概念 - 扇区大小&#xff1a;512Btyes&#xff0c;0.5KB - 磁盘最小存储单位&…

小黑子—spring:第一章 Bean基础

spring入门1.0 一 小黑子对spring基础进行概述1.1 spring导论1.2 传统Javaweb开发困惑及解决方法1.3 三大的思想提出1.3.1 IOC入门案例1.3.2 DI入门案例 1.4 框架概念1.5 初识spring1.5.1 Spring Framework 1.6 BeanFactory快速入门1.7 ApplicationContext快速入门1.8 BeanFact…

python之计算平面点集的的面积

在当今数据驱动的世界中&#xff0c;计算平面点集的最小外接轮廓面积被广泛应用于各种实际场景中。它是一项重要而魅力十足的任务&#xff0c;旨在找到一个最小的矩形或多边形区域&#xff0c;能够完全包围给定的离散点集。这个看似简单的问题背后隐藏着许多挑战&#xff0c;需…

HTML5+CSS3+Vue小实例:路飞出海的动画特效

实例:路飞出海的动画特效 技术栈:HTML+CSS+Vue 效果: 源码: 【HTML】 <!DOCTYPE html> <html><head><meta http-equiv="content-type" content="text/html; charset=utf-8"><meta name="viewport" content=&…

windows下OOM排查

如下有一段代码 package com.lm.demo.arthas.controller;import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController;import java.util.A…

车载音频项目

加我微信hezkz17进数字音频系统研究开发交流答疑群(课题组) ー 1&#xff0e;负责此项目的音频链路的设计及其实现 在ADSP21375上实现音频链路的处理。如噪声门&#xff0c;压限器&#xff0c;高低通&#xff0c;PEQ、各种效果等。 2&#xff0e;负责DSP与MCU端SPI协议实现。M…

Python文件——使用Python读取txt文件

作者&#xff1a;Insist-- 个人主页&#xff1a;insist--个人主页 本文专栏&#xff1a;Python专栏 专栏介绍&#xff1a;本专栏为免费专栏&#xff0c;并且会持续更新python基础知识&#xff0c;欢迎各位订阅关注. 目录 一、文件的编码 1. 什么是编码 2. 常见的编码 二、P…

【纯离线】Ubuntu离线安装ntp时间同步服务

Ubuntu离线安装ntp服务 准备阶段&#xff1a;下载安装包 apt-get download ntp apt-get download ntpdate 一、服务端( 192.166.6.xx) 1、环境准备 先判断是否已安装 systemd-timesyncd systemctl is-active systemd-timesyncd 如果返回结果是 active&#xff0c;则表示…

文件夹批量改名:如何在文件夹名左边添加递增的自动编号

在文件管理的过程中&#xff0c;我们有时需要对文件夹进行重命名&#xff0c;使其更具区分度和可读性。为了实现这一目标&#xff0c;我们可以采用在文件夹名左边添加递增的自动编号的方法。本文将介绍云炫文件管理器如何进行文件夹批量改名&#xff0c;以在文件夹名左边添加递…

mathtype7.4破解永久激活码

MathType(数学公式编辑器)是由Design Science公司研发的一款专业的数学公式编辑工具。MathType功能非常强大&#xff0c;尤其适用于专门研究数学领域的人群使用。使用MathType让你在输入数学公式的时候能够更加的得心应手&#xff0c;各种复杂的运算符号也不在话下。 MathType最…

Vue3.0插槽

用法&#xff1a; 父组件App.vue <template><div><!--将html代码插入到子组件中带默认名称的插槽中--><AChild><!--这段html会插入到AChild组件中<slot></slot>插槽中--><!-- 注意&#xff1a;写在父组件中的html代码只能在父组…

百度网盘使用指南

文章目录 备份篇手机文件备份电脑文件备份 查找篇移动端PC端 文件操作文件解压文件扫描PDF工具图片工具音频操作 备份篇 手机文件备份 在百度网盘APP种点击 我的–设置–自动备份设置 里边有相册备份, 文档备份, 微信文件备份, 手机通讯录, 短信, 通话备份等功能 电脑文件备…

目标检测类项目数据集汇总

一、玩手机数据集及检测 玩手机数据集下载地址分享: https://download.csdn.net/download/qq_34717531/19870205 二、狗的数据集及检测 狗目标检测数据集下载地址分享:https://download.csdn.net/download/qq_34717531/20813390 三、猫数据集及检测 猫数据集下载地址分享: ht…

回归算法|长短期记忆网络LSTM及其优化实现

本期文章将介绍LSTM的原理及其优化实现 序列数据有一个特点&#xff0c;即“没有曾经的过去则不存在当前的现状”&#xff0c;这类数据以时间为纽带&#xff0c;将无数个历史事件串联&#xff0c;构成了当前状态&#xff0c;这种时间构筑起来的事件前后依赖关系称其为时间依赖&…
最新文章