Python机器学习17——Xgboost和Lightgbm结合分位数回归(机器学习与传统统计学结合)

最近XGboost支持分位数回归了,我看了一下,就做了个小的代码案例。毕竟学术市场上做这种新颖的机器学习和传统统计学结合的方法还是不多,算的上创新,找个好数据集可以发论文。


代码实现

导入包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error,r2_score
import xgboost as xgb
import lightgbm as lgb
import statsmodels.api as sm
from statsmodels.regression.quantile_regression import QuantReg

xgboost和lightgbm都需要安装的,他们和sklearn库的机器学习方法不是一个库的。怎么安装看我《实用的机器学习》这个栏目的xgb那篇文章。


模拟数据进行分位数回归

先制作一个模拟数据集

def f(x: np.ndarray) -> np.ndarray:
    return x * np.sin(x)

rng = np.random.RandomState(2023)
X = np.atleast_2d(rng.uniform(0, 10.0, size=1000)).T
expected_y = f(X).ravel()
sigma = 0.5 + X.ravel() / 10.0
noise = rng.lognormal(sigma=sigma) - np.exp(sigma**2.0 / 2.0)
y = expected_y + noise

print(X.shape,y.shape)

然后画图看看:

plt.figure(figsize=(6,2),dpi=100)
plt.scatter(X,y,s=1)
plt.show()

#划分训练集和测试集


X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)
print(f"Training data shape: {X_train.shape}, Testing data shape: {X_test.shape}")

这里采用三种模型进行拟合预测对比,分别是线性分位数回归,XGB结合分位数,LightGBM结合分位数:

alphas = np.arange(5, 100, 5) / 100.0
print(alphas)
mse_qr, mse_xgb, mse_lgb = [], [], []
r2_qr, r2_xgb, r2_lgb = [], [], []
qr_pred,xgb_pred,lgb_pred={},{},{}

# Train and evaluate
for alpha in alphas:
    # Quantile Regression
    model_qr = QuantReg(y_train, sm.add_constant(X_train)).fit(q=alpha)
    model_pred=model_qr.predict(sm.add_constant(X_test))
    mse_qr.append(mean_squared_error(y_test,model_pred ))
    r2_qr.append(r2_score(y_test,model_pred))
    
    # XGBoost
    model_xgb = xgb.train({"objective": "reg:quantileerror", 'quantile_alpha': alpha}, 
                          xgb.QuantileDMatrix(X_train, y_train), num_boost_round=100)
    model_pred=model_xgb.predict(xgb.DMatrix(X_test))
    mse_xgb.append(mean_squared_error(y_test,model_pred ))
    r2_xgb.append(r2_score(y_test,model_pred))
    
    # LightGBM
    model_lgb = lgb.train({'objective': 'quantile', 'alpha': alpha,'force_col_wise': True,}, 
                          lgb.Dataset(X_train, y_train), num_boost_round=100)
    
    model_pred=model_lgb.predict(X_test)
    mse_lgb.append(mean_squared_error(y_test,model_pred))
    r2_lgb.append(r2_score(y_test,model_pred))
    
    if alpha in [0.1,0.5,0.9]:
        qr_pred[alpha]=model_qr.predict(sm.add_constant(X_test))
        xgb_pred[alpha]=model_xgb.predict(xgb.DMatrix(X_test))
        lgb_pred[alpha]=model_lgb.predict(X_test)

分位点为0.1,0.5,0.9时记录一下,方便画图查看。

然后画出三种模型在不同分位点下的误差和拟合优度对比:

plt.figure(figsize=(7, 5),dpi=128)
plt.subplot(211)
plt.plot(alphas, mse_qr, label='Quantile Regression')
plt.plot(alphas, mse_xgb, label='XGBoost')
plt.plot(alphas, mse_lgb, label='LightGBM')
plt.legend()
plt.xlabel('Quantile')
plt.ylabel('MSE')
plt.title('MSE across different quantiles')

plt.subplot(212)
plt.plot(alphas, r2_qr, label='Quantile Regression')
plt.plot(alphas, r2_xgb, label='XGBoost')
plt.plot(alphas, r2_lgb, label='LightGBM')
plt.legend()
plt.xlabel('Quantile')
plt.ylabel('$R^2$')
plt.title('$R^2$ across different quantiles')
plt.tight_layout()
plt.show()

可以看到在分位点为0.5附件,模型的误差都比较小。因为这个数据集没有很多的异常值。然后模型表现上,LGBM>XGB>线性QR。线性模型对于一个非线性的函数关系拟合在这里当然不行。

画出拟合图:
 

