【剪枝实战】使用VGGNet训练、稀疏训练、剪枝、微调等,剪枝出只有3M的模型

摘要

本次剪枝实战是基于下面这篇论文去复现的,主要是实现对BN层的γ/gamma进行剪枝操作,本文用到的代码和数据集都可以在我的资源中免费下载到。

在这里插入图片描述

相关论文:Learning Efficient Convolutional Networks through Network Slimming (ICCV 2017)

剪枝的原理

在一个卷积-BN-激活模块中,BN层可以实现通道的缩放。如下:
在这里插入图片描述
BN层的具体操作有两部分:
在这里插入图片描述
在归一化后会进行线性变换,那么当系数gamma很小时候,对应的激活(Zout)会相应很小。这些响应很小的输出可以裁剪掉,这样就实现了bn层的通道剪枝。

剪枝的过程

第一步、使用VGGNet训练模型。保存训练结果,方便将来的比对!
第二步、在BN层网络中加入稀疏因子,训练模型。
第三步、剪枝操作。
第四步、fine-tune模型,提高模型的ACC。

接下来,我们一起实现对VGGNet的剪枝。

一、项目结构

Slimming_Demo
├─checkpoints
│ ├─vgg
│ ├─vgg_pruned
│ └─vgg_sp
├─data
│ ├─train
│ │ ├─Black-grass
│ │ ├─Charlock
│ │ ├─Cleavers
│ │ ├─Common Chickweed
│ │ ├─Common wheat
│ │ ├─Fat Hen
│ │ ├─Loose Silky-bent
│ │ ├─Maize
│ │ ├─Scentless Mayweed
│ │ ├─Shepherds Purse
│ │ ├─Small-flowered Cranesbill
│ │ └─Sugar beet
│ └─val
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ └─Shepherds Purse
├─vgg.py
├─train.py
├─train_sp.py
├─prune.py
└─train_prune.py

train.py:训练脚本,训练VGGNet原始模型 vgg.py:模型脚本 train_sp.py:稀疏训练脚本。 prune.py:模型剪枝脚本。 train_prune.py:微调模型脚本。

二、正常训练

首先,我们进行正常训练出一个模型,用来和剪枝后的模型进行对比。原始模型有67.8M,体积非常大。

运行train.py,正常训练结果保存为checkpoints/vgg/best.pth

通过Tensorboard工具,我们观察到,此时的BN层的gamma系数并不稀疏,很难判断出各个通道的重要性,不利于剪枝!

在这里插入图片描述

三、稀疏训练

运行train_sp.py,稀疏化训练结果保存为checkpoints/vgg_sp/best.pth

使用tensorboard查看runs/路径下保存的log,查看稀疏化训练结果。

可以看到通过稀疏化训练,==此时的gamma系数已经呈现趋近0的分布,==我们可以对gamma接近0的channel进行剪枝!

在这里插入图片描述

核心代码
在这里插入图片描述

四、剪枝

有了稀疏化训练的模型,我们便可以通过读取checkpoints/vgg_sp/best.pth,根据gamma系数来判断网络中各个通道的重要性,剪去不重要的通道。

运行prune.py,剪枝的结果保存为checkpoints/vgg_pruned/best.pth

查看剪枝后的VGG模型,只有3.58M,体积非常小,剪枝瘦身效果非常明显

在这里插入图片描述

五、微调

对模型剪枝后,为了恢复模型的性能,我们还需要对剪枝的模型进行Finetune

运行train_prune.py

checkpoints/vgg_pruned/best.pth进行微调,结果生成在checkpoints/vgg_pruned_finetune/best.pth

在这里插入图片描述

六、总结

我们通过复现Learning Efficient Convolutional Networks through Network Slimming (ICCV 2017)论文里提到的网络大瘦身剪枝算法,完成对VGG模型进行剪枝。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/457966.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【小白学机器学习9】自己纯手动计算验证,EXCEL的一元线性回归的各种参数值

