机器学习第5天:多项式回归与学习曲线

文章目录

多项式回归介绍

方法与代码

方法描述

分离多项式

学习曲线的作用

场景

学习曲线介绍

欠拟合曲线

示例

结论

过拟合曲线

示例

​结论 


多项式回归介绍

当数据不是线性时我们该如何处理呢,考虑如下数据

import matplotlib.pyplot as plt
import numpy as np


np.random.seed(42)

x = 8 * np.random.rand(100, 1) - 4
y = 2*x**2+3*x+np.random.randn(100, 1)

plt.scatter(x, y)
plt.show()


方法与代码

方法描述

先讲思路,以这个二元函数为例

y=3*x^{2}+2*x+c

将多项式化为多个单项的,也就是将x的平方和x两个项分离开,然后单独给线性模型处理,求出参数,最后再组合在一起,很好理解,让我们来看一下代码


分离多项式

我们使用机器学习库的PolynomialFeatures来分离多项式

from sklearn.preprocessing import PolynomialFeatures


poly_features = PolynomialFeatures(degree=2, include_bias=False)
x_poly = poly_features.fit_transform(x)
print(x[0])
print(x_poly[0])

运行结果

可以看到,4, 5行代码将原始x和x平方挑选了出来,这时我们再把这个数据进行线性回归

model = LinearRegression()
model.fit(x_poly, y)
print(model.coef_)

 这段代码使用处理后的x拟合y,再打印模型拟合的参数,可以看到模型的两个参数分别是2.9和2左右,而我们的方程的一次参数和二次参数分别是3和2,可见效果还是很好的

把预测的结果绘制出来

model = LinearRegression()
model.fit(x_poly, y)
pre_y = model.predict(x_poly)

# 这里是为了让x升序的排序算法, 可以尝试不加这段代码图会变成什么样
sorted_indices = sorted(range(len(x)), key=lambda k: x[k])
x_sorted = [x[i] for i in sorted_indices]
y_sorted = [pre_y[i] for i in sorted_indices]

plt.plot(x_sorted, y_sorted, "r-")
plt.scatter(x, y)
plt.show()


学习曲线的作用

场景

设想一下,当你需要预测房价,你也有多组数据,包括离学校距离,交通状况等,但是问题来了,你只知道这些特征可能与房价有关,但并不知道这些特征与房价之间的方程关系,这时我们进行回归任务时,就可能导致欠拟合或者过拟合,幸运的是,我们可以通过学习曲线来判断


学习曲线介绍

学习曲线图就是以损失函数为纵坐标,数据集大小为横坐标,然后在图上画出训练集和验证集两条曲线的图,训练集就是我们用来训练模型的数据,验证集就是我们用来验证模型性能的数据集,我们往往将数据集分成训练集与验证集

我们先定义一个学习曲线绘制函数

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression


def plot_learning_curves(model, x, y):
    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)
    train_errors, val_errors = [], []
    for m in range(1, len(x_train)):
        model.fit(x_train[:m], y_train[:m])
        y_train_predict = model.predict(x_train[:m])
        y_val_predict = model.predict(x_val)
        train_errors.append(mean_squared_error(y_train[:m], y_train_predict))
        val_errors.append(mean_squared_error(y_val, y_val_predict))
    plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")
    plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")
    plt.legend()
    plt.show()

 简单介绍一下,这个函数接收模型参数,x,y参数,然后在for循环中,取不同数据集大小来计算RMSE损失(就是\sqrt{MSE}),然后把曲线绘制出来


欠拟合曲线

我们知道欠拟合就是模拟效果不好的情况,可以想象的到,无论在训练集还是验证集上,他的损失都会比较高

示例

我们将线性模型的学习曲线绘制出来

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression


def plot_learning_curves(model, x, y):
    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)
    train_errors, val_errors = [], []
    for m in range(1, len(x_train)):
        model.fit(x_train[:m], y_train[:m])
        y_train_predict = model.predict(x_train[:m])
        y_val_predict = model.predict(x_val)
        train_errors.append(mean_squared_error(y_train[:m], y_train_predict))
        val_errors.append(mean_squared_error(y_val, y_val_predict))
    plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")
    plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")
    plt.legend()
    plt.show()


x = np.random.rand(100, 1)
y = 2 * x + np.random.rand(100, 1)

