机器学习:多元线性回归闭式解(Python)

import numpy as np
import matplotlib.pyplot as plt


class LRClosedFormSol:
    def __init__(self, fit_intercept=True, normalize=True):
        """
        :param fit_intercept: 是否训练bias
        :param normalize: 是否标准化数据
        """
        self.theta = None  # 训练权重系数
        self.fit_intercept = fit_intercept  # 线性模型的常数项。也即偏置bias,模型中的theta0
        self.normalize = normalize  # 是否标准化数据
        if normalize:
            self.feature_mean, self.feature_std = None, None  # 特征的均值,标准方差
        self.mse = np.infty  # 训练样本的均方误差
        self.r2, self.r2_adj = 0.0, 0.0  # 判定系数和修正判定系数
        self.n_samples, self.n_features = 0, 0  # 样本量和特征数

    def fit(self, x_train, y_train):
        """
        模型训练,根据是否标准化与是否拟合偏置项分类讨论
        :param x_train: 训练样本集
        :param y_train: 训练目标集
        :return:
        """
        if self.normalize:
            self.feature_mean = np.mean(x_train, axis=0)  # 按样本属性计算样本均值
            self.feature_std = np.std(x_train, axis=0) + 1e-8  # 样本方差,为避免零除,添加噪声
            x_train = (x_train - self.feature_mean) / self.feature_std  # 标准化
        if self.fit_intercept:
            x_train = np.c_[x_train, np.ones_like(y_train)]  # 添加一列1,即偏置项样本
        # 训练模型
        self._fit_closed_form_solution(x_train, y_train)  # 求闭式解

    def _fit_closed_form_solution(self, x_train, y_train):
        """
        线性回归的闭式解,单独函数,以便后期扩充维护
        :param x_train: 训练样本集
        :param y_train: 训练目标集
        :return:
        """
        # pinv伪逆,即(A^T * A)^(-1) * A^T
        self.theta = np.linalg.pinv(x_train).dot(y_train)  # 非正则化
        # xtx = np.dot(x_train.T, x_train) + 0.01 * np.eye(x_train.shape[1])  # 按公式书写
        # self.theta = np.dot(np.linalg.inv(xtx), x_train.T).dot(y_train)

    def get_params(self):
        """
        返回线性模型训练的系数
        :return:
        """
        if self.fit_intercept:  # 存在偏置项
            weight, bias = self.theta[:-1], self.theta[-1]
        else:
            weight, bias = self.theta, np.array([0])
        if self.normalize:  # 标准化后的系数
            weight = weight / self.feature_std.reshape(-1)  # 还原模型系数
            bias = bias - weight.T.dot(self.feature_mean.reshape(-1))
        return np.r_[weight.reshape(-1), bias.reshape(-1)]

    def predict(self, x_test):
        """
        测试数据预测,x_test:待预测样本集,不包括偏置项1
        :param x_test:
        :return:
        """
        try:
            self.n_samples, self.n_features = x_test.shape[0], x_test.shape[1]
        except IndexError:
            self.n_samples, self.n_features = x_test.shape[0], 1  # 测试样本数和特征数
        if self.normalize:
            x_test = (x_test - self.feature_mean) / self.feature_std  # 测试数据标准化
        if self.fit_intercept:
            x_test = np.c_[x_test, np.ones(shape=x_test.shape[0])]  # 存在偏置项,添加一列1
        return x_test.dot(self.theta)

    def cal_mse_r2(self, y_pred, y_test):
        """
        计算均方误差,计算拟合优度的判定系数R方和修正判定系数
        :param y_pred: 模型预测目标真值
        :param y_test: 测试目标真值
        :return:
        """
        self.mse = ((y_test - y_pred) ** 2).mean()  # 均方误差
        # 计算测试样本的判定系数和修正判定系数
        self.r2 = 1 - ((y_test - y_pred) ** 2).sum() / ((y_test - y_test.mean()) ** 2).sum()
        self.r2_adj = 1 - (1 - self.r2) * (self.n_samples - 1) / (self.n_samples - self.n_features - 1)
        return self.mse, self.r2, self.r2_adj

    def plt_predict(self, y_pred, y_test, is_show=True, is_sort=True):
        """
        绘制预测值与真实值对比图
        :return:
        """
        if self.mse is np.infty:
            self.cal_mse_r2(y_pred, y_test)
        if is_show:
            plt.figure(figsize=(7, 5))
        if is_sort:
            idx = np.argsort(y_test)
            plt.plot(y_pred[idx], "r:", lw=1.5, label="Predictive Val")
            plt.plot(y_test[idx], "k--", lw=1.5, label="Test True Val")
        else:
            plt.plot(y_pred, "r:", lw=1.5, label="Predictive Val")
            plt.plot(y_test, "k--", lw=1.5, label="Test True Val")
        plt.xlabel("Test sample observation serial number", fontdict={"fontsize": 12})
        plt.ylabel("Predicted sample value", fontdict={"fontsize": 12})
        plt.title("The predictive values of test samples \n MSE = %.5e, R2 = %.5f, R2_adj = %.5f"
                  % (self.mse, self.r2, self.r2_adj), fontdict={"fontsize": 14})
        plt.legend(frameon=False)
        plt.grid(ls=":")
        if is_show:
            plt.show()


