Pytorch剪枝api测试和结果

Pytorch 官方给出的prune接口

下面是基于prune的接口进行剪枝的方法步骤

1、首先prune接口在 torch.nn.utils.prune中,目前支持的剪枝方法有:

  • RandomUnstructured
  • L1Unstructured
  • RandomStructured
  • LnStructured
  • CustomFromMask
    ps:非结构性剪枝不会给剪枝后模型的速度带来提升。

2、选择一个方法,定义好一个model后,将要剪枝的模块,及模块剪枝的部分作为函数的参数传入剪枝参数

from torch.nn.utils import prune 
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
'''
module: 模型的模块名字,如 model.conv1、model.fc1 ,这些跟你在构建模型时有关,可以用 models.state_dict().keys() 查看
name:模块中要剪枝的部分,可以是、weight、bias
amount:指的是模型本次剪枝的概率
n:前面使用的是ln_structured 模型,n表示使用那种剪枝策略,L1、L2、L3
dim:表示对第几个维度进行剪枝,如卷积层可以是维度 0123
'''

3、剪枝完后会产生一个weight_mask的掩码,本身不会直接作用于模型,会产生一个weight的属性,这时候原module是不存在weight的parameter,仅仅是一个attribute
如果此时输出模型的model.state_dict().keys()
之前是 conv1.weight 变成了 conv1.weight_orig ,以及conv1.weight_mask
4、此时模型的参数仍然是没有发生变化的,需要对剪枝后的模型进行保存

prune.remove(module, 'weight')
print(list(module.named_parameters()))

5、此时模型保存的是剪枝之后的权重值,同时weight_orig已经被删除掉了
6、所以直接对每一层需要剪枝的地方选择一个剪枝方法后,直接进行剪枝就可以了,然后保存模型此时的状态参数。

对模型进行全局剪枝,prune只提供了一个全局剪枝的接口global_unstructured()

import torch.nn.utils.prune as pt_prune
pt_prune.global_unstructured(parameters_to_prune,pruning_method=pt_prune.L1Unstructured,amount=amount)
'''
parameters_to_prune:list 待剪枝模块的 名字
pruning_method:全局剪枝的方法
amount:剪枝率
'''

然后对剪枝后的模块进行remove操作即可
但是全局剪枝,只支持非结构性剪枝

prune全局非结构性剪枝测试结果

# 推理模型tiny-yolov4


def model_global_prune(amount: float):
    detect_model = Darknet(
        '/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/cfg/yolov4-tiny.cfg')  # TODO:改成相对路径
    detect_model.load_weights(
        "/Users/wuzhensheng01/Documents/wzs/code/yolov4-tiny-model_pruning/weight/yolov4-tiny.weights")

    parameters_to_prune = list()
    nums = 0
    for i, modules in enumerate(detect_model.models):
        if isinstance(modules, nn.Sequential):  
            for j, module in enumerate(modules):
                if isinstance(detect_model.models[i][j], nn.Conv2d):
                    nums += 1
                    parameters_to_prune.append((detect_model.models[i][j], 'weight'))
                elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):
                    nums += 2
                    parameters_to_prune.append((detect_model.models[i][j], 'weight'))
                    parameters_to_prune.append((detect_model.models[i][j], 'bias'))

    parameters_to_prune = tuple(parameters_to_prune)
    assert (nums == len(parameters_to_prune))
    pt_prune.global_unstructured(parameters_to_prune,
                                 pruning_method=pt_prune.L1Unstructured,
                                 amount=amount)
    for i, modules in enumerate(detect_model.models):
        if isinstance(modules, nn.Sequential):  
            for j, module in enumerate(modules):
                if isinstance(detect_model.models[i][j], nn.Conv2d):
                    pt_prune.remove(detect_model.models[i][j], 'weight')
                elif isinstance(detect_model.models[i][j], nn.BatchNorm2d):
                    pt_prune.remove(detect_model.models[i][j], 'weight')
                    pt_prune.remove(detect_model.models[i][j], 'bias')

    return detect_model

base_line:
base model average time : 0.2082s
bicycle:0.605963
truck:0.814734
dog:0.870323

case 1: 剪枝率0.5 只剪卷积层.
model_pruned average time:0.2122
bicycle:0.597527
truck:0.825150
dog:0.592364

