【论文阅读】Long-Tailed Recognition via Weight Balancing(CVPR2022)

目录

  • 论文
  • 使用方法
    • weight decay
    • MaxNorm
  • 如果使用原来的代码报错的可以看下面这个

论文

问题:真实世界中普遍存在长尾识别问题,朴素训练产生的模型在更高准确率方面偏向于普通类,导致稀有的类别准确率偏低。
key:解决LTR的关键是平衡各方面,包括数据分布、训练损失和学习中的梯度。
文章主要讨论了三种方法: L2normalization, weight decay, and MaxNorm
本文提出了一个两阶段训练的范式
a. 利用调节权重衰减的交叉熵损失学习特征。
b. 通过调节权重衰减和Max Norm使用类平衡损失学习分类器。
一些有用的看法

  1. 研究表明,与联合训练特征学习和分类器学习的模型相比,解耦特征学习和分类器学习导致了显著的改进。
  2. 根据基准测试结果,通过集成专家模型或采用主动数据增强技术的自监督预训练来实现最好精度。
  3. 研究发现,SGD动量导致LTR出现问题,阻碍了进一步改善。
  4. 最近,Kang等人令人信服地证明了阶段性训练对LTR很重要。
  5. 权重衰减有助于学习隐藏层的平衡权重。
  6. 重要的是,我们的探索发现,虽然在分类器上使用L2规范化约束进行训练比简单训练有所改进,但它的表现不如下面描述的其他两个正则化。
  7. 与严格将所有滤波器权重的范数值设置为1的L2归一化不同,MaxNorm放松了这一约束,允许权重在训练期间在范数球内移动。
  8. 权重衰减中,不同数据集的最优λ各不相同——较大的数据集需要较小的权重衰减,直观地说,因为在更多数据上学习有助于泛化,因此需要较少的正则化。
    单阶段使用不平衡损失训练效果不好的原因:虽然他们没有解释为什么具有类平衡损失的单阶段训练表现不佳,但直观地说,这是因为类平衡损失人为地放大了从罕见的类训练数据计算的梯度,这损害了特征表示学习,从而损害了最终的LTR性能。
    本文作者使用了weight decay和max norm两种方法结合,因为发现两个结合效果更好。让模型不同类之间权重相差不会很大的同时,还能让这些权重缓慢增加。
    下面这幅图就是解释了这些方法的特点。
    在这里插入图片描述
    第一个就是普通方法训练的,它常见的类别权重增长快。
    第二个是L2 normalization,它把所有类别的权重都限定在一个常数。
    第三个是权重衰减,它的所有类的权重小,而且权重在增长。
    第四个是MaxNorm,它限制最大的权重。
    第五个是权重衰减和MaxNorm,会导致范数中的权重较小且平衡。

使用方法

weight decay

先定义好权重衰减的值。

weight_decay = 0.1 #weight decay value

然后在优化器中调用。Adam还有其他的都有weight_decay。

optimizer = optim.SGD([{'params': active_layers, 'lr': base_lr}], lr=base_lr, momentum=0.9, weight_decay=weight_decay)

MaxNorm

就是这个论文中的regularizers.py中的代码。只要会使用就好。就是要是不是作者代码中的模型的话,model.encoder.fc还需要根据自己的代码修改。

#使用前先定义好初始化好
pgdFunc = MaxNorm_via_PGD(thresh=thresh)
pgdFunc.setPerLayerThresh(model) # set per-layer thresholds这个是计算模型每一层的权重的阈值,这篇论文中只计算最后线性层的权重,并对最后线性层的权重进行限制

当模型训练一个epoch结束后,对已经更新完毕的模型权重进行限制,如果超过阈值就进行更新,让权重在最大范数的约束下。

 if pgdFunc:# Projected Gradient Descent
     pgdFunc.PGD(model)#对权重进行限制
import torch
import torch.nn as nn
import math
# The classes below wrap core functions to impose weight regurlarization constraints in training or finetuning a network.

