【剪枝】torch-pruning的基本使用

论文:DepGraph: Towards Any Structural Pruning
工程:https://github.com/VainF/Torch-Pruning
算法和库的使用介绍:CVPR 2023 | DepGraph 通用结构化剪枝

1 TP的简介

该算法介绍了DepGraph 如何建模结构化剪枝中的层依赖,实现任意结构的剪枝。对应实现的库为 torch-pruning。
本篇博客对作者的介绍做一个自己的梳理和记录。

  • Torch-Pruning的简单介绍
    Torch-Pruning(TP)是一个结构化剪枝库,与现有框架(例如torch.nn.utils.prune)最大的区别在于,TP会物理地移除参数,同时自动裁剪其他依赖层。TP是一个纯 PyTorch 项目,实现了内置的计算图的追踪(Tracing)、依赖图(DepenednecyGraph, 见论文)、剪枝器等功能,同时支持 PyTorch 1.x 和 2.0 版本。
  • 用 Torch-Pruning 剪枝的好处
    假设正在对一个卷积结构化剪枝,需要减去哪些内容,具体第几个卷积核、对应的偏置、BN中对应的维度、与其直接或间接相连的层的核的channel。我们要实现剪枝,需要对不同模型定制不同的代码实现。Torch-Pruning可让实现者跳脱出对层剪枝时最具体的操作,而关注于整体剪枝的设置。

2 TP的初尝试


2.1 初步尝试

以 ResNet-18 结构化剪枝为例,对【conv1】进行剪枝,同时处理对应的bn、紧临的卷积。

from torchvision.models import resnet18
import torch_pruning as tp
import torch

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # 剪枝前两个通道
tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1]) # 尝试修复bn
tp.prune_conv_in_channels(model.layer1[0].conv1, idxs=[0,1]) # 尝试修复紧邻的conv
output = model(torch.randn(1,3,224,224)) # 尝试运行剪枝后的网络

会报错如下:问题出在残差结构上。残差的相加操作要求传入的两个tensor具有相同的空间尺寸,也就意味着剪枝后的Tensor通道数62和另一个tensor的通道数64不再匹配。
在这里插入图片描述在这里插入图片描述


2.2 使用TP对 conv1进行剪枝

手动设置DependencyGraph是Torch-Pruning框架的底层算法,设计目标就是"自动寻找耦合层",并自动化处理。
使用TP对ResNet-18的conv1进行剪枝,代码如下:

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. 获取与model.conv1存在依赖的所有层,并指定需要剪枝的通道索引(此处我们剪枝第[2,6,9]个通道)
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
print(model, group)

# 3. 执行剪枝操作
if DG.check_pruning_group(group): # 避免将通道剪枝到0
    group.prune()
print(model, group)
output = model(torch.randn(1,3,224,224)) # 尝试运行剪枝后的网络

上述过程一共三步:

  • 1 对网络进行依赖图构建;
  • 2 选取需要剪枝的层,指定剪枝通道,获得分组group;这里的group,是所有与conv1相依赖的层。
  • 3 执行剪枝操作,按组移除通道。那剪枝过程具体操作了哪些层呢?

左图为剪枝前的conv1 的group,右图为剪枝后的conv1 的group。
怎么去看这个group呢,在下图右侧进行了简单的标注,可以发现conv1的group都会进行剪枝,从而适应conv1的卷积核的维度发生的变化
在这里插入图片描述
左图为剪枝前的resnet结构部分,右图为剪枝后的resnet结构部分。
在这里插入图片描述


2.3 使用TP对网络中每个层进行剪枝

在实际实现时,我们希望是对整个网络结构进行剪枝,而非特定的某几层,这就涉及到如何不重复地遍历网络中所有分组的问题。
DepGraph提供了接口DG.get_all_groups来实现这以目标。该接口仅实现层的分组,并不会分辨通道的重要性。该接口包含两个参数

  • ignored_layers:指定忽略 某些希望被剪枝的层。通常包括最后的分类层、以及报错的层(也可以使用其它正确的层进行替换)
  • root_module_types:指定了每个组的起始层的类型。比如想剪枝所有的卷基层,而不想剪枝全连接层,只需要只传入对应的卷积类即可。
    值得注意的是,不同层可能出现在同一分组中,Depgraph会自动去除重复分组。

下面先提前设定好需要剪枝的通道,来展示DG.get_all_groups的使用:

import torch
import torch.nn as nn
import torch_pruning as tp
from torchvision.models import resnet18

model = resnet18(pretrained=True).eval()

# 1. 构建依赖图
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. 获取与model.conv1存在依赖的所有层,并指定需要剪枝的通道索引(此处我们剪枝第[2,6,9]个通道)
Groups = DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear])

