【Python机器学习】决策树集成——梯度提升回归树

理论知识:        

        梯度提升回归树通过合并多个决策树来构建一个更为强大的模型。虽然名字里有“回归”,但这个模型既能用于回归,也能用于分类。与随机森林方法不同,梯度提升采用连续的方式构造树,每棵树都试图纠正前一棵树的错误。默认情况下,梯度提升回归树中没有随机化,而是用到了强预剪枝。梯度提升树通常使用深度很小(1-5之间),这样的模型占用内存小,预测速度也更快。

        梯度提升背后的主要思想是合并许多简单的模型(弱学习器),比如深度较小的树。每棵树只能对部分数据做出比较好的预测,因此添加的树越来越多,可以不断迭代来提高性能。

        梯度提升树通常对参数设置非常敏感,但如果参数设置正确的话,模型精度会更高。

        除了预剪枝和集成树的数量外,梯度提升的另一个重要参数是learning_rate(学习率),用于控制每棵树纠正前一棵树的错误的强度。较高的学习率意味着每棵树都可以做出较强的修正,这样的模型更为复杂。通过增大n_estimators来向集成中添加更多树,也可以增加模型的复杂度,因为模型有更多机会来纠正训练集上的错误。

        默认参数上:树的数量为100、最大深度为3,学习率为0.1

示例:

以乳腺癌数据集为例,用分类模型:

from sklearn.ensemble import RandomForestClassifier,GradientBoostingClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
import numpy as np

plt.rcParams['font.sans-serif'] = ['SimHei']
cancer=load_breast_cancer()
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,random_state=0)
gbrt=GradientBoostingClassifier(random_state=0)
gbrt.fit(X_train,y_train)
print('训练集精度:{:.3f}'.format(gbrt.score(X_train,y_train)))
print('测试集精度:{:.3f}'.format(gbrt.score(X_test,y_test)))

 

由于训练集精度达到100%,所以很可能存在过拟合,为了降低过拟合,可以限制最大深度来加强预剪枝,也可以降低学习率:


gbrt_md1=GradientBoostingClassifier(random_state=0,max_depth=1)
gbrt_md1.fit(X_train,y_train)
print('max_depth=1训练集精度:{:.3f}'.format(gbrt_md1.score(X_train,y_train)))
print('max_depth=1测试集精度:{:.3f}'.format(gbrt_md1.score(X_test,y_test)))

gbrt_lr001=GradientBoostingClassifier(random_state=0,learning_rate=0.01)
gbrt_lr001.fit(X_train,y_train)
print('learning_rate=0.01训练集精度:{:.3f}'.format(gbrt_lr001.score(X_train,y_train)))
print('learning_rate=0.01测试集精度:{:.3f}'.format(gbrt_lr001.score(X_test,y_test)))

 

可以看到,两种方法都降低了训练集精度,而减小树的最大深度显著提升了模型性能。

特征重要性可视化: 

 

import mglearn.plots
from sklearn.ensemble import RandomForestClassifier,GradientBoostingClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
import numpy as np

def plot_importances(model):
    n_feature=cancer.data.shape[1]
    plt.barh(range(n_feature),model.feature_importances_,align='center')
    plt.yticks(np.arange(n_feature),cancer.feature_names)
    plt.xlabel('特征重要性')
    plt.ylabel('特征')

plt.rcParams['font.sans-serif'] = ['SimHei']
cancer=load_breast_cancer()
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,random_state=0)
gbrt=GradientBoostingClassifier(random_state=0)
gbrt.fit(X_train,y_train)

plot_importances(gbrt)
plt.show()

可以看到,梯度提升树的特征重要性与随机森林有些类似,但梯度提升树完全忽略了某些特征。

常用的方法是先尝试随机森林,因为它的鲁棒性很好,如果随机森林的效果好但预测时间太长,或者学习模型精度在小数点后两位的提高也很重要,那么切换成梯度提升树通常比较有用。

优缺点:

梯度提升树是监督学习中最强大也最常用的模型之一,它的主要缺点是需要仔细调参,而且训练时间会比较长;优点是不需要对数据进行缩放就可以表现的很好,而且也适用于二元特征和连续特征同时存在的数据集。与其他基于树的模型相同,梯度提升树通常也不适用于高纬稀疏数据。

梯度提升树的主要参数是树的数量n_estimators和学习率learning_rate。这两个参数高度相关,因为learning_rate越低,就需要更多树来构建具有相似复杂度的模型,随机森林的n_estimators值总是越大越好,但梯度提升树不同,增大n_estimators会导致模型更加复杂,进而可能导致过拟合,通常的做法是根据时间和内存的预算选择合适的n_estimators,然后会不同的learning_rate进行遍历。

另一个重要参数是max_depth,用来降低每棵树的复杂度,一般不超过5。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/318292.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

推荐算法常见的评估指标

推荐算法评估指标比较复杂,可以分为离线和在线两部分。召回、粗排、精排和重排由于定位区别,其评估指标也会有一定区别,下面详细讲解。 1 召回评价体系 召回结果并不是最终推荐结果,其本质是为后续排序层服务的,故核…

学习记录10-L6406E报错处理

