深度学习模型笔记

加载和保存模型参数

保存模型参数

net = MLP()
# 此处省略训练过程,在训练之后,保存模型参数
# 保存字典格式的模型参数,模型参数名
torch.save(net.state_dict(), 'mlp.params') 

加载模型参数

clone = MLP()
# 加载模型参数
clone.load_state_dict(torch.load('mlp.params'))
# 进入评估模式
clone.eval()
# 开始推理
Y_clone = clone(X)

当训练时使用了nn.DataParallel时,正确的加载方法

nn.DataParallel是一种用于多GPU计算的封装模块。它可以将模型和数据分布到多个GPU上进行并行计算,以加速训练过程。在PyTorch中,nn.DataParallel可以接受一个已存在的模型作为输入,并将模型和数据分布到多个GPU上进行并行计算。这种并行计算可以显著减少训练时间,并提高模型的收敛速度。在使用nn.DataParallel时,需要将模型和数据传递给封装后的nn.DataParallel对象,并在训练过程中使用封装后的模型进行前向传播和反向传播。

直接使用上面的加载方式会报错

RuntimeError: Error(s) in loading state_dict for Sequential:

该错误通常与使用了nn.DataParallel进行训练有关

是指模型中的参数key中字符串与torch.load获取的key中字符串不匹配

因此,我们只需要修改torch.load获取的dict,令其匹配。

 

例如:

我torch.save时,参数key中字符串前自动添加了'module.'

因此,在torch.load后,需要去掉'module.'
model = Model()
model_para_dict_temp = torch.load('xxx.pth')
model_para_dict = {}
for key_i in model_para_dict_temp.keys():
    model_para_dict[key_i[7:]] = model_para_dict_temp[key_i]  # 删除掉前7个字符'module.'
del model_para_dict_temp
model.load_state_dict(model_para_dict)

pytorch加载时出现“RuntimeError: Error(s) in loading state_dict for Sequential:”时的解决方案
https://www.cnblogs.com/s-tyou/p/16514693.html

kaggle CIFAR-10图像分类笔记

对数据集的处理

在这里插入图片描述

train: 训练集
train_valid: 训练集中的验证集,每个epoch结束后使用它进行验证。
test: 模型所有轮训练结束后,使用它来进行测试。
valid: 验证集,评估在未见过的数据上的数据集。

损失函数

loss = nn.CrossEntropyLoss(reduction='none')

reduction的字面意思是“减少”或“缩减”,指的是对损失进行某种操作以减小或简化损失的计算。

当reduction='none'时,表示不进行任何归约操作,即每个样本的损失都单独计算并返回。

当reduction='sum'时,表示对所有样本的损失进行求和操作,得到一个标量值,表示所有样本损失的总和。

当reduction='elementwise_mean'时,表示对所有样本的损失进行平均操作,得到一个标量值,表示所有样本损失的平均值。

lr_period, lr_decay = 4, 0.9 的含义

lr_period = 4: lr_period表示学习率更新的周期。设置为4可能意味着在每4个epoch后更新学习率。

lr_decay = 0.9: lr_decay表示学习率的衰减率。这里设置为0.9,意味着每次更新学习率时,学习率会乘以0.9,从而逐渐降低学习率。

图像分类的通用训练函数(训练并保存最优训练结果参数)

该函数的输入为:
train(网络,训练集的dataloader,评估集的dataloader,训练轮数,学习率,权重衰减,设备,更新学习率的周期,学习率衰减)
输出为训练结果的图和损失和准确率的值。

def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,lr_decay):
    trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss', 'train acc']
    if valid_iter is not None:
        legend.append('valid acc')
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=legend)
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    # 设定一个参数 统计最优的训练准确率
    best_train_acc = 0
    for epoch in range(num_epochs):
        net.train()
        metric = d2l.Accumulator(3)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(net, features, labels,loss, trainer, devices)
            metric.add(l, acc, labels.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[2],None))
        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
            animator.add(epoch + 1, (None, None, valid_acc))
        scheduler.step()
    train_acc = metric[1] / metric[2]
    measures = (f'train loss {metric[0] / metric[2]:.3f}, 'f'train acc {train_acc:.3f}')
    if train_acc>best_train_acc:
        best_train_acc = train_acc
        # 更新模型最优准确率的参数
        torch.save(net.state_dict(),'weights/best_train_acc_params')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
                     f' examples/sec on {str(devices)}')
    print(f'best_train_acc={best_train_acc}')