model = LinearRegression()
plot_learning_curves(model, x, y)

 

结论

可以看到,在只有一点数据时,模型在训练集上效果很好(因为就是开始这一些数据训练出来的),而在验证集上效果不好,但随着训练集增加(模型学习到的越多),验证集上的误差逐渐减小,训练集上的误差增加(因为是学到了一个趋势,不会完全和训练集一样了)

这个图的特征是两条曲线非常接近,且误差都较大(差不多在0.3) ,这是欠拟合的表现(模型效果不好)


过拟合曲线

过拟合就是完全以数据集来模拟曲线,泛化能力很差

示例

我们来试试将一次函数模拟成三次函数,再来看看学习曲线(毫无疑问过拟合了)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline


def plot_learning_curves(model, x, y):
    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)
    train_errors, val_errors = [], []
    for m in range(1, len(x_train)):
        model.fit(x_train[:m], y_train[:m])
        y_train_predict = model.predict(x_train[:m])
        y_val_predict = model.predict(x_val)
        train_errors.append(mean_squared_error(y_train[:m], y_train_predict))
        val_errors.append(mean_squared_error(y_val, y_val_predict))
    plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")
    plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")
    plt.legend()
    plt.show()


np.random.seed(10)
x = np.random.rand(200, 1)
y = 2 * x + np.random.rand(200, 1)

poly_regression = Pipeline([
    ("Poly", PolynomialFeatures(degree=3, include_bias=False)),
    ("Line", LinearRegression())
])

plot_learning_curves(poly_regression, x, y)

结论 

这条曲线的特征是训练集的效果比验证集好(两条线之间有一定间距),这往往是过拟合的表现(在训练集上效果好,验证集差,表面泛化能力差) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/159358.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于SSM的中小型企业财务管理设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

spring cloud openfeign 使用注意点

近期在做项目时给自己挖了一个坑,问题重现如下 使用的组件版本如下 spring boot 2.7.15,对应的 spring cloud 版本为 2021.0.5,其中 spring cloud 适配的 openfeign 版本是 3.1.5。 项目中使用的 feign 接口如下 public interface QueryApi…

使用express连接MySQL数据库编写基础的增、删、改、查、分页等接口

使用express连接MySQL数据库编写基础的增、删、改、查、分页接口 安装express-generator生成器 cnpm install -g express-generator通过生成器创建项目 express peifang-server切换至serverAPI目录 cd peifang-server下载所需依赖 cnpm install 运行项目 npm start访问项…

echarts 实现同一组legend控制两个饼图示例

实现同一组legend控制两个饼图示例: 该示例有如下几个特点: ①饼图不同值实现分割 ②实现tooltip自定义样式(echarts 实现tooltip提示框样式自定义-CSDN博客) ③自定义label内容 ④不同值颜色渐变 代码如下: this.o…

CentOS挂载:解锁文件系统的力量

目录 引言1 挂载简介2 挂载本地分区3 挂载网络共享文件系统4 使用CIFS挂载结论 引言 在CentOS(一种基于Linux的操作系统)上挂载文件系统是一项常见而重要的任务,无论是将新的磁盘驱动器添加到系统,还是挂载网络共享资源&#xff…

一个iOS tableView 滚动标题联动效果的实现

效果图 情景 tableview 是从屏幕顶部开始的,现在有导航栏,和栏目标题视图将tableView的顶部覆盖了 分析 我们为了达到滚动到某个分区选中标题的效果,就得知道 展示最顶部的cell或者区头在哪个分区范围内 所以我们必须首先获取顶部的位置 …

【bigo前端】egret中的对象池浅谈

本文首发于:https://github.com/bigo-frontend/blog/ 欢迎关注、转载。 egret是一款小游戏开发引擎,支持跨平台开发,之前使用这款引擎开发了一款捕鱼游戏,在这里简单聊下再egret中关于对象池的使用,虽然该引擎已经停止…

第六十二周周报

学习目标: 一、实验 二、论文 学习时间: 2023.11.11-2023.11.17 学习产出: 实验 1、CB模块实验效果出来了,加上去效果不太行,后续实验考虑是否将CB模块换到其他地方 2、CiFAR100实验已完成,效果比Vi…

大模型之十二十-中英双语开源大语言模型选型

