超参数调优:网格搜索,贝叶斯优化(optuna)详解

超参数调优:网格搜索,贝叶斯优化(optuna)详解

  • 数据科学:Scipy、Scikit-Learn笔记
  • LightGBM原生接口和Sklearn接口参数详解
  • XGBoost原生接口和Sklearn接口参数详解
  • 网格搜索
    • 参数
      • 1.estimator
      • 2.param_grid
      • 3.scoring
      • 4.fit_params
      • 5.n_jobs
      • 6.refit
      • 7.cv
      • 8.verbose
      • 9.return_train_score:
    • 属性
      • 1.best_estimator_
      • 2.best_score_
      • 3.best_params_
      • 4.best_index_
      • 5.cv_results_
  • 贝叶斯优化(optuna)

数据科学:Scipy、Scikit-Learn笔记

LightGBM原生接口和Sklearn接口参数详解

XGBoost原生接口和Sklearn接口参数详解

网格搜索

GridSearchCV 是一个在 scikit-learn 库中用于执行网格搜索(grid search)参数调优的方法。网格搜索是一种通过遍历预定义的参数网格来确定机器学习模型最佳超参数组合的技术。它对给定的参数值集合中的所有可能组合进行训练和验证,最终选择具有最高交叉验证得分的参数配置。以下是对 GridSearchCV 类主要参数和属性的详细解析:

参数

1.estimator

  • 类型: estimator object

说明: 需要进行参数调优的基础学习器(模型)。这可以是任何实现了 fit() 方法的 scikit-learn estimator,如 LogisticRegression、SVM、RandomForestClassifier 或 GradientBoostingRegressor 等。

2.param_grid

  • 类型: dict or list of dictionaries

说明: 指定要尝试的参数值的网格。它可以是一个字典,其中键是模型参数的名称,值是该参数可能取值的列表;也可以是一个字典列表,表示多个参数组合的网格搜索。例如:

param_grid = {
    'parameter_name1': [value1, value2, ...],
    'parameter_name2': [valueA, valueB, ...],
    # 更多参数...
}

3.scoring

  • 类型: string, callable, list/tuple, dict or None, default=None

说明: 用于评估模型性能的度量标准。它可以是内置评分字符串(如 ‘accuracy’、‘roc_auc’、‘neg_mean_squared_error’ 等),自定义评分函数,或者对于多输出任务,可以是列表或字典形式的多个评分标准。若为 None,则使用 estimator 的默认评分方法。

4.fit_params

  • 类型: dict, optional

说明: 可选的关键字参数传递给 estimator.fit() 方法。这些参数不会被网格搜索所改变,但可用于控制模型拟合过程,如设置正则化参数的随机种子 (random_state) 或指定特征重要性计算方式等。

5.n_jobs

  • 类型: int, default=1

说明: 并行处理数。若 -1,使用所有可用的CPU核心;若 1(默认),则顺序处理。大于 1 的整数表示使用相应数量的CPU核心。注意并行处理可能受到内存限制和其他因素的影响。

6.refit

  • 类型: bool, default=True

说明: 是否使用在交叉验证过程中找到的最佳参数重新拟合整个训练集。如果 True(默认),best_estimator_ 属性将包含使用最佳参数训练得到的模型。

7.cv

  • 类型: int, cross-validation generator, an iterable, or None,
    default=None

说明: 交叉验证策略。可以是整数(代表折叠数,如 cv=5 表示五折交叉验证),特定的交叉验证生成器(如 KFold、StratifiedKFold),或者自定义的可迭代对象,产生(训练集,验证集)分割。若为 None,则使用 estimator 的默认交叉验证策略。

8.verbose

  • 类型: integer

说明: 日志冗长度。控制输出信息的详细程度:

0: 不输出训练过程信息。
1: 偶尔输出训练过程信息。
>1: 对每个子模型都输出训练过程信息。

9.return_train_score:

  • 类型: bool
  • 默认值: False

描述: 控制是否在网格搜索结果中包含训练得分。若设置为 True,将同时记录模型在训练集上的得分,以便进一步分析模型的过拟合或欠拟合情况。