图像分类模型训练代码

训练并统计训练的时间

'''先训练20轮看看有没有问题'''
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
# lr_period表示学习率的更新周期,设置为4表示,每四个epoch后更新学习率, lr_decay表示权重衰减,每次更新学习率,学习率都会乘0.9
lr_period, lr_decay, net = 4, 0.9, get_net()

# 开始计时
start_time = time.time()
# 开始训练
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,lr_decay)
# 结束计时
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:
    print(f'{round(run_time,2)}s')
else:
    print(f'{round(run_time/60,2)}minutes')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/104332.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C++】Map和Set -- 详解

一、关联式容器 在初阶阶段&#xff0c;我们已经接触过 STL 中的部分容器&#xff0c;比如&#xff1a;vector、list、deque、forward_list&#xff08;C11&#xff09;等&#xff0c;这些容器统称为 序列式容器 &#xff0c;因为其底层为线性序列的数据结构&#xff0c;里面存…

计算机视觉实战项目3(图像分类+目标检测+目标跟踪+姿态识别+车道线识别+车牌识别+无人机检测+A*路径规划+单目测距与测速+行人车辆计数等)

车辆跟踪及测距 该项目一个基于深度学习和目标跟踪算法的项目&#xff0c;主要用于实现视频中的目标检测和跟踪。该项目使用了 YOLOv5目标检测算法和 DeepSORT 目标跟踪算法&#xff0c;以及一些辅助工具和库&#xff0c;可以帮助用户快速地在本地或者云端上实现视频目标检测和…

双向电平电压转换器TXS0102DCTR应用电路设计

1、TXS0102简介 TXS0102DCTR是一个2位双向电压电平转换器&#xff0c;主要用途是与数据I/O&#xff08;例如I2C或1-wire&#xff09;上的开漏驱动器连接&#xff08;其中数据是双向的且无可用的控制信号&#xff09;&#xff0c;在混合电压系统之间建立数字开关兼容性。它使用…

保存 uboot图像配置

一. 简介 本文学习如何保存经过图像配置&#xff0c;与加载 自己的配置文件。 之前几篇文章学习了&#xff1a;uboot 经过图形化配置 dns 命令功能。地址如下&#xff1a; uboot通过图像化界面配置 dns命令-CSDN博客 uboot通过图像化界面配置 dns命令验证-CSDN博客 二. 保…

【C++基础入门】42.C++中同名覆盖引发的问题

一、父子间的赋值兼容 子类对象可以当作父类对象使用&#xff08;兼容性) 子类对象可以直接赋值给父类对象子类对象可以直接赋值给父类对象父类指针可以直接指向子类对象父类引用可以直接引用子类对象 下面看一个子类对象兼容性的代码&#xff1a; #include <iostream>…

大模型在数据分析场景下的能力评测

“你们能对接国产大模型吗&#xff1f;” “开源的 LLaMA 能用吗&#xff0c;中文支持怎么样&#xff1f;” “私有化部署和在线服务哪个更合适&#xff1f;” 自 7 月 14 日发布 AI 数智助理 Kyligence Copilot 后&#xff0c;我们收到了很多类似上面的咨询&#xff0c;尤其…

如何处理单据保存/审核时提示:“更新即时库存时,基本单位数量与辅单位数量为一正一负,即时库存更新不成功

文章目录 如何处理单据保存/审核时提示:“更新即时库存时,基本单位数量与辅单位数量为一正一负,即时库存更新不成功问题描述前提问题分析&#xff1a;解决方案 如何处理单据保存/审核时提示:“更新即时库存时,基本单位数量与辅单位数量为一正一负,即时库存更新不成功 问题描述…

使用C# RDLC环境搭建

搭建C# RDLC环境 在vs环境中&#xff0c;菜单扩展>管理扩展 用来打开报表文件的 用来新建报表文件的 搜索Microsoft Reporting Services Projects 选择第一个进行下载 安装完以上两个即可进行报表文件的创建和预览 reportview组件 推荐nuget安装&#xff1a;Install-…

前后端交互—跨域与HTTP

