spikingjelly训练自己的网络---量化 --测试

在这里插入图片描述
在这里插入图片描述
第二个=================

在这里插入图片描述
在这里插入图片描述
但是我发现,都要反量化,因为pytorch是只能支持浮点数的。

在这里插入图片描述

https://blog.csdn.net/lai_cheng/article/details/118961420
Pytorch的量化大致分为三种:模型训练完毕后动态量化、模型训练完毕后静态量化、模型训练中开启量化,本文从一个工程项目(Pose Estimation)给大家介绍模型训练后静态量化的过程。

我又提问了
我要在这个上面进行16比特量化的修改,应该怎么修改?【class SNN(nn.Module):
def init(self, tau):
super().init()

    self.layer = nn.Sequential(
        layer.Flatten(),
        layer.Linear(28 * 28, 10, bias=False),
        neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
        )

def forward(self, x: torch.Tensor):
    return self.layer(x)】

在这里插入图片描述
在这里插入图片描述

=

=

=

=

=

=

=
测试【我将模型测试的部分单独写在一个程序中,应该怎么写?】
在这里插入图片描述

import torch
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import time
from main import SNN  # 确保从你的 main.py 或其他文件中正确导入 SNN 类和 encoder
from torch.utils.tensorboard import SummaryWriter

from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer

#python -m main -tau 2.0 -T 50 -device cuda:0 -b 64 -epochs 3 -data-dir \mnist -opt adam -lr 1e-3 -j 2

def test_model(model_path, data_dir, device='cuda:0', T=50,epoch_test = 3):
    start_epoch = 0
    out_dir = '.\\out_dir'
    writer = SummaryWriter(out_dir, purge_step=start_epoch)

    # 加载模型
    net = SNN(tau=2.0)  # 使用适当的参数初始化你的模型
    checkpoint = torch.load(model_path, map_location=device)
    net.load_state_dict(checkpoint['net'])
    net.to(device)
    net.eval()

    # 加载测试数据集
    test_dataset = torchvision.datasets.MNIST(
        root=data_dir,
        train=False,
        transform=ToTensor(),
        download=True
    )
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    for epoch in range(start_epoch, epoch_test):  
        # 初始化性能指标
        test_loss = 0
        test_acc = 0
        test_samples = 0
        start_time = time.time()

        encoder = encoding.PoissonEncoder()

        with torch.no_grad():
            for img, label in test_loader:
                img = img.to(device)
                label = label.to(device)
                label_onehot = F.one_hot(label, 10).float()
                out_fr = 0.
                for t in range(T):
                    encoded_img = encoder(img)  # 确保 encoder 已经定义
                    out_fr += net(encoded_img)
                out_fr = out_fr / T
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                # 注意:如果你的网络需要在每次迭代后重置状态,请在这里调用重置函数

        test_time = time.time() - start_time
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}')
        print(f'Test completed in {test_time:.2f} seconds.')

if __name__ == '__main__':
    model_path = 'logs\\T50_b64_adam_lr0.001\\checkpoint_max.pth'  # 模型路径
    data_dir = 'data'  # 数据集路径
    test_model(model_path, data_dir)
Test Loss: 0.0167, Test Accuracy: 0.9198
Test completed in 5.56 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9186
Test completed in 4.79 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9185
Test completed in 4.77 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9194
Test completed in 4.79 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9188
Test completed in 4.72 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9192
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9193
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9189
Test completed in 4.74 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9188
Test completed in 4.76 seconds.
test_samples=10000
Test Loss: 0.0168, Test Accuracy: 0.9192
Test completed in 4.74 seconds.
test_samples=10000

T=5时候,结果如下T只是影响网络看见 了什么,越长不一定越好,趋于稳定

Test Loss: 0.0205, Test Accuracy: 0.9064
Test completed in 2.04 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9050
Test completed in 1.24 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9055
Test completed in 1.23 seconds.
test_samples=10000
Test Loss: 0.0203, Test Accuracy: 0.9080
Test completed in 1.24 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9074
Test completed in 1.35 seconds.
test_samples=10000
Test Loss: 0.0206, Test Accuracy: 0.9045
Test completed in 1.37 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9058
Test completed in 1.40 seconds.
test_samples=10000
Test Loss: 0.0206, Test Accuracy: 0.9049
Test completed in 1.40 seconds.
test_samples=10000
Test Loss: 0.0205, Test Accuracy: 0.9063
Test completed in 1.47 seconds.
test_samples=10000
Test Loss: 0.0207, Test Accuracy: 0.9047
Test completed in 1.35 seconds.
test_samples=10000
量化
import torch

with open('model_params.txt', 'r') as file:
    lines = file.readlines()

with open('model_params_quantized.txt', 'w') as file:
    for line in lines:
        # 去除换行符并按逗号和空格拆分字符串
        values = line.strip().split(',')
        for val in values:
            float_val = float(val.strip())
            quantized_val = int(round(float_val * 10000))  # 量化为int32
            file.write(f"{quantized_val}\n")