class MaxNorm_via_PGD():
    def __init__(self, thresh=1.0, LpNorm=1, tau=1):
        self.thresh = thresh
        self.LpNorm = LpNorm
        self.tau = tau
        self.perLayerThresh = []

    def setPerLayerThresh(self, model):#根据指定的模型设置每层的阈值
        #set pre-layer thresholds
        self.perLayerThresh = []

        for curLayer in [model.encoder.fc.weight, model.encoder.fc.bias]:#遍历模型的最后两层
            curparam = curLayer.data#获取当前层的数据
            if len(curparam.shape) <= 1:#如果层只有一个维度,是一个偏置或者是一个1D的向量,则设置这一层的阈值为无穷大,继续下一层
                self.perLayerThresh.append(float('inf'))
                continue
            curparam_vec = curparam.reshape((curparam.shape[0], -1))#如果不是,把权重张量展开
            neuronNorm_curparam = torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1).detach().unsqueeze(-1)#沿着第一维计算P番薯,结果存储
            curLayerThresh = neuronNorm_curparam.min() + self.thresh*(neuronNorm_curparam.max() - neuronNorm_curparam.min())#计算每一层的阈值及神经元范数的最小值加上最大值和最小值之间的缩放差
            self.perLayerThresh.append(curLayerThresh)#每层阈值存储

    def PGD(self, model):#定义PGD函数,用于在模型的参数上执行投影梯度下降,试试最大范数约束
        if len(self.perLayerThresh) == 0:#如果每层的阈值是空,用setPerLayerThresh方法初始化
            self.setPerLayerThresh(model)
        for i, curLayer in enumerate([model.encoder.fc.weight, model.encoder.fc.bias]):#遍历模型的最后两层
            curparam = curLayer.data#获取当前层的数据张量值
            curparam_vec = curparam.reshape((curparam.shape[0], -1))#变成一维
            neuronNorm_curparam = (torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1)**self.tau).detach().unsqueeze(-1)#在最后加一维
            #计算权重张量中每行神经元番薯的tau次方
            scalingVect = torch.ones_like(curparam)#创建一个形状与当前层数据相同的张量,用1初始化
            curLayerThresh = self.perLayerThresh[i]#获取阈值

            idx = neuronNorm_curparam > curLayerThresh#创建bool保存超过阈值的神经元
            idx = idx.squeeze()#
            tmp = curLayerThresh / (neuronNorm_curparam[idx].squeeze())**(self.tau)#根据每层的阈值和超过阈值的神经元番薯计算缩放因子
            for _ in range(len(scalingVect.shape)-1):#扩展缩放因子以匹配当前层数据的维度
                tmp = tmp.unsqueeze(-1)

            scalingVect[idx] = torch.mul(scalingVect[idx],tmp)
            curparam[idx] = scalingVect[idx] * curparam[idx]
            curparam[idx] = scalingVect[idx] * curparam[idx]#通过缩放值更新当前层的数据,以便对超过阈值的神经元进行缩放。完成权重更新


如果使用原来的代码报错的可以看下面这个

我的网络只有一层是线性层idx = idx.squeeze(),idx是(1,1)形状的,squeeze就没了,所以报错,如果有这个原因的可以改成idx = idx.squeeze(1)。maxnorm只改最后两层/一层权重所以,定义了一个列表存储线性层只取最后两层或者一层。

class MaxNorm_via_PGD():
    # learning a max-norm constrainted network via projected gradient descent (PGD)
    def __init__(self, thresh=1.0, LpNorm=2, tau=1):
        self.thresh = thresh
        self.LpNorm = LpNorm
        self.tau = tau
        self.perLayerThresh = []

    def setPerLayerThresh(self, model):
        # set per-layer thresholds
        self.perLayerThresh = []#存储每一层的阈值
        self.last_two_linear_layers = []#提取线性层
        for name, module in model.named_children():
            if isinstance(module, nn.Linear):
                self.last_two_linear_layers.append(module)

        for linear_layer in self.last_two_linear_layers[-min(2, len(self.last_two_linear_layers)):]:  # here we only apply MaxNorm over the last two layers
            curparam = linear_layer.weight.data
            if len(curparam.shape) <= 1:
                self.perLayerThresh.append(float('inf'))
                continue
            curparam_vec = curparam.reshape((curparam.shape[0], -1))
            neuronNorm_curparam = torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1).detach().unsqueeze(-1)
            curLayerThresh = neuronNorm_curparam.min() + self.thresh * (
                        neuronNorm_curparam.max() - neuronNorm_curparam.min())
            self.perLayerThresh.append(curLayerThresh)

    def PGD(self, model):
        if len(self.perLayerThresh) == 0:
            self.setPerLayerThresh(model)
        for i, curLayer in enumerate([self.last_two_linear_layers[-min(2,
                                                             len(self.last_two_linear_layers))]]):  # here we only apply MaxNorm over the last two layers

            curparam = curLayer.weight.data

            curparam_vec = curparam.reshape((curparam.shape[0], -1))
            neuronNorm_curparam = (
                        torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1) ** self.tau).detach().unsqueeze(-1)
            scalingVect = torch.ones_like(curparam)
            curLayerThresh = self.perLayerThresh[i]

            idx = neuronNorm_curparam > curLayerThresh
            idx = idx.squeeze(1)
            tmp = curLayerThresh / (neuronNorm_curparam[idx].squeeze()) ** (self.tau)
            for _ in range(len(scalingVect.shape) - 1):
                tmp = tmp.unsqueeze(-1)

            scalingVect[idx] = torch.mul(scalingVect[idx], tmp)
            curparam[idx] = scalingVect[idx] * curparam[idx]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/357464.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