跨域 代码下载 同源策略 同源策略(英文全称 Same origin policy)是浏览器提供的一个安全功能。 MDN 官方给定的概念:同源策略限制了从同一个源加载的文档或脚本如何与来自另一个源的资源进行交互。这 是一个用于隔离潜在恶意文件的重要安全机制。 通俗的理解:浏览器规定&a…

拆贡献+统计非法可能不统计非法贡献:ARC150D

https://atcoder.jp/contests/arc150/tasks/arc150_d 先拆贡献成每个点&#xff0c;然后就只需要考虑这条链上的情况了 我们现在要求的是&#xff1a; 在所有点选完之前&#xff0c;最后一个点被选了多少次 我们发现这很难做&#xff0c;但有个性质&#xff1a; 在所有点选…

CLion使用SSH远程连接Linux服务器

最近要一直用实验室的服务器写Linux下的C代码, 本来一直用VScode(SSH)连接服务器, 但是我以前还是用JetBrains的IDE用的多, 毕竟他家的IDE代码提示和功能在某些细节上更加丰富。所以这次我使用了Clion里的远程连接(同样也是SSH工具)连接上了我的服务器, 实现了和VScode上同样的…

计算机网络-计算机网络体系结构-应用层

目录 一、网络应用模型 客户/服务器模型(Client/Server) P2P模型(Peer-to-peer) 二、域名解析系统(DNS) 域名 域名服务器 解析过程 三、文件传输协议(FTP) FTP控制原理 四、电子邮件 组成结构 协议 SMTP MIME POP3 IMAP 五、万维网和HTTP协议 概述 HTTP 报…

Python---for循环中的两大关键字break和continue

之前在while循环中&#xff0c;也是用到两个关键字。 相关链接&#xff1a; 所以&#xff0c;在循环结构中都存在两个关键字&#xff1a;break和continue break&#xff1a;主要功能是终止整个循环 break&#xff1a;代表终止整个循环结构 continue&#xff1a;代表中止当…

selenium元素定位之xpath

一、找父级节点parent xpath&#xff1a;//span[text()保存]/parent::button 说明&#xff1a;先找到span标签&#xff0c;再找到父级button 一、找同级的上方标签preceding-sibling xpath&#xff1a;//span[text()保存]/parent::button/preceding-sibling::button[1] 说明…

基于Java的足球赛会管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09; 代码参考数据库参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者&am…

计算机网络——计算机网络体系结构(4/4)-计算机网络体系结构中的专用术语(实体、协议、服务,三次握手‘三报文握手’、数据包术语)

目录 分类一——实体 实体 对等实体 分类二——协议 协议 协议的三要素 分类三——服务 服务 服务访问点 数据包术语 计算机网络体系结构中的专用术语 本篇所讲的专用术语来源于OSI的七层协议体系结构&#xff0c;但也适用于TCP/IP的四层体系结构和五层协议原理体系…

JVM(二)

一,运行时数据区 Java虚拟机在运行Java程序过程中管理的内存区域,称之为运行时数据区。 1.1 程序计数器 程序计数器(Program Counter Register)也叫PC寄存器,每个线程会通过程序计数器记录当前要执行的的字 节码指令的地址。 在加载阶段,虚拟机将字节码文件中的指令读…

Qt篇——子控件QLayoutItem与实际控件的强转

方法&#xff1a;使用qobject_cast<QLabel*>() &#xff0c;将通过itemAt(i)获取到的子控件(QLayoutItem)强转为子控件的实际类型(如QLineEdit、QLabel等)。 场景举例&#xff1a; QLabel *label qobject_cast<QLabel*>(ui->horizontalLayout_40->itemAt(0…

最新SQL注入漏洞修复建议

点击星标&#xff0c;即时接收最新推文 本文选自《web安全攻防渗透测试实战指南&#xff08;第2版&#xff09;》 点击图片五折购书 SQL注入漏洞修复建议 常用的SQL注入漏洞的修复方法有两种。 1&#xff0e;过滤危险字符 多数CMS都采用过滤危险字符的方式&#xff0c;例如&…

Biotech - 环状 mRNA 的 LNP 递送系统 与 成环框架

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/133992971 环状 RNA&#xff08;或 circRNA &#xff09;是一种单链 RNA&#xff0c;与线性 RNA 不同&#xff0c;形成一个共价闭合的连续环。在环…
最新文章