【pytorch】深度学习模型调参策略(五):采用贝叶斯工具进行最优参数搜索及最佳步数确认

目录

  • 1.如何决定是否应用某个新的超参数配置
  • 2.参数优化工具optuna确定最终最优配置
    • 为什么在调整的探索阶段使用准随机搜索而不是更复杂的黑盒优化算法?
    • optuna库简介
    • pytorch实现代码
    • 搜索参数详解
    • 输出结果
  • 3.确定每次训练运行的步数
    • 使用学习率扫描选择max_train_steps初始候选值

1.如何决定是否应用某个新的超参数配置

简而言之,当我们决定对模型或训练过程进行更改或采用新的超参数配置时,我们需要考虑到结果中的不同变化来源。
对于试图改进模型的情况,我们可能会发现某个候选变化的验证误差最初比现有配置更好,但重复实验后并没有一致的优势。这些不一致的结果可能来自以下几个方面:

训练过程方差、重新训练方差或试验方差:使用相同的超参数但不同的随机种子进行训练运行之间的变化。

超参数搜索方差或研究方差:由于选择超参数的过程而导致结果的变化。

数据收集和抽样方差:从任何随机拆分到训练、验证和测试数据或由于训练数据生成过程而产生的方差。
为了决定是否采用候选变化,我们需要运行最佳试验N次,以了解试验之间的运行方差。我们应该只采用能够产生优势的变化,而不会增加任何复杂性。

2.参数优化工具optuna确定最终最优配置

我们完成对搜索空间的探索并确定需要调整的超参数后,贝叶斯优化工具是一个非常有吸引力的选择。此时,我们的重点将从学习有关调整问题的更多信息转变为生成一个最佳配置,以供启动或其他用途。因此,应该有一个经过精细调整的搜索空间,足够大,舒适地包含最佳观察试验周围的局部区域并已充分采样。我们的探索工作应该已经揭示了需要调整的最重要的超参数(以及它们的合理范围),可以使用它们来构建一个搜索空间,以尽可能大的调整预算进行最终自动调整研究。由于我们不再关心最大化对调整问题的洞察力,许多准随机搜索的优点不再适用,应该使用贝叶斯优化工具来自动找到最佳的超参数配置。开源的Vizier实现了各种复杂的算法来调整机器学习模型,包括贝叶斯优化算法。如果搜索空间包含大量发散点(损失值为NaN或比平均值高出许多标准差),则使用适当处理发散试验的黑盒优化工具非常重要。同时,我们还需要考虑在测试集上检查性能。原则上,我们甚至可以将验证集合并到训练集中,并使用贝叶斯优化找到的最佳配置进行重新训练。但是,仅当不会有特定工作负载的未来启动时才适用此方法(例如一次性的Kaggle竞赛)。

为什么在调整的探索阶段使用准随机搜索而不是更复杂的黑盒优化算法?

我们更喜欢基于低差异序列的准随机搜索,而不是更复杂的黑盒优化工具,因为它作为迭代调整过程的一部分,旨在最大化对调整问题的洞察力(我们称之为“探索阶段”)。贝叶斯优化和类似的工具更适合于开发阶段。
基于随机移位的低差异序列的准随机搜索可以看作是“抖动、洗牌的网格搜索”,因为它均匀但随机地探索给定的搜索空间,并且比随机搜索更广泛地扩展搜索点。
准随机搜索相对于更复杂的黑盒优化工具(如贝叶斯优化、进化算法)的优点包括:

非自适应采样搜索空间使得可能通过后续分析改变调整目标而不重新运行实验。

准随机搜索的一致性和统计可重复性。

准随机搜索对搜索空间的均匀探索使得更容易推理结果以及它们对搜索空间的建议。
运行不同数量的试验时,准随机搜索(或其他非自适应搜索算法)与自适应算法相比不会产生统计学上的不同结果。
更复杂的搜索算法可能不会始终正确处理不可行点,尤其是如果它们没有设计神经网络超参数调整。
准随机搜索简单易用,在许多调整试验将并行运行时表现尤为出色。
如果计算资源只允许并行运行少量试验,并且我们能够承担运行许多试验的序列,则尽管使调整结果更难以解释,贝叶斯优化变得更具吸引力。