GridSearchCV.fit(X, y[, groups]) 方法是 sklearn.model_selection.GridSearchCV 类的一个重要方法,用于执行网格搜索过程,即遍历给定的参数网格,并针对每个参数组合利用交叉验证策略训练和评估模型。
调用 GridSearchCV.fit(X, y) 后,实例将保存以下属性,供后续分析和使用:

  • best_params_: 最佳参数组合,即在交叉验证过程中表现最好的参数设置。
  • best_estimator_: 使用最佳参数重新拟合得到的模型实例。
  • best_score_: 在交叉验证过程中,最佳参数组合对应的平均得分(基于指定的 scoring 函数)。
  • cv_results_: 字典形式的详细结果,包含了所有参数组合、得分、训练时间等信息。

属性

1.best_estimator_

  • 类型: estimator object

说明: 使用最佳参数组合训练得到的最优模型实例。仅当 refit=True 时有效。

2.best_score_

  • 类型: float

说明: 在交叉验证过程中观察到的最佳(平均)评分。

3.best_params_

  • 类型: dict

说明: 描述了获得最佳结果的参数组合,即字典形式的超参数及其对应的最优值。

4.best_index_

  • 类型: int

说明: 对应于最佳候选参数设置的索引,即 cv_results_ 数组中的索引位置。

5.cv_results_

  • 类型: dict of numpy (masked) ndarrays

说明: 包含了所有参数组合、交叉验证得分以及相关指标的详尽结果。这是一个丰富的字典结构,包含了关于每个参数组合的详细统计数据,如各个折叠得分、均值、标准差等。

贝叶斯优化(optuna)

Optuna 是一个流行的Python库,专注于高效且直观地进行超参数优化。它旨在自动化机器学习(尤其是深度学习)模型的超参数搜索过程,以找到最优配置以提升模型性能。
以下是对Optuna关键概念与参数的详细解析:

  1. 研究(Study)
    optuna.create_study(): 创建一个研究对象,它是超参数优化任务的容器。研究定义了优化的目标(即要最小化或最大化的指标)、搜索空间、以及优化算法(如随机搜索、TPE等)。
study = optuna.create_study(
    study_name="example_study",
    direction="maximize",  # 或 "minimize",根据目标函数的需求
    sampler=optuna.samplers.TPESampler(),  # 默认为RandomSampler,可指定其他采样器
)
  1. 目标函数(Objective Function)
    用户需提供一个目标函数,该函数接受一个包含待优化超参数的字典作为输入,并返回一个数值表示模型性能。Optuna会尝试不同的超参数组合,通过调用目标函数计算其性能,然后根据研究的方向(最大化或最小化)来指导搜索。
def objective(trial):
    ...
    return accuracy  # 假设我们希望最大化accuracy
  1. 超参数(Hyperparameters)
    trial.suggest_*(): 在目标函数内部,使用Optuna提供的suggest方法来声明和获取超参数。这些方法包括不同类型的分布,如:
    trial.suggest_float(): 定义一个连续浮点数范围。
    trial.suggest_int(): 定义一个整数范围。
    trial.suggest_categorical(): 定义离散的类别选择。
    更多复杂类型,如trial.suggest_loguniform()、trial.suggest_discrete_uniform()等。
def objective(trial):
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True)
    num_layers = trial.suggest_int("num_layers", 1, 5)
    activation = trial.suggest_categorical("activation", ["relu", "sigmoid", "tanh"])
    ...
  1. 采样器(Samplers)
    optuna.samplers.*: 采样器决定了如何从搜索空间中抽取出超参数组合。Optuna提供了多种内置采样器,如:

TPESampler: Tree-structured Parzen Estimator (TPE),基于概率模型的适应性采样器,通常表现优秀。
RandomSampler: 随机搜索,简单但可靠。
GridSampler: 网格搜索,适用于小规模、离散参数空间。
CmaEsSampler: 基于Covariance Matrix Adaptation Evolution Strategy (CMA-ES)的进化算法。

