机器学习复习(2)——线性回归SGD优化算法

目录

线性回归代码

线性回归理论

SGD算法

手撕线性回归算法

模型初始化

定义模型主体部分

定义线性回归模型训练过程

数据demo准备

模型训练与权重参数

定义线性回归预测函数

定义R2系数计算

可视化展示 

预测结果

训练过程 

sklearn进行机器学习

线性回归代码

class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        # 矩阵的维度(dimensions) 
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x

线性回归理论

回归算法是相对分类算法而言的,与我们想要预测的目标变量y的值类型有关。

如果目标变量y是分类型变量,如预测用户的性别(男、女),预测月季花的颜色(红、白、黄……),那我们就需要用分类算法去拟合训练数据并做出预测;

如果y是连续型变量,如预测用户的收入(4千,2万,10万……),预测患肺癌的概率(1%,50%,99%……),我们则需要用回归模型。

有时分类问题也可以转化为回归问题。可以用回归模型先预测出患肺癌的概率,然后再给定一个阈值,例如50%,概率值在50%以下为A类,50%以上为B类。

一元线性回归公式:

 具象化含义:

SGD算法

手撕线性回归算法

模型初始化

### 初始化模型参数
def initialize_params(dims):
    '''
    输入:
    dims:训练数据变量维度
    输出:
    w:初始化权重参数值
    b:初始化偏差参数值
    '''
    # 初始化权重参数为零矩阵
    w = np.zeros((dims, 1))
    # 初始化偏差参数为零
    b = 0
    return w, b
w,b=initialize_params(3)#用于测试
print("w初始化是",w)
print("b初始化是",b)

运行结果:

定义模型主体部分

包括线性回归公式、均方损失和参数偏导三部分
def linear_loss(X, y, w, b):
    '''
    输入:
    X:输入变量矩阵
    y:输出标签向量
    w:变量参数权重矩阵
    b:偏差项
    输出:
    y_hat:线性模型预测输出
    loss:均方损失值
    dw:权重参数一阶偏导
    db:偏差项一阶偏导
    '''
    # 训练样本数量
    num_train = X.shape[0]
    # 训练特征数量
    num_feature = X.shape[1]
    # 线性回归预测输出
    y_hat = np.dot(X, w) + b
    # 计算预测输出与实际标签之间的均方损失
    loss = np.sum((y_hat-y)**2)/num_train
    # 基于均方损失对权重参数的一阶偏导数
    dw = np.dot(X.T, (y_hat-y)) /num_train
    # 基于均方损失对偏差项的一阶偏导数
    db = np.sum((y_hat-y)) /num_train
    return y_hat, loss, dw, db

定义线性回归模型训练过程

### 定义线性回归模型训练过程
def linear_train(X, y, learning_rate=0.01, epochs=10000):
    '''
    输入:
    X:输入变量矩阵
    y:输出标签向量
    learning_rate:学习率
    epochs:训练迭代次数
    输出:
    loss_his:每次迭代的均方损失
    params:优化后的参数字典
    grads:优化后的参数梯度字典
    '''
    # 记录训练损失的空列表
    loss_his = []
    # 初始化模型参数
    w, b = initialize_params(X.shape[1])
    # 迭代训练
    for i in range(1, epochs):
        # 计算当前迭代的预测值、损失和梯度
        y_hat, loss, dw, db = linear_loss(X, y, w, b)
#y_hat是预测值,loss是损失,dw是权重参数一阶偏导,db是偏差项一阶偏导
        # 基于梯度下降的参数更新
        w += -learning_rate * dw
        b += -learning_rate * db
        # 记录当前迭代的损失
        loss_his.append(loss)
        # 每1000次迭代打印当前损失信息
        if i % 10000 == 0:
            print('epoch %d loss %f' % (i, loss))
        # 将当前迭代步优化后的参数保存到字典
        params = {
            'w': w,
            'b': b
        }
        # 将当前迭代步的梯度保存到字典
        grads = {
            'dw': dw,
            'db': db
        }     
    return loss_his, params, grads

其中的shape操作说明:

import numpy as np
# 创建一个示例的训练数据集 X
X = np.array([[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9],
              [10, 11, 12],
              [13, 14, 15]])