optuna库简介

optuna是一个用于超参数优化的Python库。它提供了一种简单、高效的方法来搜索模型的最优超参数组合,以最大化或最小化目标函数(例如验证集准确率或损失函数)。optuna支持多种搜索算法,包括随机搜索、网格搜索和基于贝叶斯优化的搜索。它还支持并行化搜索,并提供了可视化工具来帮助用户分析搜索结果。使用optuna可以帮助用户节省大量时间和精力,从而更快地找到最优的超参数组合,并提高模型的性能。optuna不仅可以用于机器学习领域,也可以用于优化任何需要调整超参数的任务。

在默认情况下,Optuna 使用的是 TPE 算法进行贝叶斯优化搜索。该算法将搜索空间中的概率分布建模为两个条件概率分布:先验概率分布和似然概率分布。先验概率分布是指在未进行试验之前,对超参数的概率分布的假设。似然概率分布是指在已有的试验结果的基础上,对超参数的概率分布的推断。通过不断更新这两个概率分布,TPE 算法能够更快地收敛到全局最优解,并适应不同类型的搜索空间。
需要注意的是,虽然贝叶斯优化算法能够更快地收敛到全局最优解,但是其搜索效率与搜索空间的形状、维度、先验分布等因素都有关系。在使用 Optuna 进行超参数搜索时,我们应该根据具体的问题和实验情况,选择合适的搜索算法和搜索空间,以达到最优的搜索效果。
安装:

pip3 install optuna

pytorch实现代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import optuna
# 定义超参数搜索空间
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def objective(trial):
    # 搜索学习率
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
    # 搜索动量因子
    momentum = trial.suggest_uniform('momentum', 0.0, 1.0)
    # 搜索权重衰减因子
    weight_decay = trial.suggest_loguniform('weight_decay', 1e-5, 1e-1)
    # 搜索网络结构超参数
    num_layers = trial.suggest_int('num_layers', 1, 5)
    hidden_size = trial.suggest_categorical('hidden_size', [32, 64, 128, 256])
    dropout = trial.suggest_uniform('dropout', 0.0, 0.5)
    # 加载CIFAR-10数据集
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)
    # 定义模型
    model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(32 * 32 * 3, hidden_size),
        nn.ReLU(),
        nn.Dropout(dropout),
        *(nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(dropout)) for _ in range(num_layers - 1)),
        nn.Linear(hidden_size, 10),
    )
    if torch.cuda.device_count() > 1:
        #如果gpu设备数目大于1个,同时使用两个GPU
        model = nn.DataParallel(model, [0, 1])
    model=model.to(device)
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    # 训练模型
    max_accuracy = float('-inf')
    accuracy_list = []
    for epoch in range(10):
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        # 在测试集上测试模型性能
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        trial.report(accuracy, epoch)
        max_accuracy = max(max_accuracy, accuracy)
        accuracy_list.append(accuracy)
        if len(accuracy_list) > 5:
            accuracy_list.pop(0)
        max_accuracy_over_last_5_epochs = max(accuracy_list)
        # 提前停止条件:如果最近5个epoch的性能较最优值都没有提高,则停止训练
        if max_accuracy > max_accuracy_over_last_5_epochs:
           break
    return accuracy