量化后再把数字写入进去
import torch

# 加载原始的checkpoint_max.pth文件
model_path = 'logs\\T50_b64_adam_lr0.001\\checkpoint_max.pth'
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# 读取量化后的数据
with open('model_params_quantized.txt', 'r') as file:
    quantized_values = [int(line.strip()) for line in file.readlines()]

# 将量化后的数据写回到模型参数中
index = 0
for name, param in checkpoint['net'].items():
    if isinstance(param, torch.Tensor):
        numel = param.numel()
        quantized_param = torch.tensor(quantized_values[index:index+numel]).view(param.size())
        checkpoint['net'][name] = quantized_param
        index += numel

# 保存新的checkpoint文件
torch.save(checkpoint, 'logs\\T50_b64_adam_lr0.001\\checkpoint_max_quantized.pth')


model_state_dict = checkpoint['net']
for name, param in model_state_dict.items():
    print(f"{name}: {param}")
    print(f"{name}: {param.size()}")

量化为int32之后的准确率  下降
Test Loss: 0.1182, Test Accuracy: 0.6758
Test completed in 2.10 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6765
Test completed in 1.23 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6789
Test completed in 1.25 seconds.
test_samples=10000
Test Loss: 0.1180, Test Accuracy: 0.6785
Test completed in 1.30 seconds.
test_samples=10000
Test Loss: 0.1181, Test Accuracy: 0.6755
Test completed in 1.35 seconds.
test_samples=10000
Test completed in 1.39 seconds.
test_samples=10000
Test Loss: 0.1183, Test Accuracy: 0.6800
Test completed in 1.35 seconds.
test_samples=10000
Test Loss: 0.1185, Test Accuracy: 0.6750
Test completed in 1.38 seconds.
test_samples=10000

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/530008.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

苍穹外卖11(Apache ECharts前端统计,营业额统计,用户统计,订单统计,销量排名Top10)

目录 一、Apache ECharts【前端】 1. 介绍 2. 入门案例 二、营业额统计 1. 需求分析和设计 1 产品原型 2 业务规则 3 接口设计 2. 代码开发 3. 功能测试 三、用户统计 1. 需求分析和设计 1 产品原型 2 业务规则 3 接口设计 2. 代码开发 3. 功能测试 四、订单统…

MacOS初识SIP——解决快捷指令sh脚本报错Operation not permitted

前言 因为一些原因,设计了一套快捷指令,中间涉及到一个sh脚本的运行,通过快捷指令运行时就会报错:operation not permitted 奇怪的是在快捷指令窗口下运行一切正常,但是从其他地方直接调用,例如通过Comma…

网络安全:重要性与应对措施

1. 网络安全的重要性 随着互联网的普及和信息技术的快速发展,网络安全问题已经变得日益突出。网络攻击者可以通过各种手段窃取个人信息、破坏系统、传播病毒等,给个人和社会带来巨大的损失。因此,网络安全已经成为信息化时代的重要问题之一。…

上门服务小程序|上门服务系统|上门服务软件开发流程

在如今快节奏的生活中,上门服务小程序的需求越来越多。它们向用户提供了方便、高效的服务方式,解决了传统服务行业中的很多痛点。如果你也想开发一个上门服务小程序,以下是开发流程和需要注意的事项。 1、确定需求:在开始开发之前…

SCI一区 | Matlab实现OOA-TCN-BiGRU-Attention鱼鹰算法优化时间卷积双向门控循环单元融合注意力机制多变量时间序列预测

SCI一区 | Matlab实现OOA-TCN-BiGRU-Attention鱼鹰算法优化时间卷积双向门控循环单元融合注意力机制多变量时间序列预测 目录 SCI一区 | Matlab实现OOA-TCN-BiGRU-Attention鱼鹰算法优化时间卷积双向门控循环单元融合注意力机制多变量时间序列预测预测效果基本介绍模型描述程序…

如何将h5网页打包成iOS苹果IPA文件

哈喽,大家好呀,淼淼又来和大家见面啦,最近有很多小伙伴都被难住了,是什么问题给他们都难住了呢,许多小伙伴都说想要把h5网页打包成iOS苹果IPA文件,但是却不知道具体怎么操作,是怎么样的一个流程…

蓝桥杯每日一题(背包dp,线性dp)

//3382 整数拆分 将 1,2,4,8看成一个一个的物品&#xff0c;以完全背包的形式放入。 一维形式&#xff1a;f]0]1; #include<bits/stdc.h> using namespace std; //3382整数拆分 const int N1e610, M5e510; int mod1e9; int f[N],n; int main() {cin>>n;//转化为完…

appium+jenkins实例构建