case2: 全局非结构性剪枝 剪枝率0.2
model_pruned average time:0.2078
bicycle:0.637542
truck:0.839107
dog:0.851859

case4:全局非结构性剪枝 剪枝率0.5 只剪bn层
‘’‘精度降为0’‘’

case4:全局非结构性剪枝 剪枝率0.2 只剪bn层
model_pruned average time : 0.2666
truck: 0.714715
truck: 0.594537
cat: 0.435578

case5:全局非结构性剪枝 剪枝率
model_pruned average time:0.2138
bicycle:0.636104
truck:0.840595
dog:0.850322

prune结构性剪枝测试结果

通过L2方法对模型的卷积层进行结构化剪枝(剪枝率0.5、0.4、0.2、0.1),剪枝完后模型的速度并没有变快,相反,模型的精度大幅度的下降,(模型精度下降的问题不知道是不是需要进行重新训练来提升,但是模型的速度并未得到提升)

结论:对于训练好的模型,prune接口只是提供了一种方法去“剪掉”模型每一层中最不重要的结构。而并没有稀疏训练这一步,导致在结构性剪枝中,模型的精度大幅度下降map趋近于0。同时剪枝方法只是使用简单的L1或L2对权重参数进行计算。
此外,接口中的“剪枝”只是找到模型中那些位置不重要参数,生成相应大小的掩膜,把不重要的位置置0,但是并没有删除与这些位置相连的前后层(只针对结构性剪枝而言),最后模型的权重大小并未发生改变,只是不重要的位置的参数大小变为了0,使得模型的速度并未提升。即使模型剪枝率达到95%,模型的速度仍与baseline保持一致。

结论:pytorch的官方接口并不能直接使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/14200.html

如若内容造成侵权/违法违规/事实不符,请联系三亩地网进行投诉反馈,一经查实,立即删除!

相关文章

面试华为,花了2个月才上岸,真的难呀····

花2个月时间面试一家公司,你们觉得值吗? 背景介绍 美本计算机专业,代码能力一般,之前有过两段实习以及一个学校项目经历。第一份实习是大二暑期在深圳的一家互联网公司做前端开发,第二份实习由于大三暑假回国的时间比…

32岁阿里P7,把简历改成不知名小公司,学历改成普通本科,工作内容不变,投简历全挂!...