# 3. 执行剪枝操作
for group in Groups:
    idxs = [2,4,6] # your pruning indices
    group.prune(idxs=idxs)
    print(group)

output = model(torch.randn(1,3,224,224)) # 尝试运行剪枝后的网络

但该段代码剪枝,在TP实际剪枝也是较少使用,这里是展示一个剪枝底层的基本操作。

3 TP对完整网络的剪枝


3.1 常用的结构化剪枝原理

结构化和非结构化剪枝方向,已发表的有较多的论文。但在工业上较常用的为结构化剪枝。实际中最常用的结构化剪枝方法:

  • 利用权值进行filter剪枝:Pruning Filters for Efficient ConvNets
    在这张图中,我们可以找到两个卷积参数矩阵(Kernel Matrix):第一个卷积层以 x i x_i xi 作为输入,输出特征图 x i + 1 x_{i+1} xi+1;第二个卷积层以 x i + 1 x_{i+1} xi+1作为输入,生成特征图 x i + 2 x_{i+2} xi+2
    在结构化剪枝中,这两个卷基层之间存在非常直观的依赖关系,即当我们调整第一层的输出通道时,第二个卷积层的输入通道也需要相应的进行调整,这使得蓝色高亮的参数需要同时被剪枝。
    在这里插入图片描述
    此外,作者指出网络中可能存在更复杂的依赖,例如残差结构依赖:
    在这里插入图片描述
  • 利用bn进行剪枝:Learning Efficient Convolutional Networks through Network Slimming。
    在这里插入图片描述
    BN会按通道对输入特征进行归一化,使得不同的特征处于比较接近的范围内。我们将缩放因子(从批量归一化层重用)与卷积层中的每个通道相关联。稀疏正则化在训练期间被施加在这些缩放因子上,以自动识别不重要的通道。缩放因子值较小(橙色)的通道将被修剪(左侧)。修剪后,我们获得紧凑模型(右侧),然后对其进行微调,以实现与正常训练的全网络相当(甚至更高)的精度。
    在任何一个网络中,BN的scale参数都具备一定的绝对值大小(也就是不会过小),这意味着各个通道都具有不可忽略的重要性。解决这类问题的一种有效方法是使用稀疏训练,通过对scale参数施加正则化项来稀疏化一部分通道。在slimming论文中,作者对scale参数施加了一个额外的L1正则化项,从而实现了这一过程。整个流程如下所示:稀疏训练–>剪枝–>微调
    在这里插入图片描述

3.2 TP剪枝示例

网络中存在大量复杂依赖的情况下,如何进行剪枝呢?
【1】计算网络每个group中每层的重要性

  • Torch-Puning 库内置了处理依赖的功能,并提供了可扩展的接口用于自定义剪枝器。
    tp.importance.Importance 要求我们实现一个非常简单的接口 __call__
    • 入参为一个 group,包含了多个相互耦合的层。
    • 输出为一个一维的重要性的得分向量,其含义是每个通道的重要性,因此他的维度和通道数量是相同的。
      由于输入的group通常会包含多个可剪枝层,因此我们首先对这些层进行独立的重要性计算,然后通过求平均值得到最终结果。
  • Torch-Puning也提供了常用重要性评估策略:
    tp.importance.MagnitudeImportance(p=2)p=2表示使用L2正则,对每个group中的每个层的权值,独立的计算重要性 。
    tp.importance.BNScaleImportance():利用BN计算每个group中的每个层的权值的重要性
    tp.importance.GroupNormImportance():与继承于MagnitudeImportance,且没做任何的添加和修改。

【2】对网络进行剪枝

  • Torch-Pruning库定义了一个元剪枝器 tp.pruner.MetaPruner,能够完成除了重要性评估之外的所有工作。一般常在自定义的重要性评估后,执行剪枝时使用
  • Torch-Puning也提供了常用的剪枝策略
    tp.pruner.MagnitudePruner()
    tp.pruner.BNScalePruner()
    tp.pruner.GroupNormPruner() Depgraph 提出的基于全局重要性的剪枝

【3】例子
为了增加难度,这里我们对一个DenseNet模型进行剪枝。
这里只展示了稀疏训练和微调使用的位置,仅剪枝部分能够有效跑通。

import torch
import torch.nn as nn
import torch_pruning as tp
from torchvision.models import densenet121