力扣题集(第一弹)

一日练,一日功;一日不练十日空。 学编程离不开刷题&#xff0c;接下来让我们来看几个力扣上的题目。 1. 242. 有效的字母异位词 题目描述 给定两个字符串 s 和 t &#xff0c;编写一个函数来判断 t 是否是 s 的字母异位词。 注意&#xff1a;若 s 和 t 中每个字符出现的次数…

JS图片二维码识别

前言 js识别QR图片&#xff0c;基于jsQR.js 代码 <!DOCTYPE html> <html> <head><meta charset"utf-8" /><title>图片二维码识别</title><script src"https://cdn.bootcss.com/jquery/3.4.1/jquery.min.js">…

什么是消息队列?

消息用队列的模式发送&#xff0c; 把要传输的数据放在队列中&#xff0c; 产生消息的叫做生产者&#xff0c; 从队列里取出消息的叫做消费者。 一、组成 生产者&#xff1a;Producer 消息的产生者与调用端 主要负责消息所承载的业务信息的实例化 是一个队列的发起方 代理…

网站小程序分类目录网源码系统+会员注册登录功能 附带完整的搭建教程

随着互联网的发展&#xff0c;小程序分类目录网站已经成为了人们获取各类信息的重要渠道。而在这个领域中&#xff0c;罗峰给大家分享一款网站小程序分类目录网源码系统以其强大的功能和易用性&#xff0c;脱颖而出。本系统集成了会员注册登录功能&#xff0c;让用户能够更加便…

uniapp H5 实现上拉刷新 以及 下拉加载

