Yolov8模型用torch_pruning剪枝

目录

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

原理

 遍历所有分组

高级剪枝器


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxv

原理

传统剪枝方法的缺陷

在复杂的网络结构中, 参数之间可能存在依赖关系, 这种依赖要求算法对这类参数进行同步移除以保证结构正确性,这就涉及到耦合参数的分组问题. 我们的工作通过提供一种自动化机制来对参数进行分组. 具体而言, Torch-Pruning使用伪输入来运 行模型, 跟踪网络计算图, 并记录层之间的依赖关系. 当剪枝某一层时, Torch-Pruning会识别所有耦合层, 并返回包含这些耦合信息的tp.Group.

一种通用的结构化剪枝框架DepGraph(Dependency Graph),可以应用于任意类型的神经网络架构(包括CNN、RNN、GNN和Transformer等)进行结构化剪枝。主要原理如下:

1. 神经网络内部存在着层与层之间的依赖关系,需要同时剪枝依赖的层组,否则会破坏网络结构。

2. 结构化剪枝的优势

结构化剪枝的做法是,找到网络中相互依赖的层组,把整个层组同时全部保留或全部删除,从而保证网络结构的完整性。这种做法虽然灵活性较低,但可以有效避免了网络结构被破坏的问题。

3. DepGraph通过建模层与层之间的依赖关系,明确每一层所属的层组。具体分为两种依赖关系:

   a) 层间依赖(Inter-layer Dependency): 相邻连接的层之间存在依赖   层间不依赖:resnet

   b) 层内依赖(Intra-layer Dependency): 同一层的输入和输出具有相同的剪枝方式时存在依赖   层内不依赖:没有共享权重的

4. 通过图遍历算法在DepGraph上找到最大连接分量作为层组,实现自动化的层组划分。总的来说,DepGraph解决了之前结构化剪枝算法依赖人工设计层组划分规则、缺乏通用性的问题,提出了一种自动建模层组依赖关系和组级剪枝重要性评估的通用框架。

5. DepGraph的工作原理

以ResNet的基本模块为例,如果要删除某个卷积层的滤波器核,由于残差连接的存在,我们必须同时删除该模块中所有层(BN层、ReLU层等)对应的通道。DepGraph通过建模层与层之间的依赖关系,自动将这些相互依赖的层划分到同一个层组中。在剪枝时,整个层组被统一评分,决定是完全保留还是完全删除,从而实现安全的结构化剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. 指定剪枝的通道维度
pruning_idxs = [2, 6, 9]
pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

print(pruning_group.details())  # or print(pruning_group)

# 3. 检查剩余通道数是否<=0, 并执行剪枝
if DG.check_pruning_group(pruning_group):
    pruning_group.prune()

这个例子演示了使用 DepGraph剪枝的基本流程, resnet.conv1实际上会与多个层耦合在一起.通过打印返回的组, 可以看到组内各个层之间的剪枝是如何互相“触发”的.在以下输出中, “A => B”表示剪枝操作“A”触发剪枝操作“B”.group[0]是用户在DG.get_pruning_group中给出的剪枝操作. 

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #idxs=3
[1] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[2] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), #idxs=3
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), #idxs=3
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), #idxs=3
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), #idxs=3
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), #idxs=3
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=3
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), #idxs=3
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), #idxs=3
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), #idxs=3
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #idxs=3
--------------------------------
 遍历所有分组

可以利用DG.get_all_groups(ignored_layers, root_module_types)来按顺序扫描所有的分组. 每个分组都会以一个"root_module_types"中所指定的层作为起点. 默认情况下, 这些组包含了完整的剪枝索引idxs=[0,1,2,3,...,K], 这个索引列表包含了所有的可修剪参数的索引. 如果我们希望对一个group进行剪枝, 我们需要使用group.prune(idxs=idxs)来指定具体的修剪通道/维度.

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
    # handle groups in sequential order
    idxs = [2,4,6] # your pruning indices
    group.prune(idxs=idxs)
    print(group)
高级剪枝器
import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