自动化测试平台 Jenkins简介 是一个开源软件项目&#xff0c;是基于java开发的一种持续集成工具&#xff0c;用于监控持续重复的工作&#xff0c;旨在提供一个开放易用的软件平台&#xff0c;使软件的持续集成变成可能。 前面我们已经开完测试脚本&#xff0c;也使用bat 批处…

从零开始学习:如何使用Selenium和Python进行自动化测试?

安装selenium 打开命令控制符输入&#xff1a;pip install -U selenium 火狐浏览器安装firebug&#xff1a;www.firebug.com&#xff0c;调试所有网站语言&#xff0c;调试功能 Selenium IDE 是嵌入到Firefox 浏览器中的一个插件&#xff0c;实现简单的浏览器操 作的录制与回…

【微服务】------微服务架构技术栈

目前微服务早已火遍大江南北&#xff0c;对于开发来说&#xff0c;我们时刻关注着技术的迭代更新&#xff0c;而项目采用什么技术栈选型落地是开发、产品都需要关注的事情&#xff0c;该篇博客主要分享一些目前普遍公司都在用的技术栈&#xff0c;快来分享一下你当前所在用的技…

PS入门|如何使用“主体”功能进行抠图?

前言 前段时间讲到给各种图标和LOGO抠图的办法&#xff0c;分别使用的是 钢笔工具蒙版 PS入门&#xff5c;规规矩矩的图形怎么抠出来&#xff1f; 魔棒工具蒙版 PS入门&#xff5c;黑白色的图标怎么抠成透明背景 色阶蒙版 PS入门&#xff5c;目标比较复杂&#xff0c;但背景…

数据中台系统架构的探索之路:生产管理企业的数字化转型引擎-亿发

当前制造业面临着诸多问题。 1、系统繁杂&#xff0c;涉及多个子系统和应用&#xff0c;导致信息孤岛和数据孤立现象普遍存在。 2、其次是业务流程冗长&#xff0c;造成生产过程中的信息传递和协同困难&#xff0c;影响效率和质量。 3、数据应用问题也十分突出&#xff0c;包…

android平台下opencv的编译--包含扩展模块

由于项目需要使用安卓平台下opencv的扩展库&#xff0c;对于通用的opencv库&#xff0c; opencv官网提供了android的SDK 但未能提供扩展库&#xff0c;因此需要自己进行源码编译。本文记录android平台下opencv及其扩展库的交叉编译。 前提&#xff1a;主机已安装android-ndk交…

mybatis-plus与mybatis同时使用别名问题

在整合mybatis和mybatis-plus的时候发现一个小坑&#xff0c;单独使用mybatis&#xff0c;配置别名如下&#xff1a; #配置映射文件中指定的实体类的别名 mybatis.type-aliases-packagecom.jk.entity XML映射文件如下&#xff1a; <update id"update" paramete…

vue2 使用vue-org-tree demo

1.安装 npm i vue2-org-tree npm install -D less-loader less安装 less-loader出错解决办法&#xff0c;直接在package.json》devDependencies下面加入less和less-loader版本&#xff0c;然后执行npm i &#xff0c;我用的nodejs版本是 16.18.0&#xff0c;“webpack”: “^4…

Redis的双写一致性问题

双写一致性问题 1.先删除缓存或者先删除数据库都可能出现脏数据。 2.删除两次缓存&#xff0c;可以在一定程度上降低脏数据的出现。 3.延时是因为数据库一般采用主从分离&#xff0c;读写分离。延迟一会是让主节点把数据同步到从节点。 1.读写锁保证数据的强一致性 因为一般放…

java Web在线考试管理系统用eclipse定制开发mysql数据库BS模式java编程jdbc

一、源码特点 JSP 在线考试管理系统是一套完善的web设计系统&#xff0c;对理解JSP java 编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,eclipse开发&#xff0c;数据库为Mysql5.0&#xff0c;使…

DDoS攻击包含哪些层面?如何防护?

DDoS攻击&#xff08;分布式拒绝服务攻击&#xff09;是一种通过向目标服务器发送大量流量或请求&#xff0c;以使其无法正常工作的网络攻击手段。DDoS攻击涉及多个层面&#xff0c;在实施攻击时对网络基础架构、网络协议、应用层等进行攻击。下面将详细介绍DDoS攻击的层面。 1…

CentOS 7 升级 5.4 内核

MatrixOne 推荐部署使用的操作系统为 Debian 11、Ubuntu 20.04、CentOS 9 等 Kernel 内核版本高于 5.0 的操作系统。随着 CentOS 7 的支持周期接近尾声&#xff0c;社区不少小伙伴都在讨论用以替换的 Linux 操作系统&#xff0c;经过问卷调查&#xff0c;我们发现小伙伴们的操作…

eclipse .project

.project <?xml version"1.0" encoding"UTF-8"?> <projectDescription> <name>scrm-web</name> <comment></comment> <projects> </projects> <buildSpec> <buil…
最新文章