选择采样器时应考虑搜索空间的特性、计算资源限制以及对探索与利用的平衡需求。

  1. 优化(Optimization)
    study.optimize(): 启动超参数优化过程。需要传入目标函数和相关参数,如最大试次数、时间限制等。
study.optimize(objective, n_trials=100, timeout=600)  # 运行100次试验或最长600秒
  1. 存储与重载(Storage & Reloading)
    optuna.storages.*: Optuna支持将研究数据保存到各种后端存储,如SQLite、MySQL、Redis等,便于跨会话或分布式环境中的工作。
study = optuna.create_study(
    storage="sqlite:///example.db",
    study_name="example_study",
    ...
)

之后可以使用相同的存储URL重新加载已存在的研究:

study = optuna.load_study(study_name="example_study", storage="sqlite:///example.db")
  1. 可视化与分析(Visualization & Analysis)
    optuna.visualization.plot_*(): 提供了一系列图表来可视化研究结果,如参数重要性、优化历史、平行坐标图等。
optuna.visualization.plot_param_importances(study)
optuna.visualization.plot_optimization_history(study)

Dashboard: Optuna还提供了一个交互式Web dashboard,可以实时监控优化过程和分析结果。启动dashboard:

optuna-dashboard sqlite:///example.db --study example_study
  1. 约束(Constraints)
    可以通过在目标函数中添加条件判断来设置超参数组合的约束:
def objective(trial):
    ...
    if learning_rate > 0.01 and batch_size < 32:
        raise optuna.structs.TrialPruned  # 若不符合条件,则标记该试验为被剪枝
  1. 早停(Early Stopping)
    optuna.trial.Trial.report() 和 optuna.trial.Trial.should_prune(): 可以在目标函数中报告中间结果,并检查是否应提前终止当前试验(早停)。这有助于节省计算资源。
def objective(trial):
    for epoch in range(num_epochs):
        ...
        intermediate_value = val_loss
        trial.report(intermediate_value, step=epoch)
        if trial.should_prune(epoch):
            raise optuna.structs.TrialPruned
  1. 多目标优化(Multi-objective Optimization)
    optuna.create_study(multivariate=True): 支持多目标优化,目标函数返回一个包含多个目标值的列表。
import optuna

def objective(trial):
    lgb_params = {
        "verbosity": -1,
        'objective': 'regression',
        'metric': 'rmse',
        'boosting_type': 'gbdt',
        'random_state': 6,
        'n_estimators': trial.suggest_int('n_estimators', 50, 200),
        'reg_alpha': trial.suggest_loguniform('reg_alpha', 1e-3, 10.0),
        'reg_lambda': trial.suggest_loguniform('reg_lambda', 1e-3, 10.0),#对数分布的建议值
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1),#浮点数
        'subsample': trial.suggest_float('subsample', 0.5, 1),
        'learning_rate': trial.suggest_float('learning_rate', 1e-4, 0.5, log=True),
        'num_leaves' : trial.suggest_int('num_leaves', 8, 64),#整数
        'min_child_samples': trial.suggest_int('min_child_samples', 1, 100),
    }
    X=train_feats.drop(['target'],axis=1).copy()
    y=train_feats['target'].copy()
    test_X=valid_feats.drop(['target'],axis=1).values.copy()
    test_y=valid_feats['target'].values.copy()
    test_preds=np.zeros((5,len(test_X)))
    # 初始化 KFold
    kf = KFold(n_splits=5, shuffle=True,random_state=6)
    # 进行 k 折交叉验证
    for fold, (train_index, valid_index) in (enumerate(kf.split(X))):
        X_train, X_valid = X.iloc[train_index], X.iloc[valid_index]
        y_train, y_valid = y.iloc[train_index], y.iloc[valid_index]

        model=LGBMRegressor(**lgb_params)
        model.fit(X_train,y_train)
        test_preds[fold]=model.predict(test_X)
    test_preds=test_preds.mean(axis=0)
    mean_rmse=metric(test_y,test_preds)
    return mean_rmse