name=['QR','XGB-QR','LGB-QR']
plt.figure(figsize=(7, 6),dpi=128)
for k,model in enumerate([qr_pred,xgb_pred,lgb_pred]):
    n=int(str('31')+str(k+1))
    plt.subplot(n)
    plt.scatter(X_test,y_test,c='k',s=2)
    for i,alpha in enumerate([0.1,0.5,0.9]):
        sort_order = np.argsort(X_test, axis=0).ravel()
        X_test_sorted = np.array(X_test)[sort_order]
        #print(np.array(model[alpha]))
        predictions_sorted = np.array(model[alpha])[sort_order]
        plt.plot(X_test_sorted,predictions_sorted,label=fr"$\tau$={alpha}",lw=0.8)
    plt.legend()
    plt.title(f'{name[k]}')
plt.tight_layout()
plt.show()

可以看到分位数回归的明显的区间特点。

还有非参数非线性方法的优势,明显XGB和LGBM拟合得更好。


波士顿数据集

上面是人工数据,下面采用真实的数据集进行对比,就用回归最常用的波士顿房价数据集吧:

data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]
column_names = ['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO',  'B','LSTAT', 'MEDV']
boston=pd.DataFrame(np.hstack([data,target.reshape(-1,1)]),columns= column_names)

取出X和y,划分测试集和训练集