# 计算训练样本数量
shape0 = X.shape[0]
shape1 = X.shape[1]
print("shape0是",shape0)
print("shape1是",shape1)

运行结果:

数据demo准备

from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
data = diabetes.data
target = diabetes.target 
print(data.shape)
print(target.shape)
print(data[:5])
print(target[:5])
###########################################
# 导入sklearn diabetes数据接口
from sklearn.datasets import load_diabetes
# 导入sklearn打乱数据函数
from sklearn.utils import shuffle
# 获取diabetes数据集
diabetes = load_diabetes()
# 获取输入和标签
data, target = diabetes.data, diabetes.target 
# 打乱数据集
X, y = shuffle(data, target, random_state=13)
# 按照8/2划分训练集和测试集
offset = int(X.shape[0] * 0.8)
# 训练集
X_train, y_train = X[:offset], y[:offset]
# 测试集
X_test, y_test = X[offset:], y[offset:]
# 将训练集改为列向量的形式
y_train = y_train.reshape((-1,1))
# 将验证集改为列向量的形式
y_test = y_test.reshape((-1,1))
# 打印训练集和测试集维度
print("X_train's shape: ", X_train.shape)
print("X_test's shape: ", X_test.shape)
print("y_train's shape: ", y_train.shape)
print("y_test's shape: ", y_test.shape)

模型训练与权重参数

# 线性回归模型训练
loss_his, params, grads = linear_train(X_train, y_train, 0.01, 200000)
# 打印训练后得到模型参数
print(params)

定义线性回归预测函数

### 定义线性回归预测函数
def predict(X, params):
    '''
    输入:
    X:测试数据集
    params:模型训练参数
    输出:
    y_pred:模型预测结果
    '''
    # 获取模型参数
    w = params['w']
    b = params['b']
    # 预测
    y_pred = np.dot(X, w) + b
    return y_pred
# 基于测试集的预测
y_pred = predict(X_test, params)
# 打印前五个预测值
y_pred[:5]

定义R2系数计算

R2系数,也称为决定系数(Coefficient of Determination),是一种用于评估回归模型拟合优度的统计指标。它表示模型对观测数据的方差解释比例,通常用于衡量回归模型的拟合程度。

R2系数的取值范围在0到1之间,具体含义如下:

  • 如果R2等于0,表示模型未能解释目标变量的任何方差,即模型无法拟合数据。
  • 如果R2等于1,表示模型完美拟合了数据,能够解释目标变量的所有方差。
  • 如果R2在0和1之间,表示模型能够解释一部分目标变量的方差,数值越接近1,说明模型的拟合程度越好。

计算公式如下:

其中:

  • SSR(Sum of Squares of Residuals)表示模型的残差平方和,即实际观测值与模型预测值之间的差异的平方和。
  • SST(Total Sum of Squares)表示总平方和,即实际观测值与观测值的均值之间的差异的平方和。

R2系数越接近1,说明模型对数据的拟合越好,而越接近0则表示模型的拟合效果较差。这个指标对于评估回归模型的性能非常有用,帮助我们了解模型解释数据方差的程度。

### 定义R2系数函数
def r2_score(y_test, y_pred):
    '''
    输入:
    y_test:测试集标签值
    y_pred:测试集预测值
    输出:
    r2:R2系数
    '''
    # 测试标签均值
    y_avg = np.mean(y_test)
    # 总离差平方和
    ss_tot = np.sum((y_test - y_avg)**2)
    # 残差平方和
    ss_res = np.sum((y_test - y_pred)**2)
    # R2计算
    r2 = 1 - (ss_res/ss_tot)
    return r2

可视化展示 

预测结果

import matplotlib.pyplot as plt
f = X_test.dot(params['w']) + params['b']

plt.scatter(range(X_test.shape[0]), y_test)
plt.plot(f, color = 'darkorange')
plt.xlabel('X_test')
plt.ylabel('y_test')
plt.show();

运行结果:

训练过程 

plt.plot(loss_his, color='blue')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()

运行结果:

sklearn进行机器学习

 和torch.nn类似:封装好了linear函数,直接掉包