#创建的研究命名,找最小值.
study = optuna.create_study(direction='minimize', study_name='Optimize boosting hyperparameters')
#目标函数,尝试的次数
study.optimize(objective, n_trials=50)
lgb_params=study.best_trial.params

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/569181.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深入理解分布式事务① ---->分布式事务基础(四大特性、五大类型、本地事务、MySQL并发事务问题、MySQL事务隔离级别命令设置)详解

目录 深入理解分布式事务① ---->分布式事务基础&#xff08;四大特性、五大类型、本地事务、MySQL并发事务问题、MySQL事务隔离级别命令设置&#xff09;详解事务的基本概念1、什么是事务&#xff1f;2、事务的四大特性2-1&#xff1a;原子性&#xff08;Atomic&#xff09…

STM32点灯大师(中断法)

一、使用CubeMX配置 新增加了RCC进行配置 二、代码 需要重写虚函数&#xff0c;给自己引用

Python打怪升级(4)

在计算机领域常常有说"合法"和"非法"指的是:是否合理&#xff0c;是否有效&#xff0c;并不是指触犯了法律。 random.randint(begin,end) 详细讲解一下这个random是指模板&#xff0c;也就是别人写好的代码直接来用&#xff0c;在Python当中&#xff0c;…

《R语言与农业数据统计分析及建模》学习——ggplot2绘图基础

一、农业科研数据可视化常用图形及用途 1、数据可视化的重要性 通过可视化&#xff0c;我们可以更直观地理解和分析数据的特征和趋势。 2、常用图表类型及其概述 散点图&#xff1a;用于展示两个变量之间的关系&#xff0c;可用于观察数据的分布、趋势和异常值。 折线图&…

网络安全之CSRFSSRF漏洞(上篇)(技术进阶)

目录 一&#xff0c;CSRF篇 二&#xff0c;认识什么是CSRF 三&#xff0c;实现CSRF攻击的前提 四&#xff0c;实战演练 【1】案例1 【2】案例2 【3】案例3 【4】案例4&#xff08;metinfo&#xff09; 一&#xff0c;CSRF篇 二&#xff0c;认识什么是CSRF CSRF&#x…

YesPMP众包平台最新项目

YesPMP一站式互联网众包平台&#xff0c;最新外包项目&#xff0c;有感兴趣的用户可进入平台参与竞标。 &#xff08;竞标后由项目方直接与服务商联系&#xff0c;双方直接对接&#xff09; 1.查看项目&#xff1a;个人技术-YesPMP平台 2.查看项目&#xff1…

【003_音频开发_基础篇_Linux进程通信(20种你了解几种?)】

003_音频开发_基础篇_Linux进程通信&#xff08;20种你了解几种&#xff1f;) 文章目录 003_音频开发_基础篇_Linux进程通信&#xff08;20种你了解几种&#xff1f;)创作背景Linux 进程通信类型fork() 函数fork() 输出 2 次fork() 输出 8 次fork() 返回值fork() 创建子进程 方…

zkVM选型要点

1. 引言 当选择ZK工具&#xff0c;来做可验证链下计算来扩容区块链时&#xff0c;需考虑&#xff1a; 1&#xff09;为何应选择zkVM&#xff1f;2&#xff09;zkVM有哪些基本功能&#xff1f;3&#xff09;哪些zkVM可提供这些基本功能&#xff1f; 2. 为何应选择zkVM&#x…

OpenCV——图像分块局部阈值二值化

目录 一、算法原理1、算法概述2、参考文献 二、代码实现三、结果展示 OpenCV——图像分块局部阈值二值化由CSDN点云侠原创&#xff0c;爬虫自重。如果你不是在点云侠的博客中看到该文章&#xff0c;那么此处便是不要脸的爬虫。 一、算法原理 1、算法概述 针对目前局部阈值二值…

消息队列 Kafka 入门篇(二) -- 安装启动与可视化工具