from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, mean_squared_error


housing = fetch_california_housing()
X, y = housing.data, housing.target  # 获取样本数据和目标数据
X_train, X_test, y_train, y_test = \
    train_test_split(X, y, test_size=0.3, random_state=1, shuffle=True)

lgcfs_obj = LRClosedFormSol(normalize=True, fit_intercept=True)
lgcfs_obj.fit(X_train, y_train)
theta = lgcfs_obj.get_params()  # 获得模型系数
print("线性回归模型拟合系数如下:")
for i, fn in enumerate(housing.feature_names):
    print(fn + ":", theta[i])
print("Const:", theta[-1])

# 模型预测,即针对测试样本进行预测
y_pred = lgcfs_obj.predict(X_test)
lgcfs_obj.plt_predict(y_pred, y_test, is_sort=True)

# 采用sklearn库函数进行线性回归和预测
lr = LinearRegression().fit(X_train, y_train)
print("sklearn截距:", lr.intercept_)  # 打印截距
print("sklearn系数:", lr.coef_)  # 打印模型系数
y_test_predict = lr.predict(X_test)
mse = mean_squared_error(y_test, y_test_predict)
r2 = r2_score(y_test, y_test_predict)
print("sklearn均方误差与判定系数为:", mse, r2)




 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/346657.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue3 Teleport 将组件传送到外层DOM位置

✨ 专栏介绍 在当今Web开发领域中,构建交互性强、可复用且易于维护的用户界面是至关重要的。而Vue.js作为一款现代化且流行的JavaScript框架,正是为了满足这些需求而诞生。它采用了MVVM架构模式,并通过数据驱动和组件化的方式,使…

力扣hot100 合并区间 排序 贪心