### sklearn版本为1.0.2
# 导入线性回归模块
from sklearn import linear_model
from sklearn.metrics import mean_squared_error, r2_score
# 创建模型实例
regr = linear_model.LinearRegression()
# 模型拟合
regr.fit(X_train, y_train)
# 模型预测
y_pred = regr.predict(X_test)
# 打印模型均方误差
print("Mean squared error: %.2f" % mean_squared_error(y_test, y_pred))
# 打印R2
print('R2 score: %.2f' % r2_score(y_test, y_pred))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/367832.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CSC联合培养博士申请亲历|联系外导的详细过程

在CSC申报的各环节中,联系外导获得邀请函是关键步骤。这位联培博士同学的这篇文章,非常详细且真实地记录了申请过程、心理感受,并提出有益的建议,小编特推荐给大家参考。 2024年国家留学基金委公派留学项目即将开始,其…

网络原理TCP/IP(2)

文章目录 TCP协议确认应答超时重传连接管理断开连接 TCP协议 TCP全称为"传输控制协议(Transmission Control Protocol").⼈如其名,要对数据的传输进⾏⼀个详细 的控制; TCP协议段格式 • 源/目的端口号:表⽰数据是从哪个进程来,到哪个进程去; • 32位序号/32位确认…

会声会影下载 Corel VideoStudio 2023 v26.1.0.268中文激活版

会声会影Corel VideoStudio 2023破解版是领先的视频编辑和转换软件!提供直观友好的功能,让用户能够更快速便捷地制作独特的视频,高质量的效果,各种滤镜、贴纸、过渡、模板等都将让您事半功倍!软件允许您导入自己的剪辑…

社区店加盟:如何选择适合的品牌和项目?

在当下创业热潮中,社区店加盟成为了许多创业者的首选。特别是鲜奶吧这种深受各年龄段人群喜爱的项目,更是备受关注。然而,面对众多品牌和项目,如何选择适合自己的社区店加盟品牌和项目呢? 作为一位资深的鲜奶吧创业者…

关于node.js奇数版本不稳定 将11.x.x升级至16.x.x不成功的一系列问题(一)

据说vue2用16稳定一些 vue3用18好一点(但之前我vue3用的16.18.1也可以) 为维护之前的老项目 先搞定node版本切换 下载nvm node版本管理工具 https://github.com/coreybutler/nvm-windows/releases 用这个nvm-setup.zip安装包 安之前最好先将之前的nod…

西瓜书学习笔记——k近邻学习(公式推导+举例应用)

文章目录 算法介绍实验分析 算法介绍 K最近邻(K-Nearest Neighbors,KNN)是一种常用的监督学习算法,用于分类和回归任务。该算法基于一个简单的思想:如果一个样本在特征空间中的 k k k个最近邻居中的大多数属于某个类别…

git命令上传本地项目到远程仓库的悲惨遭遇

git命令上传本地项目到远程仓库的悲惨遭遇。我想把前端后端合并到一个仓库下2个分支,结果呢,不仅合并没有成功,还把代码丢失了。 如图,原始我写好了完整的后端代码,都丢失了。 远程仓库里也都没有了。奇怪了。 难道远…

二、Java学习 数据类型与变量

目录 一、字面常量 二、数据类型 三、变量 语法格式 四、类型转换 隐式类型转换 强制类型转换 字符串类型 五、类型提升 1.int与long 2.byte与byte 小结 一、字面常量 常量即运行期间,固定不变的量。 字面常量的分类: 1.字符串常量&#xff…

TQ15EG开发板教程:开发板资源介绍

时钟资源 采用时钟芯片CDCM6208提供系统时钟 PL端时钟 PS 收发器时钟 PL收发器时钟 电源 BANK500 BANK501 BANK502 BANK503(专用) 1.8V 1.8V 1.8V 1.8V PS端外设 QSPI 采用2片MT25QU256 拼接成8bit的QSPI存储系统。采用1.8V供电 SD卡 SATA接口 PS端以太网接口 D…

Java宝典-数据类型