目录 0 目标 1 构造模型 1.1 构造模型的思路 1.2 具体模型构造的EXCEL公式和过程 2 直接用EXCEL画图,然后生成趋势线的方式进行回归分析 2.1 先选择“观测值Y”的数据,用散点图或者折线图作图 2.2 然后添加趋势线和设置趋势线格式 2.3 生成趋…

服务器Debian 12.x中安装Jupyer并配置远程访问

服务器系统:Debian 12.x;IP地址:10.100.2.138 客户端:Windows 10;IP地址:10.100.2.38 利用ssh登录服务器: 1.安装python3 #apt install python3 2.安装pip #apt install python3-pip … 3.安装virtualen…

HBase分布式数据库的原理和架构

一、HBase简介 HBase是是一个高性能、高可靠性、面向列的分布式数据库,它是为了在廉价的硬件集群上存储大规模数据而设计的。HBase利用Hadoop HDFS作为其文件存储系统,且Hbase是基于Zookeeper的。 二、HBase架构 *图片引用 Hbase采用Master/Slave架构…

PTA-练习1

目录 实验2-3-8 计算火车运行时间 实验2-4-4 求简单交错序列前N项和 实验2-4-5 输出华氏-摄氏温度转换表 实验3-4 统计字符[2] 实验3-5 查询水果价格 实验3-11 求一元二次方程的根 实验4-1-1 统计数字字符和空格 实验2-3-8 计算火车运行时间 时钟数有两种情况&#xff1…

使用BBDown下载bilibili视频的方法

一款命令行式哔哩哔哩下载器. Bilibili Downloader. 下载地址 https://github.com/nilaoda/BBDown 功能 番剧下载(Web|TV|App) 课程下载(Web) 普通内容下载(Web|TV|App) 合集/列表/收藏夹/个人空间解析 多分P自动下载 选择指定分P进行下载 选择指定清晰度进行下载 下载外挂字幕…

解决驱动开发中<stdlib.h> no such file 的问题

前言 在进行驱动开发时&#xff0c;需要使用malloc等函数&#xff0c;导入C库<stdlib.h>出现bug。 嵌入式驱动学习专栏将详细记录博主学习驱动的详细过程&#xff0c;未来预计四个月将高强度更新本专栏&#xff0c;喜欢的可以关注本博主并订阅本专栏&#xff0c;一起讨论…

java并发编程之 volatile关键字

1、简单介绍一下JMM Java 内存模型&#xff08;Java Memory Model 简称JMM&#xff09;是一种抽象的概念&#xff0c;并不真实存在&#xff0c;指一组规则或规范&#xff0c;通过这组规范定义了程序中各个变量的访问方式。java内存模型(JMM)屏蔽掉各种硬件和操作系统的内存访问…

OpenvSwitch VXLAN 隧道实验

OpenvSwitch VXLAN 隧道实验 最近在了解 openstack 网络&#xff0c;下面基于ubuntu虚拟机安装OpenvSwitch&#xff0c;测试vxlan的基本配置。 节点信息&#xff1a; 主机名IP地址OS网卡node1192.168.95.11Ubuntu 22.04ens33node2192.168.95.12Ubuntu 22.04ens33 网卡信息&…

XUbuntu22.04之关闭todesk开机自启动(二百二十一)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

Python数据分析-4

1.对于一组电影数据&#xff0c;呈现出rating,runtime的分布情况&#xff1a; #encodingutf-8 import pandas as pd import numpy as np from matplotlib import pyplot as plt file_path "./youtube_video_data/IMDB-Movie-Data.csv" df pd.read_csv(file_path) …

React低代码平台实战:构建高效、灵活的应用新范式

文章目录 每日一句正能量前言一、React与低代码平台的结合优势二、基于React的低代码平台开发挑战三、基于React的低代码平台开发实践后记好书推荐编辑推荐内容简介作者简介目录前言为什么要写这本书 读者对象如何阅读本书 赠书活动 每日一句正能量 人生之美&#xff0c;不在争…