if __name__ == '__main__':
    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=100)
    print('Best trial:')
    trial = study.best_trial
    print('  Value: {}'.format(trial.value))
    print('  Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))

上述代码使用了 Optuna 提供的 suggest 方法来搜索超参数。具体来说,suggest_loguniform 方法用于搜索一个对数均匀分布中的数值,suggest_uniform 方法用于搜索一个均匀分布中的数值,suggest_int 方法用于搜索一个整数范围中的数值,suggest_categorical 方法用于搜索一组离散的数值。下面分别解释这些方法的用法:

搜索参数详解

trial.suggest_loguniform(‘lr’, 1e-5, 1e-1):搜索一个对数均匀分布中的数值,用于学习率超参数 lr。对数均匀分布是指在对数尺度上均匀分布的分布,它通常用于搜索连续的正数超参数。这里我们搜索的范围是 [1e-5, 1e-1]。

trial.suggest_uniform(‘momentum’, 0.0, 1.0):搜索一个均匀分布中的数值,用于动量因子超参数 momentum。均匀分布是指在给定的范围内均匀分布的分布,它通常用于搜索连续的超参数。这里我们搜索的范围是 [0.0, 1.0]。

trial.suggest_loguniform(‘weight_decay’, 1e-5, 1e-1):搜索一个对数均匀分布中的数值,用于权重衰减因子超参数 weight_decay。这里我们搜索的范围是 [1e-5, 1e-1]。

trial.suggest_int(‘num_layers’, 1, 5):搜索一个整数范围中的数值,用于网络层数超参数 num_layers。这里我们搜索的范围是 [1, 5]。

trial.suggest_categorical(‘hidden_size’, [32, 64, 128, 256]):搜索一组离散的数值,用于隐藏层神经元个数超参数 hidden_size。这里我们搜索的数值是 [32, 64, 128, 256] 中的一个。

trial.suggest_uniform(‘dropout’, 0.0, 0.5):搜索一个均匀分布中的数值,用于 dropout 超参数 dropout。这里我们搜索的范围是 [0.0, 0.5]。

在每次试验中,Optuna 会根据超参数搜索空间中的分布,为每个超参数生成一个候选值,供模型训练使用。模型训练过程中,我们使用这些超参数的候选值来训练模型,并将模型在验证集上的表现作为试验的结果报告给 Optuna。Optuna 根据这些结果,不断更新超参数搜索空间中的概率分布,以提高搜索效率,并逐步找到最优的超参数组合。

输出结果

在这里插入图片描述
最终输出:

[I 2023-03-28 08:27:02,336] Trial 99 finished with value: 0.5207 and parameters: {'lr': 0.0413617890693661, 'momentum': 0.6396519061675595, 'weight_decay': 3.840413828120234e-05, 'num_layers': 1, 'hidden_size': 256, 'dropout': 0.023284222799990092}. Best is trial 66 with value: 0.5493.
Best trial:
  Value: 0.5493
  Params: 
    lr: 0.07048135305238855
    momentum: 0.45078184234468277
    weight_decay: 6.0485982461392196e-05
    num_layers: 2
    hidden_size: 256
    dropout: 0.017699021560325236

3.确定每次训练运行的步数

工作负载可以分为两种:计算受限和非计算受限的。当训练是计算受限的时候,训练受到的限制是我们愿意等待的时间,而不是我们有多少训练数据或其他因素。
在这种情况下,如果我们能够以某种方式更长时间或更有效地训练,我们应该会看到更低的训练损失,并且通过适当的调整,可以改善验证损失。换句话说,加快训练等同于改善训练,而“最佳”训练时间始终是“尽可能长”。
尽管如此,只因为工作负载受计算限制,并不意味着训练更长/更快是改善结果的唯一方法。
当训练不受计算限制时,我们可以承受尽可能长时间的训练,但在某些时候,训练时间更长并没有太大的帮助(甚至会导致问题的过拟合)。
在这种情况下,我们应该期望能够训练到非常低的训练损失,直到训练时间更长可能会略微降低训练损失,但不会实质性地降低验证损失。特别是当训练不受计算限制时,更慷慨的训练时间预算可以使调整更容易,特别是在调整学习率衰减计划时,因为它们与训练预算有着特别强的相互作用。
换句话说,非常吝啬的训练时间预算可能需要调整到完美的学习率衰减计划,以实现良好的错误率。
无论给定的工作负载是否受计算限制,增加梯度的方差(跨批次)的方法通常会导致训练进度变慢,因此可能会增加达到特定验证损失所需的训练步数。高梯度方差可能是由以下原因引起:

(1)使用较小的批量大小
(2)添加数据增强
(3)添加某些类型的正则化(例如dropout)

使用学习率扫描选择max_train_steps初始候选值

该算法假定可以使用恒定的学习率计划不仅“完美地”拟合训练集,还可以找到一个最优的max_train_steps值。首先,找到任何最优配置(具有某个max_train_steps值)并将其用作起点。然后,在不使用数据增强和正则化的情况下,运行恒定的学习率扫描,在N步训练每个试验的情况下网格搜索学习率。扫描中最快试验达到完全训练性能所需的步骤数是max_train_steps的初始猜测。但需要注意的是,不良的搜索空间可能会导致自我欺骗。因此,至少应检查最优搜索空间是否足够宽广。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/4117.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

设置鼠标右键打开方式,添加IDEA的打开方式

一、问题描述 已下载IDEA,但是右键打开之前保存的项目文件,无法显示以IDEA方式打开。 二、解决步骤 1. 打开注册表 winR键输入regedit 2、查找路径为计算机\HKEY_LOCAL_MACHINE\SOFTWARE\Classes\Directory\shell (我找了半天没看到Class…

在芯片设计行业,从项目的初期到交付,不同的岗位的工程师主要负责什么?

大家都知道在芯片设计行业,项目是至关重要的一环。从项目的初期到交付,不同的岗位的工程师在项目的各环节主要负责什么?他们是怎样配合的?下面看看资深工程师怎么说。 一个项目,从初期到交付的过程是比较漫长的。我们知道最早的时候&#…

deskvideosys 办公行为管理软件的部署架构

deskvideosys 办公行为管理软件服务器端使用的是 B/S 架构,采用 golangvue 框架来编程,agent 端直接使用的是 vc编程框架,然后通过tcp协议连接服务器端,所以deskvideosys架构 可以作为终端安全管理,上网行为管理&#…

小程序 table组件

最近有在小程序中用table的需求,但是没有找到有符合要求的组件,所以自己弄了一个,能满足基本需求。 组件下载:https://download.csdn.net/download/weixin_67585820/85047405 引入 "usingComponents": {"table": "…

基于springboot和Web实现社区医院管理服务系统【源码+论文】分享

基于springboot和Web的社区医院管理服务系统演示开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Mave…

记录--Vue 3 中的极致防抖/节流(含常见方式防抖/节流)

这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 今天给大家带来的是Vue 3 中的极致防抖/节流(含常见方式防抖/节流)这篇文章,文章中不仅会讲述原来使用的防抖或节流方式,还会带来新的一种封装方式,使用起来更简单、…

diffusion 之 cifar/mnist 数据集

diffusion 之 mnist 数据集mnist数据集ddpm/script_utils.pyscripts/train_mnist.py展示采样结果代码出处:https://github.com/abarankab/DDPMwandb的问题解决方法: step1: 按照这个https://blog.csdn.net/weixin_43164054/article/details/1…

基于kubernetes部署gitlab

目录前提下载镜像部署服务前提 已经搭建完kubernets集群并可提供服务。 下载镜像 去docker hub 下载具体版本镜像,当使用最新版本时,也建议具体制定版本号,而不是使用latest. 如 gitlab/gitlab-ce:15.10.0-ce.0 当然可以pull到本地&#x…

Linux拒绝俄罗斯开发者合入

最近在Linux社区看到这样的信息https://lore.kernel.org/all/20230314103316.313e5f61kernel.org/我们不愿意接受你们的补丁。关于上面的内容,看到有一篇这样的文章https://www.phoronix.com/news/Linux-STMAC-Russian-Sanctions由于美国对俄罗斯实施制裁&#xff0…

一次内存泄露排查

前因: 因为测试 长时间压测导致 接口反应越来越慢,甚至 导致服务器 崩溃 排查过程 1、top 查看是 哪个进程 占用 内存过高 2、根据 进程 id 去查找 具体是哪个 程序的问题 ps -ef| grep 41356 可以看到 具体的 容器位置 排查该进程 对象存活 状态…

大数据学习路线图(2023完整高清版超详细)

送福利了!超详细的大数据学习路线图来啦,2023版是首发哟!大数据学习路线图分为7个阶段,包含: 数据仓库基础-->Linux &Hadoop生态-->Hadoop-->数据仓库与ETL技术-->BI数据分析与可视化-->自研数据仓…

计算机科学与技术专业-大三-学年设计-题目

大三-学年设计题目 西南大学 计算机与信息科学学院 周竹荣 课程概述 学年设计是重要的综合性设计训练,安排在修完相关专业平台课后进行。旨在培养学生综合运用所学的基础理论和专业知识,分析、解决实际问题的能力,理论联系实际。是一次系统的…

HTML 标签和属性

一些标签 单双标签 双标签。双标签指标签是成对出现的&#xff0c;也就是有一个开始标签和一个结束标签&#xff0c;开始标签用 <标签名> 表示&#xff0c;结束标签用 </标签名> 表示&#xff0c;只有一对标签一起使用才能表示一个具体的含义。例如 <html>&…

关于线程池你了解些什么?

前言学习线程池的思维导图线程池是什么?它有什么用?虽然线程比进程更轻量级,但是每个进程所占的资源空间是有限,如果我们频繁创建和销毁线程也会消耗很多CPU资源,那么我们该如何解决这个问题呢?官方解释:线程池是一种多线程处理形式,其处理过程可以将多个任务添加到阻塞队列…

电子招标采购系统:营造全面规范安全的电子招投标环境,促进招投标市场健康可持续发展

营造全面规范安全的电子招投标环境&#xff0c;促进招投标市场健康可持续发展 传统采购模式面临的挑战 一、立项管理 1、招标立项申请 功能点&#xff1a;招标类项目立项申请入口&#xff0c;用户可以保存为草稿&#xff0c;提交。 2、非招标立项申请 功能点&#xff1a;非招标…

Google巨大漏洞让Win10、11翻车,小姐姐马赛克白打了

早年间电脑截图这项技能未被大多数人掌握时&#xff0c;许多人应该都使用过手机拍屏幕这个原始的方式。 但由于较低的画面质量极其影响其他用户的观感&#xff0c;常常受到大家的调侃。 但到了 Win10、11 &#xff0c;预装的截图工具让门槛大幅降低。 WinShiftS 就能快速打开…

专业护眼灯什么牌子好?分享最专业护眼灯品牌排行

在日常生活中&#xff0c;照明灯具为人们提供了很多便利&#xff0c;但是对于长时间在灯光下用眼的学生跟办公族来说&#xff0c;容易导致用眼疲劳&#xff0c;甚至头晕等现象&#xff0c;所以现在普遍许多家庭都必备有护眼灯&#xff0c;能对眼睛起到缓解疲劳的作用&#xff0…

Docker安装Mysql集群(主从复制)

Docker安装Mysql集群(主从复制) 配置阿里云镜像 sudo vim /etc/docker/daemon.json插入如下镜像 {"registry-mirrors": ["https://sdiz8d27.mirror.aliyuncs.com"] }重启docker sudo systemctl daemon-reloadsudo systemctl restart docker保证images有…

python

好处 相对其他编程语言 比较简洁丰富的第三方库【做爬虫、机器学习、深度学习】 numpypandasmatplotlit用处 数据分析web开发游戏开发AI【比较广泛】安装部署python环境 1.官网下载python安装包【原生部署】 官网&#xff1a;python.org2.安装anaconda 1.自带python环境2.有丰富…

Mybatis-Plus进阶使用

一、逻辑删除 曾经我们写的删除代码都是物理删除。 逻辑删除&#xff1a;删除转变为更新 update user set deleted1 where id 1 and deleted0 查找: 追加 where 条件过滤掉已删除数据,如果使用 wrapper.entity 生成的 where 条件也会自动追加该字段 查找: select id,name,dele…
最新文章