目录 1.变量与常量2.Java中的数据类型3.整型3.1 字节型byte3.2 短整型short3.3 整型int3.4 长整型long 4.浮点型4.1 单精度浮点型float4.2 双精度浮点型double 5.字符型6.布尔型7.类型转换7.1 隐式类型转换7.2 显示类型转换(强制类型转换) 8.类型提升 大家好,我是你们的Vampire…

了解UDP发送过快导致的问题和对应解决方案

在当今这个以数据为核心的时代,企业对于数据传输的速度和稳定性有着日益增长的需求。UDP凭借其低延迟和高效率的特性,在实时通信和大规模数据传输领域扮演着关键角色。然而,UDP的无连接特性和缺乏可靠性也给数据传输带来了挑战,尤…

【python错误】Pytorch1.9 ImportError: cannot import name ‘zero_gradients‘

错误:Pytorch1.9 ImportError: cannot import name ‘zero_gradients’ 错误提示: ImportError: cannot import name ‘zero_gradients’ from ‘torch.autograd.gradcheck’ (/root/miniconda3/envs/d2l/lib/python3.9/site-packages/torch/autograd/g…

3种JWT验证和续签的策略

3 种JWT验证和续签的策略 好文推荐:一文教你搞定所有前端鉴权与后端鉴权方案,让你不再迷惘 - 掘金 (juejin.cn) 3 种jwt 验证的策略 通过解析去验证:每次访问api时parse jwt 判断是否vaild jwt有效 正常调用api jwt无效 返回401 缺点&a…

AVR 328pb串口基本介绍和使用

AVR 328pb串口基本介绍和使用 📍相关篇《AVR 328pb定时器0基本介绍和使用》 🔖基于Atmel Studio 7.0开发环境。 📍结合参考同架构lgt8f328p中文文档:http://www.prodesign.com.cn/wp-content/uploads/2023/03/LGT8FX8P_databook…

北朝隋唐文物展亮相广西,文物预防性保护网关保驾护航

一、霸府名都——太原博物馆收藏北朝隋朝文物展 2月1日,广西民族博物馆与太原博物馆携手,盛大开启“霸府名都——太原博物馆北朝隋文物展”。此次新春展览精选了北朝隋唐时期150多件晋阳文物珍品。依据“巍巍雄镇”“惊世古冢”“锦绣名都”三个单元&am…

多线程编程6——使用 volatile 解决问题可见性问题

一、内存可见性问题 内存可见性问题是出现线程安全问题的原因之一。 1、什么是内存可见性问题? 一个线程针对一个变量进行读取操作,另一个线程针对这个变量进行修改操作,此时读到的值不一定是修改后的值,出现了线程安全问题&a…

学习Android的第三天

目录 Android LinearLayout 线性布局 XML 属性 LinearLayout 几个重要的 XML 属性 LinearLayout.LayoutParams XML 属性 divider (分割线) Android RelativeLayout 相对布局 RelativeLayout 布局属性 TableLayout ( 表格布局 ) TableRow 子控件的主要属性 Android Lin…

爬虫入门到精通_基础篇4(BeautifulSoup库_解析库,基本使用,标签选择器,标准选择器,CSS选择器)

1 Beautiful说明 BeautifulSoup库是灵活又方便的网页解析库,处理高效,支持多种解析器。利用它不用编写正则表达式即可方便地实线网页信息的提取。 安装 pip3 install beautifulsoup4解析库 解析器使用方法优势劣势Python标准库BeautifulSoup(markup,…

ADB的配置和使用及刷机root

ADB的配置和使用 ADB即Android Debug Bridge,安卓调试桥,是谷歌为安卓开发者提供的开发工具之一,可以让你的电脑以指令窗口的方式控制手机。可以在安卓开发者网页中的 SDK 平台工具页面下直接下载对应系统的 adb 配置文件,大小只…

05、全文检索 -- Solr -- Solr 全文检索之图形界面的文档管理(文档的添加、删除,如何通过关键字等参数查询文档)

目录 Solr 全文检索之文档管理添加文档使用 JSON 添加文档:使用 XML 添加文档: 删除文档使用 JSON 删除文档:使用 XML 删除文档: 查询文档查询文档的详细参数fq(Filter Query):过滤sort:排序sta…