C# 根据两点名称,寻找两短路程的最优解,【有数据库设计,完整代码】

前言 如果我们遇到路径问题&#xff0c;可以使用点点连线&#xff0c;给定一个点&#xff0c;可以到达另外几个点&#xff0c;寻找最优解 例&#xff1a;如下图所示&#xff0c;如果要从A1-C1,可以有三条路 1.A1-B1-C1 2.A1-B2-C1 3.A1-B3-C1 最优解肯定是A1-B1-C1&#xff0c…

15. jwt认证中间件

在上一篇登录功能的实现中&#xff0c;我们使用了jwt作为鉴权组件&#xff0c;其中登录后会颁发token。前端在访问后续请求时&#xff0c;可以带上这个token。对于一些需要权限校验的请求&#xff0c;我们就需要验证这个token&#xff0c;从token中获取到用户id&#xff08;有了…

Unity Timeline学习笔记(1) - 创建TL和添加动画片段

Timeline在刚出的时候学习了一下&#xff0c;但是因为一些原因一直都没用在工作中使用。 版本也迭代了很久不用都不会用了&#xff0c;抽时间回顾和复习一下&#xff0c;做一个笔记后面可以翻出来看。 创建Timeline 首先我们创建一个场景&#xff0c;放入一个Plane地板&#…

【机器学习智能硬件开发全解】(四)—— 政安晨:嵌入式系统基本素养【后摩尔时代】

随着物联网、大数据、人工智能时代的到来&#xff0c;海量的数据分析、大量复杂的运算对CPU的算力要求越来越高。 CPU内部的大部分资源用于缓存和逻辑控制&#xff0c;适合运行具有分支跳转、逻辑复杂、数据结构不规则、递归等特点的串行程序。 在集成电路工艺制程将要达到极…

PgSQL技术内幕 - 优化器如何估算行数

PgSQL技术内幕 - 优化器如何估算行数 PgSQL优化器根据统计信息估算执行计划路径的代价&#xff0c;从而选择出最优的执行计划。而这些统计信息来自pg_statistic&#xff0c;当然这个系统表是由ANALYZE或者VACUUM进行样本采集而来。关于该系统表的介绍详见&#xff1a;PgSQL技术…

水泵房远程监控物联网系统

随着物联网技术的快速发展&#xff0c;越来越多的行业开始利用物联网技术实现设备的远程监控与管理。水泵房作为城市供水系统的重要组成部分&#xff0c;其运行状态的监控与管理至关重要。HiWoo Cloud作为专业的物联网云服务平台&#xff0c;为水泵房远程监控提供了高效、稳定、…

2.1HTML5基本结构

HTML5实际上不算是一种编程语言&#xff0c;而是一种标记语言。HTML5文件是由一系列成对出现的元素标签嵌套组合而成&#xff0c;这些标签以<元素名>的形式出现&#xff0c;用于标记文本内容的含义。浏览器通过元素标签解析文本内容并将结果显示在网页上&#xff0c;而元…

基于centos7的k8s最新版v1.29.2安装教程

k8s概述 Kubernetes 是一个可移植、可扩展的开源平台&#xff0c;用于管理容器化的工作负载和服务&#xff0c;可促进声明式配置和自动化。 Kubernetes 拥有一个庞大且快速增长的生态&#xff0c;其服务、支持和工具的使用范围相当广泛。 Kubernetes 这个名字源于希腊语&…

CentOS无法解析部分网站(域名)

我正在安装helm软件&#xff0c;参考官方文档&#xff0c;要求下载 get-helm-3 这个文件。 但是我执行该条命令后&#xff0c;报错 连接被拒绝&#xff1a; curl -fsSL -o get_helm.sh https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 # curl: (7) Fai…