从ChatGPT火爆出圈到现在纷纷开源的大语言模型,众多出入门的学习者以及跃跃欲试的公司不得不面临的是开源大语言模型的选型问题。 基于开源商业许可的开源大语言模型可以极大的节省成本和加速业务迭代。 当前(2023年11月17日)开源的大语言模型如下&#…

基于DE10-Standard Cyclone V SoC FPGA学习---开发板简介

基于DE10-Standard Cyclone V SoC FPGA学习---开发板简介 简介产品规格基于 ARM 的 HPS配置与调试存储器件通讯连接头显示器音频视频输入模数转换器开关、按钮、指示器传感器电源 DE10-Standard 开发板系统框图Connect HTG 组件配置设计资源其他资源 简介 开发板资料 见 DE10-…

什么是CRM管理系统

什么是CRM管理系统 市场竞争的日益激烈,企业对于客户关系的重视程度不断提升。为了更好地管理和维护客户关系,很多企业开始引入CRM(Customer Relationship Management)管理系统。那么,什么是CRM管理系统呢&#xff1f…

Jenkins代码检测和本地静态检查

1:Jenkins简介 Jenkins是一个用Java编写的开源的持续集成工具;Jenkins自动化部署可以解决集成、测试、部署等重复性的工作,工具集成的效率明显高于人工操作;并且持续集成可以更早的获取代码变更的信息,从而更早的进入测…

Java 之拼图小游戏

声明 此项目为java基础的阶段项目,此项目涉及了基础语法,面向对象等知识,具体像语法基础如判断,循环,数组,字符串,集合等…; 面向对象如封装,继承,多态,抽象类,接口,内部类等等…都有涉及。此项目涉及的内容比较多,作为初学者可以很好的将前面的知识串起来。此项目拿来练手以及…

golang学习笔记——基础01

文章目录 golang概述Go 语言特色Go 语言用途 Go 语言结构执行 Go 程序 Go 语言包管理01Go 语言包管理02Go 语言基础语法Go 标记行分隔符注释标识符字符串连接关键字、预定义标识符Go 语言的空格格式化字符串 Go 语言数据类型数字类型浮点型其他数字类型 Go 语言变量变量声明零值…

Linux下安装部署redis(离线模式)

一、准备工作 1.下载redis的安装包 下载地址:Index of /releases/ 大家可以自行选择redis的版本,笔者选择的是最新的 2.上传到服务器 前提是我先在服务器上创建了一个目录redis7.2.3,我直接上传到这个目录下 二、安装redis 1.解压redis t…

03-瑞吉外卖关于菜品/套餐分类表的增删改查

新增菜品/套餐分类 页面原型 当我们在后台系统中添加菜品/套餐时,需要选择一个菜品/套餐分类,在移动端也会按照菜品分类和套餐分类来展示对应的菜品和套餐 第一步: 用户点击确定按钮执行submitForm函数发送Ajax请求,将新增菜品/套餐表单中输入的数据以json形式提交给服务端,…

代码随想录算法训练营第24天|77. 组合

JAVA代码编写 77. 组合 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 示例 1: 输入:n 4, k 2 输出: [[2,4],[3,4],[2,3],[1,2],[1,3],[1,4], ]示例 2: 输入…

IIC协议保姆级教学

目录 1.IIC协议概述 2.IIC总线传输 3.IIC-51单片机应用 1.起始信号 2.终止信号 3.应答信号 4.数据发送 4.IIC-32单片机应用 用到的库函数: 1.IIC协议概述 IIC全称Inter-Integrated Circuit (集成电路总线)是由PHILIPS公司在80年代开发的两线式串行总线&am…

hive sql 取当周周一 str_to_date(DATE_FORMAT(biz_date, ‘%Y%v‘), ‘%Y%v‘)

select str_to_date(DATE_FORMAT(biz_date, %Y%v), %Y%v)方法拆解 select DATE_FORMAT(now(), %Y%v), str_to_date(202346, %Y%v)

和鲸科技创始人范向伟受邀出席“凌云出海,来中东吧”2023华为云上海路演活动

11月9日,华为云“凌云出海,来中东吧”系列路演活动第二场在上海正式开启。聚焦“创业全球化”,本次活动由华为云携手阿布扎比投资办公室(ADIO)举办,旨在与渴望出海发展的优秀创业者们共探出海中东新商机。 …