hr靠什么来招人? 一位猎头讲述了自己和朋友打赌的故事: 朋友在阿里云,32岁,P7,他把简历上的公司改成不知名,学历改成普通本科,工作内容不变,结果投其他公司(比如京东&…

Spring Boot异步任务、异步消息

目录 1.异步任务 1.1.概述 1.2.使用 2.异步消息 2.1.概述 2.2.使用 1.异步任务 1.1.概述 举一个例子,我现在有一个网上商城,客户在界面点击下单后,后台需要完成两步: 1.创建客户订单 2.发短信通知客户订单号 这里面第2…

【hello Linux】理解文件系统

目录 创建文件的过程: 删除文件的过程: 创建目录的过程: 查看inode编号: 硬链接 软链接 Linux🌷 我们知道文件所有数据 文件内容 文件属性信息; 未打开的文件是被存放到磁盘/固态硬盘中的; …

《前端bug齁逼多,真假开发说》2023/4/10-2023/4/18问题汇总

1 高德地图 运行抱错 INVALID_USER_SCODE 这里是错误信息对应原因 错误信息列表-参考手册-地图 JS API | 高德地图API 这里是高德地图api设置说明 准备-入门-教程-地图 JS API | 高德地图API 如果你自己能排查出错误 那不用看我的,如果都写的对还是抱错…

1686_MATLAB处理Excel文件

全部学习汇总: GreyZhang/g_matlab: MATLAB once used to be my daily tool. After many years when I go back and read my old learning notes I felt maybe I still need it in the future. So, start this repo to keep some of my old learning notes servral …

非常详细的阻抗测试基础知识

编者注:为什么要测量阻抗呢?阻抗能代表什么?阻抗测量的注意事项... ...很多人可能会带着一系列的问题来阅读本文。不管是数字电路工程师还是射频工程师,都在关注各类器件的阻抗,本文非常值得一读。全文13000多字&#…

基于html+css的图片展示17

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

2023年4月-近期看书

复习书记 用于读书 文章目录 复习书记一、(2001)控制工程基础二、(3001)交通管理与控制三、(1001)英语 一、(2001)控制工程基础 学习这本书的前6章节。 参看视频链接: https://www.bilibili.com/video/BV1Sb411q7jU?p8&spm_id_frompageDriver&vd_source…

数字化转型危与机,20年老厂的升级之路

“投资大、周期长、见效慢”,是每一家企业在考虑数字化战略时,都会纠结的问题。 打江山容易,守江山难 企业在快速扩张的过程中,往往可以不需要过多的考虑细节的问题,跑马圈地的打法会更加有效。 但是市场占有量开始饱…

瀚高股份吕新杰:创新开源双驱动,躬耕国产数据库

作者 | 伍杏玲 近年来,国际形势不断变幻,也给人们带来巨大警示:关键核心技术是买不来、讨不来的,中国科技企业需寻找研发自强之路。 瀚高基础软件股份有限公司(简称瀚高股份)专注数据库十八年,始…

大厂面试-算法优化:冒泡排序你会优化吗?

关注公众号:”奇叔码技术“ 回复:“java面试题大全”或者“java面试题” 即可领取资料 原文:冒泡排序及优化代码 https://blog.csdn.net/weixin_43989347/article/details/122025689原文:十大经典排序算法 https://frxcat.fun/p…

史上最详细的八大排序详解!(建议收藏)

🚀write in front🚀 📜所属专栏:初阶数据结构 🛰️博客主页:睿睿的博客主页 🛰️代码仓库:🎉VS2022_C语言仓库 🎡您的点赞、关注、收藏、评论,是对…

问题排查记录-ffmpeg链接libavfilter和libavcodec:未定义的引用

目录 一、问题背景 二、问题现象 2.1 ffmpeg测试例程 2.2 编译脚本 2.3 错误提示 三、问题排查 3.1 关于提示找不到“stdio" "iostream"头文件的问题 3.1.1查看工具链头文件检索位置 3.1.2 根据工具链路径查找头文件 3.1.3 在编译脚本中指定头文件路径…

第一章 Maven概述

第一节 为什么要学习Maven? maven-作为依赖管理工具 ①jar 包的规模 随着我们使用越来越多的框架,或者框架封装程度越来越高,项目中使用的jar包也越来越多。项目中,一个模块里面用到上百个jar包是非常正常的。 比如下面的例子…

Flex布局

flex是 W3C 提出的一种新的布局方案 当我将某一元素设置为 display:flex 时,这个元素所包含的直接子元素就成为了我的子民 但是我发现我无法控制我的子民, 首先我要解决的是我要控制子民的方向 flex-direction: row 以行排列row-reverse…

Linux-初学者系列2——用户组管理和权限管理

用户组管理和权限管理 Linux-初学者系列2_用户组管理和权限管理一、所有者1、查看文件的所有者指令 2、修改文件所有者指令实操 二、组创建语法指令:实操: 三、所在组1、查看文件/目录所在组基本指令:实操: 2、修改文件所在组基本…

在当前互联网行情下,Android想转音视频开发,会有前景吗?

前言 近年来,由于三年疫情的影响,很多公司都开始陆陆续续的在裁员,Android开发工作岗位也是,可能有些从事Android开发的朋友还没有意识到,Android开发岗位正在变少,求职者,僧多粥少&#xff0c…

020:Mapbox GL加载高德地图(影像瓦片图)

第020个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+mapbox中加载高德地图(影像瓦片图)。 直接复制下面的 vue+mapbox源代码,操作2分钟即可运行实现效果 文章目录 示例效果配置方式示例源代码(共80行)相关API参考:专栏目标示例效果 配置方式 1)查看基础设置:…

在金融领域使用机器学习的 9个技巧

机器学习已经倍证明可以预测结果和发掘隐藏的数据模式。但是必须小心使用,并遵循一些规则,否则就会在数据的荒野中徘徊而无所获。使用机器学习进行交易的道路充满了陷阱和挑战,只有那些勤奋认真地遵循规则的人才能从中获得收益。下面是一些技…