uniapp H5 实现上拉刷新 以及 下拉加载 1. 先上图 下拉加载 2. 上代码 <script>import DragableList from "/components/dragable-list/dragable-list.vue";import {FridApi} from /api/warn.jsexport default {data() {return {tableList: [],loadingHi…

Redis核心技术与实战【学习笔记】 - 6.Redis 的统计操作处理

1.前言 在 Web 业务场景中&#xff0c;我们经常保存这样一种信息&#xff1a;一个 key 对应了一个数据集合。比如&#xff1a; 手机 APP 中的每天用户登录信息&#xff1a;一天对应一系列用户 ID。电商网站上商品的用户评论列表&#xff1a;一个商品对应了一些列的评论。用户…

12 数据仓库理论

数仓基本概述 数据仓库基本概念 数据仓库是一个为数据分析而设计的企业级数据管理系统。数据仓库可集中 、整合多个信息源的大量数据。 数仓核心架构 数据仓库建模概述 数据仓库建模意义 数据模型就是数据组织和存储方法&#xff0c;它强调从业务、数据存取和使用角度合理…

Django配置websocket时的错误解决

基于移动群智感知的网络图谱构建系统需要手机app不断上传数据到服务器并把数据推到前端标记在百度地图上&#xff0c;由于众多手机向同一服务器发送数据&#xff0c;如果使用长轮询&#xff0c;则实时性差、延迟高且服务器的负载过大&#xff0c;而使用websocket则有更好的性能…

链表与二叉树-数据结构

链表与二叉树-数据结构 创建叶子node节点建立二叉树三元组&#xff1a;只考虑稀疏矩阵中非0的元素&#xff0c;并且存储到一个类&#xff08;三元组&#xff09;的数组中。 创建叶子node节点 class Node{int no;Node next;public Node(int no){this.nono;} } public class Lb…

YOLOv8改进 | 可视化热力图 | 支持YOLOv8最新版本密度热力图,和视频热力图

一、本文介绍 本文给大家带来的机制是集成了YOLOv8最新版本的可视化热力图功能,热力图作为我们论文当中的必备一环,可以展示出我们呈现机制的有效性,本文的内容支持YOLOv8最新版本的根据密度呈现的热力图,同时支持视频检测,根据视频中的密度来绘画热力图。 在开始之前给…

薅运营商羊毛?封杀!

最近边小缘在蓝点网上看到一则消息 “浙江联通也开始严格排查PCDN和PT等大流量行为 被检测到可能会封停宽带”。 此前中国联通已经在四川和上海等多个省市严查家庭宽带 (部分企业宽带也被查) 使用 PCDN 或 PT&#xff0c;当用户的宽带账户存在大量上传数据的情况&#xff0c;中…

数据库管理-第141期 DG PDB - Oracle DB 23c(20240129)

数据库管理141期 2024-01-29 第141期 DG PDB - Oracle DB 23c&#xff08;20240129&#xff09;1 概念2 环境说明3 操作3.1 数据库配置3.2 配置tnsname3.3 配置强制日志3.4 DG配置3.5 DG配置建立联系3.6 启用所有DG配置3.7 启用DG PDB3.8 创建源PDB的DG配置3.9 拷贝pdbprod1文件…

【C++】I/O多路转接详解(一)

目录 1. 背景引入1.1 IO的过程1.2 五种IO模型1.2.1 阻塞IO1.2.2 非阻塞IO1.2.3 信号驱动IO1.2.4 IO多路转接1.2.5 异步IO 1.3 同步通信 与 异步通信1.4 阻塞 与 非阻塞1.4.1 阻塞与非阻塞区别1.4.2 设置非阻塞IO 2. select2.1 接口使用2.2 select执行过程2.3 select代码实践 3.…

<网络安全>《9 入侵防御系统IPS》

1 概念 IPS&#xff08; Intrusion Prevention System&#xff09;是电脑网络安全设施&#xff0c;是对防病毒软件&#xff08;Antivirus Programs&#xff09;和防火墙&#xff08;Packet Filter, Application Gateway&#xff09;的补充。 入侵预防系统&#xff08;Intrusio…

JS第一课简单看看这是啥东西

1.什么是JavaScript JS是一门编程语言&#xff0c;是一种运行在客户端(浏览器)的编程语言&#xff0c;主要是让前端的画面动起来&#xff0c;注意HTML和CSS不是编程语言&#xff0c;他俩是一种标记语言。JS只要有浏览器就能运行不用跟Python或者Java一样上来装一个jdk或者Pyth…

2023年算法SAO-CNN-BiLSTM-ATTENTION回归预测(matlab)

2023年算法SAO-CNN-BiLSTM-ATTENTION回归预测&#xff08;matlab&#xff09; SAO-CNN-BiLSTM-Attention雪消融优化器优化卷积-长短期记忆神经网络结合注意力机制的数据回归预测 Matlab语言。 雪消融优化器( SAO) 是受自然界中雪的升华和融化行为的启发&#xff0c;开发了一种…

Docker入门篇(二)—— 命令

Docker入门篇&#xff08;二&#xff09;—— 命令 插播&#xff01;插播&#xff01;插播&#xff01;亲爱的朋友们&#xff0c;我们的Cmake/Makefile/Shell这三个课程上线啦&#xff01;感兴趣的小伙伴可以去下面的链接学习哦~ 构建工具大师-CSDN程序员研修院 一、引言 当…

二叉搜索树的后序遍历序列

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 学习必须往深处挖&…

利用Knife4j注解实现Java生成接口文档

文章目录 1、简介2、生成文档3、系列注解3.1、Api3.2、ApiResponses和ApiResponse3.3、ApiOperation3.4、Pathyvariable⭐3.5、RequestBody3.6、ApiOperationSupport3.7、ApiImplicitParams 和 ApiImplicitParam3.8、ApiModel3.9、ApiModelProperty ​&#x1f343;作者介绍&am…

动手学RAG:汽车知识问答

原文&#xff1a;动手学RAG&#xff1a;汽车知识问答 - 知乎 Part1 内容介绍 在自然语言处理领域&#xff0c;大型语言模型&#xff08;LLM&#xff09;如GPT-3、BERT等已经取得了显著的进展&#xff0c;它们能够生成连贯、自然的文本&#xff0c;回答问题&#xff0c;并执行…
最新文章