MogaNet实战:使用MogaNet实现图像分类任务(一)

文章目录

  • 摘要
  • 安装包
    • 安装timm
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集

摘要

论文:https://arxiv.org/pdf/2211.03295.pdf

作者多阶博弈论交互这一全新视角探索了现代卷积神经网络的表示能力。这种交互反映了不同尺度上下文中变量间的相互作用效果。提出了一种新的纯卷积神经网络架构族,称为MogaNet。MogaNet具有出色的可扩展性,在ImageNet和其他多种典型视觉基准测试中,与最先进的模型相比,其参数使用更高效,且具有竞争力的性能。具体来说,MogaNet在ImageNet上实现了80.0%和87.8%的Top-1准确率,分别使用了5.2M和181M参数,优于ParC-Net-S和ConvNeXt-L,同时节省了59%的浮点运算和17M的参数。源代码可在GitHub上(https://github.com/Westlake-AI/MogaNet)获取。

在这里插入图片描述

本文使用MogaNet模型实现图像分类任务,模型选择最小的moganet_xtiny,在植物幼苗分类任务ACC达到了96%+。

请添加图片描述

请添加图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现MogaNet模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=0.1, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=12)
 criterion_train = SoftTargetCrossEntropy()

Mixup 是一种在图像分类任务中常用的数据增强技术,它通过将两张图像以及其对应的标签进行线性组合来生成新的数据和标签。
参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)

class ModelEma:
    def __init__(self, model, decay=0.9999, device='', resume=''):
        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        self.ema_has_module = hasattr(self.ema, 'module')
        if resume:
            self._load_checkpoint(resume)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        assert isinstance(checkpoint, dict)
        if 'state_dict_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict_ema'].items():
                # ema model may have been wrapped by DataParallel, and need module prefix
                if self.ema_has_module:
                    name = 'module.' + k if not k.startswith('module') else k
                else:
                    name = k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)
            _logger.info("Loaded state_dict_ema")
        else:
            _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")

    def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:
     model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device='cpu',
            resume=resume)

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    if model_ema is not None:
        model_ema.update(model)


# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

MogaNet_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  ├─__init__.py
│  └─moganet.py
├─mean_std.py
├─makedata.py
├─train.py
└─test.py

mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
train.py:训练MogaNet模型
models:来源官方代码。

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