model = densenet121(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 1. 使用我们上述定义的重要性评估
# imp = tp.importance.MagnitudeImportance(p=2)
# imp = tp.importance.BNScaleImportance()
imp = tp.importance.GroupNormImportance()

# 2. 忽略无需剪枝的层,例如最后的分类层
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

# 3. 初始化剪枝器
iterative_steps = 5 # 迭代式剪枝,重复5次Pruning-Finetuning的循环完成剪枝。  
# pruner = tp.pruner.MagnitudePruner(
# pruner = tp.pruner.BNScalePruner(
pruner = tp.pruner.GroupNormPruner(
    model,
    example_inputs, # 用于分析依赖的伪输入
    importance=imp, # 重要性评估指标
    iterative_steps=iterative_steps, # 迭代剪枝,设为1则一次性完成剪枝
    ch_sparsity=0.5, # 目标稀疏性,这里我们移除50%的通道 ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers, # 忽略掉最后的分类层
)

# 4. 稀疏训练(为了节省时间我们假装在训练,实际应用时只需要在optimizer.step前插入regularize即可)
for _ in range(100):
    pass
    # optimizer.zero_grad() 
    # ...
    # loss.backward()
    # pruner.regularize(model, reg=1e-5) # <== 插入该行进行稀疏化
    # optimizer.step()
    
# 4. Pruning-Finetuning的循环
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step() # 执行裁剪,本例子中我们每次会裁剪10%,共执行5次,最终稀疏度为50%
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("  Iter %d/%d, Params: %.2f M => %.2f M" % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6))
    print("  Iter %d/%d, MACs: %.2f G => %.2f G"% (i+1, iterative_steps, base_macs / 1e9, macs / 1e9))
    # finetune your model here
    # finetune(model)
    # ...