一、Windows 10 环境安装 1、下载与解压 首先&#xff0c;访问Apache Kafka的官方下载地址&#xff1a; https://kafka.apache.org/downloads 在本教程中&#xff0c;我们将使用kafka_2.13-2.8.1版本作为示例。下载完成后&#xff0c;解压到您的工作目录的合适位置&#xff…

目标检测——YOLOv6算法解读

论文&#xff1a;YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications (2022.9.7) 作者&#xff1a;Chuyi Li, Lulu Li, Hongliang Jiang, Kaiheng Weng, Yifei Geng, Liang Li, Zaidan Ke, Qingyuan Li, Meng Cheng, Weiqiang Nie, Yiduo Li, Bo …

企业商业活动如何获得央级媒体的采访报道?

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 企业想要获得央级媒体的采访报道&#xff0c;确实需要精心策划和准备&#xff1a; 一、如何巧妙给媒体报选题 精准定位&#xff1a;首先要明确企业的核心价值、创新点或行业影响力&…

【C++】手撕list(list的模拟实现)

目录 01.节点 02.迭代器 迭代器运算符重载 03.list类 &#xff08;1&#xff09;构造与析构 &#xff08;2&#xff09;迭代器相关 &#xff08;3&#xff09;容量相关 &#xff08;4&#xff09;访问操作 &#xff08;5&#xff09;插入删除 我们在学习数据结构的时候…

StartAI智能绘图软件出现“缺少Python运行库”怎么办?

StartAI做为一款国产AI界的新秀&#xff0c;是一款贴合AIGC新手的智能绘图软件。新手安装遇见“缺少Python运行库”怎么办”&#xff1f;小编一招搞定~ 解决方法&#xff1a;手动下载【resource文件】&#xff0c;将文件添加到安装目录下。 点击链接进行手动下载噢~ 确保 Star…

图像处理之模板匹配(C++)

图像处理之模板匹配&#xff08;C&#xff09; 文章目录 图像处理之模板匹配&#xff08;C&#xff09;前言一、基于灰度的模板匹配1.原理2.代码实现3.结果展示 总结 前言 模板匹配的算法包括基于灰度的匹配、基于特征的匹配、基于组件的匹配、基于相关性的匹配以及局部变形匹…

Spring-IOC之组件扫描

版本 Spring Framework 6.0.9​ 1. 前言 通过自动扫描&#xff0c;Spring 会自动从扫描指定的包及其子包下的所有类&#xff0c;并根据类上的特定注解将该类装配到容器中&#xff0c;而无需在 XML 配置文件或 Java 配置类中逐一声明每一个 Bean。 支持的注解 Spring 支持一系…

Mysql索引详解(索引分类)

文章目录 概述索引对查询速度的影响索引的优缺点索引类型一级索引和二级索引的区别MySQL 回表联合索引&#xff08;最左前缀原则主键索引和唯一索引的区别BTree索引和Hash索引的区别 覆盖索引索引下推加索引能够提升查询效率原因MySQL 索引结构采用 B树原因索引失效的场景MySQL…

JAVASE基础语法(异常、常用类)

一、异常 1.1 什么是异常 异常就是指不正常。是指代码在运行过程中可能发生错误&#xff0c;导致程序无法正常运行。 package com.atguigu.exception;public class TestException {public static void main(String[] args) {int[] arr {1,2,3,4,5};System.out.println(&quo…

前端css中filter(滤镜)的使用

前端css中filter的使用 一、前言二、补充内容说明三、模糊&#xff08;一&#xff09;、模糊效果&#xff0c;源码1&#xff08;二&#xff09;、源码1运行效果1.视频演示2.截图演示 四、阴影&#xff08;一&#xff09;、阴影效果&#xff0c;源码2&#xff08;二&#xff09;…

Linux文件系统与日志

一、inode和block 文件数据包括元信息与实际数据&#xff0c;文件存储在硬盘上&#xff0c;硬盘最小存储单位是扇区&#xff0c;每个扇区存储512字节 1.block(块)&#xff1a;文件系统中用于存储文件实际数据的最小单位&#xff0c;由文件系统进行分配和管理&#xff0c;并通…