X = boston.iloc[:,:-1]
y = boston.iloc[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

拟合预测,对比

alphas = np.arange(0.1, 1, 0.1)
mse_qr, mse_xgb, mse_lgb = [], [], []
r2_qr, r2_xgb, r2_lgb = [], [], []
qr_pred,xgb_pred,lgb_pred={},{},{}
# Train and evaluate
for alpha in alphas:
    # Quantile Regression
    model_qr = QuantReg(y_train, sm.add_constant(X_train)).fit(q=alpha)
    model_pred=model_qr.predict(sm.add_constant(X_test))
    mse_qr.append(mean_squared_error(y_test,model_pred ))
    r2_qr.append(r2_score(y_test,model_pred))
    
    # XGBoost
    model_xgb = xgb.train({"objective": "reg:quantileerror", 'quantile_alpha': alpha}, 
                          xgb.QuantileDMatrix(X_train, y_train), num_boost_round=100)
    model_pred=model_xgb.predict(xgb.DMatrix(X_test))
    mse_xgb.append(mean_squared_error(y_test,model_pred ))
    r2_xgb.append(r2_score(y_test,model_pred))
    
    # LightGBM
    model_lgb = lgb.train({'objective': 'quantile', 'alpha': alpha,'force_col_wise': True,}, 
                          lgb.Dataset(X_train, y_train), num_boost_round=100)
    
    model_pred=model_lgb.predict(X_test)
    mse_lgb.append(mean_squared_error(y_test,model_pred))
    r2_lgb.append(r2_score(y_test,model_pred))
    
    if alpha in [0.1,0.5,0.9]:
        qr_pred[alpha]=model_qr.predict(sm.add_constant(X_test))
        xgb_pred[alpha]=model_xgb.predict(xgb.DMatrix(X_test))
        lgb_pred[alpha]=model_lgb.predict(X_test)

画图查看不同分位点的不同模型的误差和拟合优度:

plt.figure(figsize=(8, 5),dpi=128)
plt.subplot(211)
plt.plot(alphas, mse_qr, label='Quantile Regression')
plt.plot(alphas, mse_xgb, label='XGBoost')
plt.plot(alphas, mse_lgb, label='LightGBM')
plt.legend()
plt.xlabel('Quantile')
plt.ylabel('MSE')
plt.title('MSE across different quantiles')

plt.subplot(212)
plt.plot(alphas, r2_qr, label='Quantile Regression')
plt.plot(alphas, r2_xgb, label='XGBoost')
plt.plot(alphas, r2_lgb, label='LightGBM')
plt.legend()
plt.xlabel('Quantile')
plt.ylabel('$R^2$')
plt.title('$R^2$ across different quantiles')
plt.tight_layout()
plt.show()

可以看到在分位点为0.6附件三个模型表现效果都比较好,然后模型表现来看,XGB>LGBM>QR,还是两个机器学习模型更厉害。


分位数损失函数和平方和损失函数对比

上面我们得到在分位点为0.6的时候,模型效果表现好,那么分位数模型和普通的MSE损失函数的效果比起来怎么样呢?我们继续对比:

# 定义alpha值
alpha = 0.5

# 分位数回归模型
model_qr = sm.regression.quantile_regression.QuantReg(y_train, sm.add_constant(X_train)).fit(q=alpha)
qr_pred = model_qr.predict(sm.add_constant(X_test))

# XGBoost分位数回归
model_xgb = xgb.train({"objective": "reg:quantileerror", 'quantile_alpha': alpha}, 
                      xgb.DMatrix(X_train, label=y_train), num_boost_round=100)
xgb_q_pred = model_xgb.predict(xgb.DMatrix(X_test))

# LightGBM分位数回归
model_lgb = lgb.train({'objective': 'quantile', 'alpha': alpha,'force_col_wise': True}, 
                      lgb.Dataset(X_train, label=y_train), num_boost_round=100)
lgb_q_pred = model_lgb.predict(X_test)

# 普通的最小二乘法线性回归
model_lr = LinearRegression()
model_lr.fit(X_train, y_train)
lr_pred = model_lr.predict(X_test)

# 普通的XGBoost
model_xgb_reg = xgb.train({"objective": "reg:squarederror"}, xgb.DMatrix(X_train, label=y_train), num_boost_round=100)
xgb_pred = model_xgb_reg.predict(xgb.DMatrix(X_test))

# 普通的LightGBM
model_lgb_reg = lgb.train({'objective': 'regression', 'force_col_wise': True}, lgb.Dataset(X_train, label=y_train), num_boost_round=100)
lgb_pred = model_lgb_reg.predict(X_test)

上面是六个模型,非别是基于分位数回归的XGB,LGBM,线性分位数回归。还有三个基于最普通的MSE损失函数的普通XGB,LGBM和最小二乘线性回归。

# 计算6个模型的MSE和R^2 


models = ['QR', 'XGB Quantile', 'LightGBM Quantile', 'Linear Reg', 'XGBoost', 'LightGBM']
preds = [qr_pred, xgb_q_pred, lgb_q_pred, lr_pred, xgb_pred, lgb_pred]
mse_scores = [mean_squared_error(y_test, pred) for pred in preds]
r2_scores = [r2_score(y_test, pred) for pred in preds]

 画柱状图查看:

colors = sns.color_palette("muted", len(models))
fig, axs = plt.subplots(2, 1, figsize=(9,7))
axs[0].bar(models, mse_scores, color=colors)
axs[0].set_title('MSE Comparison')
axs[0].set_ylabel('MSE')
axs[1].bar(models, r2_scores, color=colors)
axs[1].set_title(r'$R^{2}$ Comparison')
axs[1].set_ylabel(r'$R^{2}$')
plt.tight_layout()
plt.show()

可以看到模型效果来看,XGboost由于Lightgbm优于线性模型。但是分位数回归效果没有MSE损失好,说明在这个数据集表现上,就采用最经典的MSE损失的普通的模型效果会更好。。。

确实是这样的,很多学术创新和改进都不一定比最经典和最常见的方法的效果好。

如果是那种异常值很多的数据,具有异方差的数据 ,可能损失函数改用分位数的会更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/104911.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

uni-app医院智能导诊系统源码

随着科技的迅速发展,人工智能已经逐渐渗透到我们生活的各个领域。在医疗行业中,智能导诊系统成为了一个备受关注的应用。本文将详细介绍智能导诊系统的概念、技术原理以及在医疗领域中的应用,分析其优势和未来发展趋势。 智能导诊系统通过人工…

如何训练Embedding Model

BGE的技术亮点: 高效预训练和大规模文本微调;在两个大规模语料集上采用了RetroMAE预训练算法,进一步增强了模型的语义表征能力;通过负采样和难负样例挖掘,增强了语义向量的判别力;借鉴Instruction Tuning的…

中央设备状态监控系统CMS如何帮助半导体晶圆厂提高产品良率

中央设备状态监控系统(CMS)在半导体晶圆厂中扮演着关键角色,帮助企业提高产品的良率。本文将介绍CMS是什么、当前半导体晶圆厂产品良率面临的挑战,并重点探讨CMS如何通过实时数据监控、故障预测和预警、以及统计分析和过程改进等方…

【论文阅读】(2023TPAMI)PCRLv2

目录 AbstractMethodMethodnsU-Net中的特征金字塔多尺度像素恢复多尺度特征比较从多剪切到下剪切训练目标 总结 Abstract 现有方法及其缺点:最近的SSL方法大多是对比学习方法,它的目标是通过比较不同图像视图来保留潜在表示中的不变合判别语义&#xff…

SENet 学习

ILSVRC 是一个比赛,全称是ImageNet Large-Scale Visual Recognition Challenge,平常说的ImageNet比赛指的是这个比赛。 使用的数据集是ImageNet数据集的一个子集,一般说的ImageNet(数据集)实际上指的是ImageNet的这个子…

day01_matplotlib_demo

文章目录 折线图plot多个绘图区绘制数学函数图像散点图scatter柱状图bar直方图histogram饼图pie总结 折线图plot import matplotlib.pyplot as pltplt.figure(figsize(15, 6), dpi80) plt.plot([1, 0, 9], [4, 5, 6]) plt.show()### 展现一周天气温度情况 # 创建画布 plt.figu…

NewStarCTF2023week4-More Fast(GC回收)

打开链接,存在很多个类,很明显是php反序列化漏洞利用,需要构造pop链 , 关于pop链构造的详细步骤教学,请参考我之前的博客,真的讲得很详细也容易理解: http://t.csdnimg.cn/wMYNB 如果你是刚接…

Echarts柱状图渐变色问题变通

问题背景 设计稿中给出了如下图的效果,在柱状图的最上面给出了一个白色的小块,起初我一直在思考亦或者搜索相关的问题:如何在Echarts柱状图顶部实现一个24*4的白色矩形块。始终不得其解,在一个吃饭的瞬间冒出来一个想法是否可以用…

MySQL数据xtrabackup物理备份方法

目录 一、物理备份的方式二、xtrabackup物理备份1.安装xtrabackup2.完整备份/恢复流程3.增量备份流程4.差异备份流程5.物理备份总结 一、物理备份的方式 1.完整备份 每次对数据进行完整的备份,即对整个数据库的备份、数据库结构和文件结构的备份,保存的…

Docker:创建主从复制的Redis集群

一、Redis集群 在实际项目里,一般不会简单地只在一台服务器上部署Redis服务器,因为单台Redis服务器不能满足高并发的压力,另外如果该服务器或Redis服务器失效,整个系统就可能崩溃。项目里一般会用主从复制的模式来提升性能&#x…

指针相关面试题目

数组名的意义: 1. sizeof( 数组名 ) ,这里的数组名表示整个数组,计算的是整个数组的大小。 2. & 数组名,这里的数组名表示整个数组,取出的是整个数组的地址。 3. 除此之外所有的数组名都表示首元素的地址。 下…

发卡系统微信小程序源码/云盘发卡系统源码带PC端/自动发卡小程序源码(开源)

源码介绍: 最新开源的发卡系统微信小程序源码,这是一款云盘发卡系统源码,还带了电脑PC端。它是一款实用方便操作自动发卡小程序源码,它使用ERMEB云盘发卡,能为用户提供便捷的发卡服务。 源码框架: 系统采…

软考高级之系统架构师之数据流图和流程图

数据流图 概述 数据流图,DFD,用于表示业务信息系统中的数据流,它表达系统中的据传从输入到存储间所涉及的程序。采用图形方式来表达系统的逻辑功能、数据在系统内部的逻辑流向和逻辑变换过程,是结构化系统分析方法的主要表达工具…

python 查找波峰和波谷

import numpy as np import matplotlib.pyplot as plt from scipy.signal import find_peaks# 生成示例信号 x np.array([1, 3, 7, 1, 2, 6, 0, 4, 3, 2, 5, 1])# 寻找波峰 peaks, _ find_peaks(x)# 寻找波谷(使用信号的负数形式) valleys, _ find_pe…

【html】图片多矩形框裁剪

说明 由于项目中需要对一个图片进行多选择框进行裁剪&#xff0c;所以特写当前的示例代码。 代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><base href"/"><title>图片裁剪</tit…

第1篇 目标检测概述 —(3)目标检测评价指标

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。目标检测评价指标是用来衡量目标检测算法性能的指标&#xff0c;主要包括几个指标&#xff1a;精确率&#xff08;Precision&#xff09;、召回率&#xff08;Recall&#xff09;、交并比&#xff08;IoU&#xff09;、平均…

GitLab升级16.5.0后访问提示502

系统是兼容CentOS8的TencentOS3.1 GitLab原来的版本是16.4.1 使用yum升级时发现GitLab有新版本,决定升级。 升级过程无异常,出现升级成功的提示。 可是意外的时,访问站点时提示502. GitLab比较吃资源,启动的服务较多。之前也有等会就正常的情况。 这次没那么幸运,一…

js创建 ajax 过程

目录 前言&#xff1a;AJAX 技术的重要性 详解&#xff1a;创建 AJAX 请求的步骤 1. 创建 XMLHttpRequest 对象 2. 配置请求 3. 处理响应 4. 发送请求 5. 处理异步请求 解析&#xff1a;AJAX 请求的重要性和限制 总结&#xff1a; 前言&#xff1a;AJAX 技术的重要性 …

FastAPI 快速学习之 Flask 框架对比

目录 一、前言二、FastAPI 优势三、Hello World四、HTTP 方法五、URL 变量六、查询字符串七、POST 请求八、文件上传九、表单提交十、Cookies十一、模块化视图十二、数据校验十三、自动化文档Swagger 风格ReDoc 风格 十四、CORS跨域 一、前言 本文主要对 FastAPI 与 Flask 框架…

驱动开发5 阻塞IO实例、IO多路复用

1 阻塞IO 进程1 #include <stdlib.h> #include <stdio.h> #include <sys/types.h> #include <sys/stat.h> #include <sys/ioctl.h> #include <fcntl.h> #include <unistd.h> #include <string.h>int main(int argc, char co…
最新文章