# 重要性指标
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2) # p=2表示使用L2正则,对每个group中的每个层的权值,独立的计算重要性   重要性如何计算??什么是重要的?值大还是小?是损失吗

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # 迭代式剪枝, 该示例会分五步完成50%通道剪枝 (10%->20%->...->50%)
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # 整体移除50%通道, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/441654.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C# winform 重启电脑

一、重启电脑指令 windows7系统的启动文件夹为“开始菜单”——“所有程序”里面就有“启动”文件夹&#xff0c;其位置是 “C:\Users\Administrator\AppData\Roaming\Microsoft\Windows\Start Menu\Programs\Startup” 如果没有&#xff0c;则需要将其中的"administrator…

flutter逆向 ACTF native ap

言 算了一下好长时间没打过CTF了,前两天看到ACTF逆向有道flutter逆向题就过来玩玩啦,花了一个下午做完了.说来也巧,我给DASCTF十月赛出的逆向题其中一道也是flutter,不过那题我难度降的相当之低啦,不知道有多少人做出来了呢~ 还原函数名 flutter逆向的一大难点就是不知道lib…

容器安全是什么?

容器安全定义 容器安全是指保护容器的完整性。这包括从其保管的应用到其所依赖的基础架构等全部内容。容器安全需要完整且持续。通常而言&#xff0c;企业拥有持续的容器安全涵盖两方面&#xff1a; 保护容器流水线和应用保护容器部署环境和基础架构 如何将安全内置于容器流…

【MySQL 系列】MySQL 语句篇_DQL 语句

DQL&#xff08;Data Query Language&#xff09;&#xff0c;即数据查询语言&#xff0c;用来查询数据记录。DQL 基本结构由 SELECT FROM、WHERE、JOIN 等子句构成。 DQL 语句并不会改变数据库&#xff0c;而是让数据库将查询结果发送结果集给客户端&#xff0c;返回的结果是一…

代理IP以及动态拨号VPS的关系是什么?

在数字时代&#xff0c;网络安全和隐私保护已成为全球关注的热点话题。代理IP和动态拨号VPS作为提升网络匿名性和安全的重要技术&#xff0c;它们在维护网络隐私中扮演着至关重要的角色。虽然这两种技术在表面上看似相似&#xff0c;实际上它们在功能、应用场景以及用户需求满足…

生成对抗网络 (GAN)

生成对抗网络&#xff08;Generative Adversarial Networks&#xff0c;GAN&#xff09;是由Ian Goodfellow等人在2014年提出的一种深度学习模型。GAN由两部分组成&#xff1a;一个生成器&#xff08;Generator&#xff09;和一个判别器&#xff08;Discriminator&#xff09;&…

window vscode安装node.js

window vscode安装node.js 官网下好vscode 和nodejs 选.msi的安装 点这个安装 下载完 继续安装 完毕后倒杯水喝个茶等2分钟 重启VScode 或者在cmd 运行 npm -v node -v 显示版本号则成功

借助Aspose.html控件,在 Java 中将 URL 转换为 PDF

如果您正在寻找一种将实时 URL 中的网页另存为 PDF文档的方法&#xff0c;那么您来对地方了。在这篇博文中&#xff0c;我们将学习如何使用 Java 将 URL 转换为 PDF。从实时 URL转换HTML网页可以像任何其他文档一样保存所需的网页以供离线访问。将网页保存为 PDF 格式可以轻松突…

金智维售前总监屈文浩,将出席“ISIG-RPA超级自动化产业发展峰会”

3月16日&#xff0c;第四届「ISIG中国产业智能大会」将在上海中庚聚龙酒店拉开序幕。本届大会由苏州市金融科技协会指导&#xff0c;企智未来科技&#xff08;RPA中国、AIGC开放社区、LowCode低码时代&#xff09;主办。大会旨在聚合每一位产业成员的力量&#xff0c;深入探索R…

windows关闭copilot预览版

如果用户不想在windows系统当中启用Copilot&#xff0c;可以通过以下三种方式禁用。 第一种&#xff1a;隐藏Copilot 按钮 右键点击任务栏&#xff0c;取消勾选“显示 Copilot&#xff08;预览版&#xff09;按钮”&#xff0c;任务栏则不再显示&#xff0c;用户可以通过快捷键…