def get_mean_and_std(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
    print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/384474.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

tcp 中使用的定时器

定时器的使用场景主要有两种。 (1)周期性任务 这是定时器最常用的一种场景,比如 tcp 中的 keepalive 定时器,起到 tcp 连接的两端保活的作用,周期性发送数据包,如果对端回复报文,说明对端还活着…

【后端高频面试题--设计模式下篇】

🚀 作者 :“码上有前” 🚀 文章简介 :后端高频面试题 🚀 欢迎小伙伴们 点赞👍、收藏⭐、留言💬 后端高频面试题--设计模式下篇 后端高频面试题--设计模式上篇设计模式总览模板方法模式怎么理解模…

2024年腾讯云4核8G12M服务器性能测评,适合哪些使用场景?

腾讯云4核8G服务器适合做什么?搭建网站博客、企业官网、小程序、小游戏后端服务器、电商应用、云盘和图床等均可以,腾讯云4核8G服务器可以选择轻量应用服务器4核8G12M或云服务器CVM,轻量服务器和标准型CVM服务器性能是差不多的,轻…

【c++基础】同构数

说明 同构数是这样一种数:它出现在它的平方数的右端。例如:5的平方是25,5就是同构数,25的平方是625,25也是同构数。 再比如:100以内的同构数有1 5 6 25 76这5个整数。 请编程计算出1~N之间(包…

算法村目录

大家好我是苏麟 , 这是算法村使用目录 . 算法通关村 从链表到动态规划的实战 目录 算法村开篇第一关 了解链表第二关 链表专题第三关 数组专题第四关 栈专题第五关 队列专题第六关 树专题第七关 二叉树遍历专题第八关 二叉树专题第九关 二分查找与二叉树专题第十关 快速排序与归…

传统推荐算法库使用--mahout初体验

文章目录 前言环境准备调用混合总结 前言 郑重声明:本博文做法仅限毕设糊弄老师使用,不建议生产环境使用!!! 老项目缝缝补补又是三年,本来是打算直接重写写个社区然后给毕设使用的。但是怎么说呢&#xff…

事理与事件知识图谱

目录 前言1 事件定义与事理逻辑1.1 事件定义1.2 事理逻辑 2 事理知识图谱与传统知识图谱的区别和联系2.1 事理知识图谱与传统知识图谱的区别2.2 事理知识图谱与传统知识图谱的联系 3 事理知识图谱中的关系3.1 顺承关系3.2 因果关系3.3 条件关系3.4 并发关系3.5 上下位关系 4 事…

HP Pavilion Laptop 15-cs3xxx原装出厂Win10.20H1系统

惠普笔记本HP Pavilion - 15-cs3030tx原厂Windows10系统镜像下载 链接:https://pan.baidu.com/s/1LmdJoN7F3BGvt49ovq-eww?pwdzgmt 提取码:zgmt 适用型号: 15-cs3001tx,15-cs3030tx,15-cs3031tx,15-cs…

每日一练:LeeCode-654、最大二叉树【二叉树+DFS+分治】

本文是力扣LeeCode-654、最大二叉树【二叉树DFS分治】 学习与理解过程,本文仅做学习之用,对本题感兴趣的小伙伴可以出门左拐LeeCode。 给定一个不重复的整数数组 nums 。 最大二叉树 可以用下面的算法从 nums 递归地构建: 创建一个根节点,其…

最全面的Docker安装部署,配置镜像加速

安装Docker 卸载旧版 首先如果系统中已经存在旧的Docker,则先卸载: yum remove docker \docker-client \docker-client-latest \docker-common \docker-latest \docker-latest-logrotate \docker-logrotate \docker-engine 配置Docker的yum仓库 首先…

Codeforces Round 924 (Div. 2)B. Equalize(思维+双指针)

文章目录 题面链接题意题解代码 题面 链接 B. Equalize 题意 给一个数组 a a a,然后让你给这个数组加上一个排列,求出现最多的次数 题解 赛时没过不应该。 最开始很容易想到要去重,因为重复的元素对于答案是没有贡献的。 去重后排序。&a…

HTTP 超文本传送协议

1 超文本传送协议 HTTP HTTP 是面向事务的 (transaction-oriented) 应用层协议。 使用 TCP 连接进行可靠的传送。 定义了浏览器与万维网服务器通信的格式和规则。 是万维网上能够可靠地交换文件(包括文本、声音、图像等各种多媒体文件)的重要基础。 H…

LLM大模型常见问题解答(2)

对大模型基本原理和架构的理解 大型语言模型如GPT(Generative Pre-trained Transformer)系列是基于自注意力机制的深度学习模型,主要用于处理和生成人类语言。 基本原理 自然语言理解:模型通过对大量文本数据的预训练&#xff…

LLM之RAG实战(二十五)| 使用LlamaIndex和BM25重排序实践

本文,我们将研究高级RAG方法的中的重排序优化方法以及其与普通RAG相比的关键差异。 一、什么是RAG? 检索增强生成(RAG)是一种复杂的自然语言处理方法,它包括两个不同的步骤:信息检索和生成语言建模。这种方…

【开源】JAVA+Vue.js实现车险自助理赔系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 角色管理模块2.3 车辆档案模块2.4 车辆理赔模块2.5 理赔照片模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 角色表3.2.2 车辆表3.2.3 理赔表3.2.4 理赔照片表 四、系统展示五、核心代码5.1 查询车…

【PyQt】10 QLineEdit

文章目录 前言一、回显模式(EchoMode)1.1 四种回显模式1.2 代码展示运行结果 二、校验器2.1 代码2.2 运行结果 三、通过掩码限制输入3.1 代码3.2 运行结果 总结 前言 1、QLineEdit 可以输入单行文字 2、回显模式 3、校验器 4、掩码输入 一、回显模式&am…

图片懒加载:从低像素预览到高清加载

老生常谈的问题,图片太多太大的网站,往往由于图片加载过慢而导致页面白屏时间过长。本次年前最后一更,来讲一个加载方法来处理这种情况。 在使用 Next.js 时,发现其支持模糊图片占位符加载的方式,本文就手动实现一个 图…

最近vscode链接Autodl出现的问题

最近vscode链接Autodl出现的问题 一、问题的概述 在使用vscode连接autodl远程服务器的时候,在vscode的右下角出现了,以下的问题提示: 远程主机可能不符合glibc和libstdc VS Code服务器的先决条件 二、问题的原因 vscode版本过高的问题&…

远程git仓库已有仓库,若想本地代码覆盖远程代码,操作如下

1.首先 2.其次 3.然后 4.最后 git push -f origin "master" -f:强制推送,若远程是空项目,可以换成-u

基于Java (spring-boot)的宿舍管理系统

一、项目介绍 基于Java (spring-boot)的宿舍管理系统功能:登录界面、宿舍管理、学生管理、班级管理、宿舍楼管理、维修记录、晚归记录、请假记录、用户管理、角色管理、菜单管理、日志管理、我收到的、退宿审核,等等等 二、作品包含 三、项目技术 后端语…
最新文章