print(model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/175677.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

redis的集群

高可用方案 1、持久化 2、高可用 主从复制 哨兵模式 集群 主从复制: 主从复制是redis实现高可用的基础&#xff0c;哨兵模式和集群都是在主从复制的基础之上实现高可用 主从复制实现数据的多机备份&#xff0c;以及读写分离&#xff08;主服务器负责写&#xff0c;从服务器…

云HIS系统源码,医院管理系信息统源码,融合B/S版四级电子病历系统

医院管理信息系统是以推进公共卫生、医疗、医保、药品、财务监管信息化建设为着力点&#xff0c;整合资源&#xff0c;加强信息标准化和公共服务信息平台建设&#xff0c;逐步实现统一高效、互联互通的管理系统。 SaaS模式Java版云HIS系统&#xff0c;在公立二甲医院应用三年…

代餐粉产业分析:中国市场销售额增长至116.94亿元

近年来&#xff0c;随着人们生活节奏的加快和健康意识的增强&#xff0c;代餐粉市场规模逐渐壮大。在这个忙碌的时代&#xff0c;快捷、营养而又方便的代餐粉成为了许多人选择的首选。 随着健康理念的不断普及和推广&#xff0c;人们开始更加重视日常饮食的健康与营养。代餐粉作…

Vellum —— 简介

目录 一&#xff0c;介绍 二&#xff0c;原理 三&#xff0c;PBD算法 一&#xff0c;介绍 Vellum是一个解算模拟框架&#xff0c;使用更高级的PBD&#xff08;XPBD&#xff0c;extended position based dynamics&#xff09;&#xff0c;是2nd Order Integration&#xff08…

Go 实现网络代理

使用 Go 语言开发网络代理服务可以通过以下步骤完成。这里&#xff0c;我们将使用 golang.org/x/net/proxy 包来创建一个简单的 SOCKS5 代理服务作为示例。 步骤 1. 安装 golang.org/x/net/proxy 包 使用以下命令安装 golang.org/x/net 包&#xff0c;该包包含 proxy 子包&am…

2023亿发数字化智能工单,专业管理工单处理全流程,助力企业转型腾飞

伴随着智能化和信息化的不断深入&#xff0c;企业数字化转型势如腾飞。在这个过程中&#xff0c;工单管理成为生产、家电、后勤等多个管理场景下频繁应用的关键环节。如何满足管理方对设备、服务等智能化管理的需求&#xff0c;提升工单管理效率、规范管理流程&#xff0c;并实…

问题:vue2+elementui,tabs切换显示表格并设置表格选中行高亮失败

错误示范&#xff1a; 1.直接setCurrentRow失败&#xff08;this.currentRow是之前保存的表格当前选中行的数据&#xff09; this.$refs.table.setCurrentRow(this.currentRow);2.以为是表格没生成就执行了setCurrentRow导致设置不成功&#xff0c;所以使用了this.$nextTick&…

英国国家量子计算中心与IBM签署重要协议!英国进入实用量子时代

​&#xff08;图片来源&#xff1a;网络&#xff09; 近日&#xff0c;英国国家量子计算中心&#xff08;NQCC&#xff09;与IBM达成了一项重要协议。根据该协议&#xff0c;NQCC将为英国研究人员提供IBM量子高级计划的云访问权限&#xff0c;其中包括IBM的量子计算系统舰队。…

SpringBoot Admin

前言 Spring Boot Admin 是一个管理和监控 Spring Boot 应用程序的开源项目&#xff0c;它提供了一个简洁的 Web 界面来监控 Spring Boot 应用程序的状态和各种运行时指标。Spring Boot Admin 可以帮助开发者快速了解应用程序的状态&#xff0c;并快速定位错误或性能问题。下面…

赛氪荣幸受邀参与中国联合国采购促进会第五次会员代表大会

11 月21 日 &#xff08;星期二&#xff09; 下午14:00&#xff0c;在北京市朝阳区定福庄东街1号中国传媒大学&#xff0c;赛氪荣幸参与中国联合国采购促进会第五次会员代表大会。 2022年以来&#xff0c;联合国采购杯全国大学生英语大赛已经走上了国际舞台&#xff0c;共有来自…

HC32L110小华半导体SWD模式切换的问题

在将SWD配置为普通引脚并配置为输出后&#xff0c;如果需要重新配置为SWD&#xff0c;需要将其配置为输入才行&#xff0c;如下&#xff1a; Clk_SetFunc(ClkFuncSwdPinIOEn, TRUE); //配置SWD引脚为普通引脚模式 Gpio_InitIOExt(SWCLK_PORT, SWCLK_PIN, GpioDirOut, TRUE,…

垃圾收集器的种类及概述

1.JVM参数 1.1标准参数所有jdk版本通用参数 -version -help -server -cp 1.2-X参数 非标准参数&#xff0c;也就是在JDK各个版本中可能会变动 -Xint 解释执行 -Xcomp 第一次使用就编译成本地代码 -Xmixed 混合模式&#xff0c;JVM自己来决定 1.3 -XX参数 使用得最多…

一个测试驱动的Spring Boot应用程序开发

文章目录 系统任务用户故事搭建开发环境Web应用的框架Spring Boot 自动配置三层架构领域建模域定义与领域驱动设计领域类 业务逻辑功能随机的Challenge验证 表示层RESTSpring Boot和REST API设计API第一个控制器序列化的工作方式使用Spring Boot测试控制器 小结 这里采用面向需…

悄悄上线:CSS @starting-style 新规则

最近 Chrome 117&#xff0c;CSS 又悄悄推出了一个新的的规则&#xff0c;叫做starting-style。从名称上来看&#xff0c;表示定义初始样式。那么&#xff0c;具体是做什么的&#xff1f;有什么用&#xff1f;一起了解一下吧 一、快速了解 starting-style 通常做一个动画效果…

vue3引入vuex基础

一&#xff1a;前言 使用 vuex 可以方便我们对数据的统一化管理&#xff0c;便于各组件间数据的传递&#xff0c;定义一个全局对象&#xff0c;在多组件之间进行维护更新。因此&#xff0c;vuex 是在项目开发中很重要的一个部分。接下来让我们一起来看看如何使用 vuex 吧&#…

OpenLayers入门,OpenLayers6的WebGLPointsLayer图层样式和运算符详解,四种symbolType类型案例

专栏目录: OpenLayers入门教程汇总目录 前言 本章讲解使用OpenLayers6的WebGL图层显示大量点情况下,列举出所有WebGLPointsLayer图层所支持的所有样式运算符大全。 补充说明 本篇主要介绍OpenLayers6.x版本的webgl图层,OpenLayers7.x和OpenLayers8.x主要更新内容就是webgl…

任意文件下载漏洞(CVE-2021-44983)

简介 CVE-2021-44983是Taocms内容管理系统中的一个安全漏洞&#xff0c;可以追溯到版本3.0.1。该漏洞主要源于在登录后台后&#xff0c;文件管理栏存在任意文件下载漏洞。简言之&#xff0c;这个漏洞可能让攻击者通过特定的请求下载系统中的任意文件&#xff0c;包括但不限于敏…

单链表相关面试题--5.合并有序链表

5.合并有序链表 21. 合并两个有序链表 - 力扣&#xff08;LeetCode&#xff09; /* 解题思路&#xff1a; 此题可以先创建一个空链表&#xff0c;然后依次从两个有序链表中选取最小的进行尾插操作进行合并。 */ typedef struct ListNode Node; struct ListNode* mergeTwoList…

C++ Boost 实现异步端口扫描器

端口扫描是一种用于识别目标系统上哪些网络端口处于开放、关闭或监听状态的网络活动。在计算机网络中&#xff0c;端口是一个虚拟的通信端点&#xff0c;用于在计算机之间传输数据。每个端口都关联着特定类型的网络服务或应用程序。端口扫描通常是网络管理员、安全专业人员用来…

PyTorch微调终极指南1:预训练模型调整

如今&#xff0c;在训练深度学习模型时&#xff0c;通过根据自己的数据微调预训练模型来进行迁移学习&#xff08;transfer learning&#xff09;已成为首选方法。 通过微调这些模型&#xff0c;我们可以利用他们的专业知识并使它们适应我们的特定任务&#xff0c;从而节省宝贵…
最新文章