[N1CTF 2018]eating_cms 不会编程的崽

题倒是不难&#xff0c;但是实在是恶心到了。 上来就是登录框&#xff0c;页面源代码也没什么特别的。寻思抓包看一下&#xff0c;数据包直接返回了sql查询语句。到以为是sql注入的题目&#xff0c;直到我看到了单引号被转义。。。挺抽象&#xff0c;似乎sql语句过滤很严格。又…

C语言分析基础排序算法——插入排序

目录 插入排序 直接插入排序 希尔排序 希尔排序基本思路解析 希尔排序优化思路解析 完整希尔排序文件 插入排序 直接插入排序 所谓直接插入排序&#xff0c;即每插入一个数据和之前的数据进行大小比较&#xff0c;如果较大放置在后面&#xff0c;较小放置在前面&#x…

Flutter使用auto_updater实现windows/mac桌面应用版本升级功能

因为windows应用一般大家都是从网上下载的&#xff0c;后期版本肯定会更新&#xff0c;那用flutter开发windows应用&#xff0c;怎么实现应用内版本更新功能了&#xff1f;可以使用auto_updater库&#xff0c; 这个插件允许 Flutter 桌面 应用自动更新自己 (基于 sparkle 和 wi…

生活的色彩--爱摸鱼的美工(17)

题记 生活不如意事十之八九&#xff0c; 恶人成佛只需放下屠刀&#xff0c;善人想要成佛却要经理九九八十一难。而且历经磨难成佛的几率也很小&#xff0c;因为名额有限。 天地不仁以万物为刍狗&#xff01; 小美工记录生活&#xff0c;记录绘画演变过程的一天。 厨房 食…

单调栈(例题+解析)

1、应用场景 找出一个数的左面离概述最近的且小于该数的数&#xff08;同理右面也可以&#xff09; 例如&#xff1a; 数组a[i] 3 4 2 7 5 答案&#xff1a; -1 3 -1 2 2 2、如何实现找到规律 暴力写法&#xff1a; for(int i0;i<n;i) {for(int ji-1;j>0;j--){i…

在ubuntu上使用vscode+gcc-arm-none-eabi+openocd工具开发STM32

文章目录 所需工具安装调试搭建过程中遇到的问题 写在前面 老大上周让我用vscode开发STM32&#xff0c;我爽快的答应了&#xff0c;心想大学四年装了这么多环境了这不简简单单&#xff0c;更何况vscode这两年还用过&#xff0c;然而现实总是令人不快的——我竟然花了差不多两周…

UDP实现文件的发送、UDP实现全双工的聊天、TCP通信协议

我要成为嵌入式高手之3月7日Linux高编第十七天&#xff01;&#xff01; ———————————————————————————— 回顾 重要程序 1、UDP实现文件的发送 发端&#xff1a; #include "head.h"int main(void) {int sockfd 0;struct sockaddr_i…

PHP在线图像处理程序:基于Photoshop的网页版图片处理源码

PHP在线PS修图网页版源码&#xff1a;实现照片图片处理的便捷工具 众所周知&#xff0c;许多朋友都喜欢使用PS进行图像编辑。然而&#xff0c;PS需要下载软件并对电脑配置要求较高。今天我们为大家带来一款基于浏览器的在线PS网页版源码&#xff0c;让您轻松实现在线P图和作图…

Kafka MQ 主题和分区

Kafka MQ 主题和分区 Kafka 的消息通过 主题 进行分类。主题就好比数据库的表&#xff0c;或者文件系统里的文件夹。主题可以被分为若干个 分区 &#xff0c;一个分区就是一个提交日志。消息以追加的方式写入分区&#xff0c;然 后以先入先出的顺序读取。要注意&#xff0c;由…

【论文阅读】Mamba:选择状态空间模型的线性时间序列建模(一)

文章目录 Mamba:选择状态空间模型的线性时间序列建模介绍状态序列模型选择性状态空间模型动机&#xff1a;选择作为一种压缩手段用选择性提升SSM 选择性SSM的高效实现先前模型的动机选择扫描总览&#xff1a;硬件感知状态扩展 Mamba论文 Mamba:选择状态空间模型的线性时间序列建…