首页 > 编程学习 > 使用网格搜索(GridSearchCV)自动调参

使用网格搜索(GridSearchCV)自动调参

发布时间:2023/4/17 2:50:20

使用网格搜索(GridSearchCV)自动调参

描述

调参对于提高模型的性能十分重要。在尝试调参之前首先要理解参数的含义,然后根据具体的任务和数据集来进行,一方面依靠经验,另一方面可以依靠自动调参来实现。

Scikit-learn 中提供了网格搜索(GridSearchCV)工具进行自动调参,该工具自动尝试预定义的参数值列表,并具有交叉验证功能,最终找到最佳的参数组合。

本任务的主要实践内容:

1、 使用手工枚举来调参

2、 利用GridSearchCV自动调参

添加链接描述

环境

  • 操作系统:Windows10、Ubuntu18.04

  • 工具软件:Anaconda3 2019、Python3.7

  • 硬件环境:无特殊要求

  • 依赖库列表

    scikit-learn	  0.24.2
    

分析

本任务采用SVM算法对鸢尾花数据集进行建模,实践调参过程。任务涉及以下环节:

1)使用手工枚举来确定最佳参数组合

2)拆分出验证集进行调参

3)使用GridSearchCV进行自动调参,确定最佳参数组合

实施

步骤1、使用手工枚举进行调参

首先使用支持向量机(SVM)建立鸢尾花分类模型,模型主要参数为gamma和C,我们使用循环枚举的方式进行调参,具体过程为:

1、定义gamma和C两个参数的取值列表

2、定义循环,使用不同的参数组合创建模型并评估成绩

3、取最佳成绩的参数组合

from sklearn.svm import SVC 
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris 

iris = load_iris() # 加载数据
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=0) 
 
best_score = 0 # 最佳成绩

# 对每种参数组合都训练一个模型,评估其成绩,找到最佳参数组合
for gamma in [0.001, 0.01, 0.1, 1, 10, 100]:     
    for C in [0.001, 0.01, 0.1, 1, 10, 100]: 
        svm = SVC(gamma=gamma, C=C).fit(X_train, y_train) 
        score = svm.score(X_test, y_test) # 在测试集上评估SVC
        # 如果得到了更高的分数,则保存该分数和对应的参数         
        if score > best_score:             
            best_score = score             
            best_parameters = {'C': C, 'gamma': gamma} 
 
print("Best score: {:.2f}".format(best_score))  # 输出最佳成绩
print("Best parameters: {}".format(best_parameters)) # 输出最佳参数组合

输出结果:

Best score: 0.97
Best parameters: {'C': 100, 'gamma': 0.001}

可以看到,最佳成绩为97%,最佳组合为 {‘C’: 100, ‘gamma’: 0.001}

步骤2、拆分出验证集用于调参

如果模型调参在测试集上进行,就不能保证测试的客观性(相当于练习题与考试题不分),因此调参时需要拆分出独立的验证集,在验证集上调参,而测试集专门用于最终测试,即测试集不参与训练和调参过程。

2.1 拆分方法:

# 首先,将数据拆分为(训练+验证集)与测试集 
X_train_valid, X_test, y_train_valid, y_test = train_test_split(iris.data, iris.target, random_state=0) 

# 然后,将(训练+验证集)拆分为训练集与验证集 
X_train, X_valid, y_train, y_valid = train_test_split(X_train_valid, y_train_valid, random_state=1) 

2.2 拆分出验证集后,在验证集上进行调参,代码如下:

best_score = 0

# 对每种参数组合都训练一个模型,评估其成绩,找到最佳参数组合
for gamma in [0.001, 0.01, 0.1, 1, 10, 100]:     
    for C in [0.001, 0.01, 0.1, 1, 10, 100]: 
        svm = SVC(gamma=gamma, C=C).fit(X_train, y_train) 
        score = svm.score(X_valid, y_valid) # 注意!这里改为在验证集上评估模型    
        if score > best_score:             
            best_score = score             
            best_parameters = {'C': C, 'gamma': gamma} 

print("Training_set size: {} , Validation_set size: {} , Test_set size: {}"
      .format(X_train.shape[0], X_valid.shape[0], X_test.shape[0])) 

# 利用调参得到的最佳参数组合,在(训练+验证集)上重新构建一个模型,并在测试集上进行评估 
svm = SVC(**best_parameters) 
svm.fit(X_train_valid, y_train_valid) 
test_score = svm.score(X_test, y_test) # 测试集只用做最终评估,不参与调参,保证测试的客观性

print("Best score on validation_set: {:.2f}".format(best_score)) 
print("Best parameters: ", best_parameters) 
print("Test_set score with best parameters: {:.2f}".format(test_score)) # 输出测试成绩

输出结果:

Training_set size: 84 , Validation_set size: 28 , Test_set size: 38
Best score on validation_set: 0.96
Best parameters:  {'C': 10, 'gamma': 0.001}
Test_set score with best parameters: 0.92

强调:在机器学习中,测试集只做测试用,不参与训练与调参,以保证测试的客观性。

步骤3、使用网格搜索(GridSearchCV)自动调参

GridSearchCV整合了自动调参和交叉验证功能,相对于手工枚举更加高效。

from sklearn.model_selection import GridSearchCV # 引入网格搜索类

# 定义参数列表
param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100], 'gamma': [0.001, 0.01, 0.1, 1, 10, 100]} 

# 定义GridSearchCV对象,注意参数
# 参数1-模型
# 参数2-参数列表
# 参数3-交叉验证次数(网格搜索具有交叉验证功能)
grid_search = GridSearchCV(SVC(), param_grid, cv=10)

# 拆分数据集
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=0)
grid_search.fit(X_train, y_train) # 将训练数据交给GridSearchCV对象

# 输出最佳参数组合及最佳成绩(利用.best_params_和.best_score_属性)
print("Best parameters: {}".format(grid_search.best_params_)) 
print("Best cross-validation score: {:.2f}".format(grid_search.best_score_))

输出结果:

Best parameters: {'C': 10, 'gamma': 0.1}
Best cross-validation score: 0.98

说明:需要掌握GridSearchCV的参数及两个重要属性。

Copyright © 2010-2022 mfbz.cn 版权所有 |关于我们| 联系方式|豫ICP备15888888号