Problem: 56. 合并区间 复杂度 时间复杂度: O ( n log ⁡ n ) O(n\log{n}) O(nlogn) 空间复杂度: O ( n ) O(n) O(n) Code class Solution {public int[][] merge(int[][] intervals){Arrays.sort(intervals, (int[] a, int[] b) -> {return a[0] - b[0];});// 按照数…

基于原生图数据库的知识图谱存储

目录 前言1 关系模型的局限1.1 语义关联的隐藏1.2 数据多样性的挑战1.3 动态性受限1.4 与自然语言描述失配 2 知识图谱与图数据库2.1 图数据库概述2.2 图的结构特征的优势2.3 跨领域图建模与查询2.4 丰富的关系语义表达与推理能力 3 图数据建模的好处3.1 自然表达3.2 易于扩展3…

uniapp canvas做的刮刮乐解决蒙层能自定义图片

最近给湖南中烟做元春活动&#xff0c;一个月要开发4个小活动&#xff0c;这个是其中一个难度一般&#xff0c;最难的是一个类似鲤鱼跃龙门的小游戏&#xff0c;哎&#xff0c;真实为难我这个“拍黄片”的。下面是主要代码。 <canvas :style"{width:widthpx,height:hei…

uniapp复选框 实现排他选项

选择了排他选项之后 复选框其他选项不可以选择 <view class"reportData" v-for"(val, index) in obj" :key"index"> <view v-if"val.type 3" ><u-checkbox-group v-model"optionValue" placement"colu…

orm-04-Spring Data JPA 入门介绍

拓展阅读 The jdbc pool for java.(java 手写 jdbc 数据库连接池实现) The simple mybatis.&#xff08;手写简易版 mybatis&#xff09; Spring Data JPA Spring Data JPA&#xff0c;作为更大的 Spring Data 家族的一部分&#xff0c;使得基于 JPA 的仓库实现变得更加容易。…

卫星影像离线瓦片如何调用?

我们曾为你分享了按区县购买卫星影像并在线调用的方法。 于是就有朋友问&#xff0c;卫星影像瓦片可以离线调用吗&#xff1f; 当然可以&#xff0c;这里就来分享一下卫星影像瓦片离线调用的方法。 卫星影像离线瓦片如何调用&#xff1f; 这里以OpenLayers、Mapbox和Cesiu…

Google 提出稀疏注意力框架Exphormer,提升图Transformer的扩展性!

引言 Graph Transformer已成为ML的重要架构&#xff0c;它将基于序列的Transformer应用于图结构数据。然而当面对大型图数据集时&#xff0c;使用Graph Transformer会存在扩展性限制。为此&#xff0c;「Google提出了一个稀疏注意力框架Exphormer&#xff0c;它使用扩展图来提…

LeetCode-2865. 美丽塔 I

题面 给你一个长度为 n 下标从 0 开始的整数数组 maxHeights 。 你的任务是在坐标轴上建 n 座塔。第 i 座塔的下标为 i &#xff0c;高度为 heights[i] 。 如果以下条件满足&#xff0c;我们称这些塔是 美丽 的&#xff1a; 1 < heights[i] < maxHeights[i] heights 是…

音频特效SDK,满足内容生产的音频处理需求

美摄科技&#xff0c;作为音频处理技术的佼佼者&#xff0c;推出的音频特效SDK&#xff0c;旨在满足企业内容生产中的音频处理需求。这款SDK内置多种常见音频处理功能&#xff0c;如音频变声、均衡器、淡入淡出、音频变调等&#xff0c;帮助企业轻松应对各种音频处理挑战。 一…

启动mitmproxy报错 ImportError: cannot import name ‘url_quote‘ from ‘werkzeug.urls‘

报错截图 ImportError: cannot import name url_quote from werkzeug.urls (d:\soft\python\python38\lib\site-packages\werkzeug\urls.py) 原因是Werkzeug版本不兼容导致 解决方法 pip install Werkzeug2.2.2

Java-NIO篇章(5)——Reactor反应器模式

前面已经讲过了Java-NIO中的三大核心组件Selector、Channel、Buffer&#xff0c;现在组件我们回了&#xff0c;但是如何实现一个超级高并发的socket网络通信程序呢&#xff1f;假设&#xff0c;我们只有一台内存为32G的Intel-i710八核的机器&#xff0c;如何实现同时2万个客户端…

【云原生】Docker的端口映射、数据卷、数据卷容器、容器互联

目录 一、端口映射&#xff08;相当于添加iptables的DANT&#xff09; 二、数据卷创建&#xff08;宿主机目录或文件挂载到容器中&#xff09; 三、数据卷容器&#xff08;多个容器通过同一个数据卷容器为基点&#xff0c;实现所有容器数据共享&#xff09; 四、容器互联&am…

Docker容器引擎(2)

目录 一.批量删除镜像&#xff0c;容器 二.Docker 网络实现原理 随机映射端口&#xff08;从32768开始&#xff09; 访问自己&#xff1a; 在10服务器上配置路由转发&#xff1a; 指定映射端口&#xff1a; 查看容器的输出和日志信息&#xff1a; 将宿主机目标|文件挂载…

司铭宇老师:员工心态培训:正确对待工作的七种心态:让你在工作中游刃有余

员工心态培训&#xff1a;正确对待工作的七种心态&#xff1a;让你在工作中游刃有余 工作&#xff0c;是我们生活的核心组成部分。它不仅关乎我们的物质生活&#xff0c;更影响着我们的精神世界。如何在工作中找到平衡&#xff0c;实现自我价值&#xff0c;成为许多人心中的困…

oracle19.22的patch已发布

2024年01月16日,oracle发布了19.22的patch 具体patch如下 Reserved for Database - Do not edit or delete (Doc ID 19202401.9) 文档ID规则如下 19(版本)+年份(202x)+(季度首月01,04,07,10).9 往期patch no信息和下载参考文档 oracle 19C Release Update patch num…

Allegro因为精度问题导致走线未连接上的解决办法

Allegro因为精度问题导致走线未连接上的解决办法。还有在用Allegro进行PCB设计时,因为不小心操作移动了芯片,导致导线和引脚未连接上,也可以使用这个方法。 选择菜单Tools→Derive Connectivity(获取连接) 跳出下面的对话框, Convert Line to Connect Lines

SQL注入实战:宽字节注入

一、宽字节概率 1、单字节字符集:所有的字符都使用一个字节来表示&#xff0c;比如 ASCII编码(0-127) 2、多字节字符集:在多字节字符集中&#xff0c;一部分字符用多个字节来表示&#xff0c;另一部分字符(可能没有)用 单个字节来表示。 3、宽字节注入是利用mysql的一个特性…

人工智能在教学领域中有哪些运用?

人工智能在教学领域中可以有以下几种运用&#xff1a; 1. 智能辅助教学&#xff1a;利用人工智能技术开发出智能辅助教学系统&#xff0c;根据学生的学习状态和知识背景&#xff0c;提供个性化的学习路径和推荐的学习资源&#xff0c;帮助学生更好地掌握知识。 2. 自适应评估…

如何使用宝塔面板配置Nginx反向代理WebSocket(wss)

本章教程&#xff0c;主要介绍一下在宝塔面板中如何配置websocket wss的具体过程。 目录 一、添加站点 二、申请证书 三、配置代理 1、增加配置内容 2、代理配置内容 三、注意事项 一、添加站点 二、申请证书 三、配置代理 1、增加配置内容 map $http_upgrade $connection_…
最新文章