前言 问题:在编译时报如下错误 ..\OBJ\LCD.axf: Error: L6406E: No space in execution regions with .ANY selector matching main.o(.constdata). ..\OBJ\LCD.axf: Error: L6406E: No space in execution regions with .ANY selector matching lcd_init.o(i.LCD…

rke2 Offline Deploy Rancher v2.8.0 latest (helm 离线部署 rancher v2.8.0)

文章目录 1. 预备条件2. 为什么是三个节点?​3. 配置私有仓库4. 介质清单5. 安装 helm6. 安装 cert-manager6.1 下载介质6.2 镜像入库6.3 helm 部署6.4 cert-manager 卸载 7. 安装 rancher7.1 镜像入库7.2 helm 安装 8. 验证9. 界面预览10. 卸载 1. 预备条件 所有支…

【SAP】如何删除控制范围

经历就是财富,可你终将遗忘。期望文字打败时间。 本周心惊胆战地在配置系统删除了一个控制范围,还是有些收获,特此记录一下。 背景:在删除控制范围之前,我主要做了如下配置。 定义控制范围(自动生成了成本…

层叠布局(Stack)

目录 1、概述 2、开发布局 3、对齐方式 3.1、TopStart 3.2、Top 3.3、TopEnd 3.4、Start 3.5、Center 3.6、End 3.7、BottomStart 3.8、Bottom 3.9、BottomEnd 4、Z序控制 5、场景示例 1、概述 层叠布局(StackLayout)用于在屏幕上预留一…

ChatGPT能帮助我们人类做什么

一、ChatGPT可以在多个方面帮助人类: 回答问题: ChatGPT可以回答各种问题,提供信息和解释概念。 创造性写作: 它可以生成文章、故事、诗歌等创意性文本。 学术辅助: ChatGPT可以辅助学术研究,提供解释、背…

如何生成文本: 通过 Transformers 用不同的解码方法生成文本

如何生成文本: 通过 Transformers 用不同的解码方法生成文本 假设 $p0.92$,Top-p 采样对单词概率进行降序排列并累加,然后选择概率和首次超过 $p92%$ 的单词集作为采样池,定义为 $V_{\text{top-p}}$。在 $t1$ 时 $V_{\text{top-p}}$ 有 9 个…

串行Nor Flash的结构和参数特性

文章目录 引言1、Nor Flash的结构2、Nor Flash的类别3.标准Serial Nor Flash的特征属性1.Wide Range VCC Flash2.Permanent Lock3.Default Lock Protection4.Standard Serial Interface5.Multi-I/O6.Multi-I/O Duplex (DTR)7.XIP(片上执行) 4.标准Serial…

Java SE入门及基础(11)

程序调试 1. 什么是程序调试 当程序出现问题时,我们希望程序能够暂停下来,然后通过我们操作使代码逐行执行,观察整个过程中变量的变化是否按照我们设计程序的思维变化,从而找问题并解决问题,这个过程称之为程序调试…

Openstack组件glance对接swift

2、glance对接swift (1)可直接在数据库中查看镜像存放的位置、状态、id等信息 (2)修改glance-api的配置文件,实现对接swift存储(配置文件在/etc/glance/glance-api.conf,建议先拷贝一份&#x…

基于反卷积方法的重大突破:结构光系统中的测量误差降低3倍

作者:小柠檬 | 来源:3DCV 在公众号「3DCV」后台,回复「原论文」可获取论文pdf 结构光三维测量技术在工业自动化、逆向工程和图形学领域越来越受欢迎。然而,现有的测量系统在成像过程中存在不完美,会导致在不连续边缘周…

nuxt pm2使用、启动、问题解决方案

pm2简介 pm2是一个进程管理工具,可以用它来管理node进程,并查看node进程的状态,当然也支持性能监控,进程守护,负载均衡等功能,在前端和nodejs的世界中用的很多 pm2安装 安装pm2: $ npm install -g pm2查看pm2的安装…

Kafka基本介绍

消息队列 产生背景 消息队列:指的数据在一个容器中,从容器中一端传递到另一端的过程 消息(message): 指的是数据,只不过这个数据存在一定流动状态 队列(queue): 指的容器,可以存储数据,只不过这个容器具备FIFO(先进…

Linux技术,winSCP连接服务器超时故障解决方案

知识改变命运,技术就是要分享,有问题随时联系,免费答疑,欢迎联系! 故障现象 使用 sftp 协议连接主机时, 明显感觉缓慢且卡顿,并且时常出现如下报错: 点击重新连接后,又有概率重新连接上; 总之在"连接上"和&…

蓝桥杯练习题(六)

📑前言 本文主要是【算法】——蓝桥杯练习题(六)的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 …

分布式系统的三字真经CAP

文章目录 前言C(Consistency 数据一致性)A(Availability 服务可用性)P(Partition Tolerance 分区容错性)CAP理论最后 前言 你好,我是醉墨居士,我一起探索一下分布式系统的三字真经C…

Linux完全卸载Anaconda3和MiniConda3

如何安装Anaconda3和MiniConda3请看这篇文章: 安装Anaconda3和MiniConda3_minianaconda3-CSDN博客文章浏览阅读474次。MiniConda3官方版是一款优秀的Python环境管理软件。MiniConda3最新版只包含conda及其依赖项如果您更愿意拥有conda以及超过720个开源软件包&…

怎么安装es、kibana(单点安装)

1.部署单点es 1.1.创建网络 因为我们还需要部署kibana容器,因此需要让es和kibana容器互联。这里先创建一个网络: docker network create es-net1.2.加载镜像 这里我们采用elasticsearch的7.12.1版本的镜像,这个镜像体积非常大&#xff0c…

增删改查管理系统 总结1

//提醒: 管理员也要有增删改查 新增员工代码完善2可能需要用到 目录 细节1 pom文件出现奇怪页面? 细节2 如何联系DataGrip与idea? 细节3 Yapi?接口文档?如何有以下画面? ​细节4 如何将时间转化为好看的时间&…

PostgreSQL认证考试PGCA、PGCE、PGCM

PostgreSQL认证考试PGCA、PGCE、PGCM 【重点!重点!重点!】PGCA、PGCE、PGCM 直通车快速下正,省心省力,每2个月一次考试 PGCE考试通知 (2024) 一、考试概览 (一) 报名要…
最新文章