【深度学习】pytorch——线性回归

笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~

深度学习专栏链接:
http://t.csdnimg.cn/dscW7

pytorch——线性回归

  • 线性回归简介
  • 公式说明
  • 完整代码
  • 代码解释

线性回归简介

线性回归是一种用于建立特征和目标变量之间线性关系的统计学习方法。它假设特征和目标变量之间存在一个线性的关系,并试图通过拟合最佳的线性函数来预测目标变量。

线性回归模型的一般形式可以表示为:

y = w 0 + w 1 x 1 + w 2 x 2 + … + w n x n y = w_0 + w_1x_1 + w_2x_2 + \ldots + w_nx_n y=w0+w1x1+w2x2++wnxn

其中, y y y 是目标变量(或因变量), x 1 , x 2 , … , x n x_1, x_2, \ldots, x_n x1,x2,,xn 是特征变量(或自变量), w 0 , w 1 , w 2 , … , w n w_0, w_1, w_2, \ldots, w_n w0,w1,w2,,wn 是模型的参数,分别对应截距和各个特征的权重。

线性回归模型的训练过程就是寻找最优的参数 w 0 , w 1 , w 2 , … , w n w_0, w_1, w_2, \ldots, w_n w0,w1,w2,,wn 来使得模型的预测值与实际值之间的差异最小化。

公式说明

以下是代码涉及到的数学公式

  1. 线性回归模型

线性回归模型用于建立特征 x x x 和目标变量 y y y 之间的线性关系。在本代码中,线性回归模型被表示为:

y = w x + b y = wx + b y=wx+b

其中, w w w 是权重(即斜率), b b b 是偏置(即截距), x x x 是输入特征, y y y 是预测值。

  1. 损失函数

损失函数用于衡量模型预测值与实际标签之间的差异。在本代码中,使用的损失函数是均方误差(Mean Squared Error,MSE):

l o s s = 1 2 n ∑ i = 1 n ( y p r e d ( i ) − y ( i ) ) 2 loss = \frac{1}{2n} \sum_{i=1}^{n} (y_{pred}^{(i)} - y^{(i)})^2 loss=2n1i=1n(ypred(i)y(i))2

其中, y p r e d ( i ) y_{pred}^{(i)} ypred(i) 是模型的第 i i i 个样本的预测值, y ( i ) y^{(i)} y(i) 是实际标签, n n n 是样本数量。

  1. 其他运算

代码中还涉及到了矩阵乘法、矩阵转置、元素级别的操作等。例如, x . m m ( w ) x.mm(w) x.mm(w) 表示将输入特征 x x x 与权重 w w w 进行矩阵乘法; x T . m m ( d y _ p r e d ) x^T.mm(dy\_pred) xT.mm(dy_pred) 表示将输入特征 x x x 的转置与梯度 d y _ p r e d dy\_pred dy_pred 进行矩阵乘法。

完整代码

import torch as t
%matplotlib inline
from matplotlib import pyplot as plt
from IPython import display

device = t.device('cpu') #如果你想用gpu,改成t.device('cuda:0')

# 设置随机数种子,保证在不同电脑上运行时下面的输出一致
t.manual_seed(1000) 

def get_fake_data(batch_size=8):
    ''' 产生随机数据:y=x*2+3,加上了一些噪声'''
    x = t.rand(batch_size, 1, device=device) * 5
    y = x * 2 + 3 +  t.randn(batch_size, 1, device=device)
    return x, y

'''
# 产生的x-y分布
x, y = get_fake_data(batch_size=100)
plt.scatter(x.squeeze().cpu().numpy(), y.squeeze().cpu().numpy())
'''


# 随机初始化参数
w = t.rand(1, 1).to(device)
b = t.zeros(1, 1).to(device)

lr =0.02 # 学习率

for ii in range(500):
    x, y = get_fake_data(batch_size=4)
    
    # forward:计算loss
    y_pred = x.mm(w) + b.expand_as(y) 
    loss = 0.5 * (y_pred - y) ** 2 # 均方误差
    loss = loss.mean()
    
    # backward:手动计算梯度
    dloss = 1
    dy_pred = dloss * (y_pred - y)
    
    dw = x.t().mm(dy_pred)
    db = dy_pred.sum()
    
    # 更新参数
    w.sub_(lr * dw)
    b.sub_(lr * db)
    
    if ii%50 ==0:
        # 画图
        display.clear_output(wait=True)
        x = t.arange(0, 6).view(-1, 1)
        y = x.float().mm(w) + b.expand_as(x)
        plt.plot(x.cpu().numpy(), y.cpu().numpy(),color='b') # predicted
        
        x2, y2 = get_fake_data(batch_size=100) 
        plt.scatter(x2.numpy(), y2.numpy(),color='r') # true data
        
        plt.xlim(0, 5)
        plt.ylim(0, 15)
        plt.show()
        plt.pause(0.5)
        
print('w: ', w.item(), 'b: ', b.item())

输出结果为:
在这里插入图片描述
w: 1.9709817171096802 b: 3.1699466705322266

代码解释

  1. 导入需要的库:
import torch as t
%matplotlib inline
from matplotlib import pyplot as plt
from IPython import display

导入PyTorch库以及绘图相关的库,%matplotlib inline是Jupyter Notebook中的魔法命令,用于在Notebook中显示绘图。

  1. 设置随机数种子:
t.manual_seed(1000)

这行代码设置随机数种子,保证每次运行结果的随机数生成过程一致。

  1. 定义生成随机数据的函数:
def get_fake_data(batch_size=8):
    ''' 产生随机数据:y=x*2+3,加上了一些噪声'''
    x = t.rand(batch_size, 1, device=device) * 5
    y = x * 2 + 3 +  t.randn(batch_size, 1, device=device)
    return x, y

该函数用于产生随机的输入特征x和对应的标签y,其中y满足线性关系y = x * 2 + 3,并添加了一些随机噪声。

  1. 初始化模型参数:
w = t.rand(1, 1).to(device)
b = t.zeros(1, 1).to(device)

这里使用随机数初始化模型参数wb,并指定在CPU上进行计算。

  1. 设置学习率:
lr = 0.02

学习率lr控制每次参数更新的步长。

  1. 进行模型训练:
for ii in range(500):
    # 生成随机数据
    x, y = get_fake_data(batch_size=4)
    
    # forward:计算损失
    y_pred = x.mm(w) + b.expand_as(y)
    loss = 0.5 * (y_pred - y) ** 2
    loss = loss.mean()
    
    # backward:手动计算梯度
    dloss = 1
    dy_pred = dloss * (y_pred - y)
    
    dw = x.t().mm(dy_pred)
    db = dy_pred.sum()
    
    # 更新参数
    w.sub_(lr * dw)
    b.sub_(lr * db)

这里使用一个循环进行模型的训练,每次迭代都包含以下步骤:

  • 生成随机数据;
  • 前向传播:计算预测值y_pred和损失函数loss
  • 反向传播:手动计算梯度dwdb
  • 更新参数:根据梯度和学习率更新参数wb
  1. 可视化模型训练过程:
if ii % 50 == 0:
    display.clear_output(wait=True)
    x = t.arange(0, 6).view(-1, 1)
    y = x.float().mm(w) + b.expand_as(x)
    plt.plot(x.cpu().numpy(), y.cpu().numpy(), color='b') # predicted line
    
    x2, y2 = get_fake_data(batch_size=100)
    plt.scatter(x2.numpy(), y2.numpy(), color='r') # true data
    
    plt.xlim(0, 5)
    plt.ylim(0, 15)
    plt.show()
    plt.pause(0.5)

这部分代码用于可视化模型训练的过程,每50次迭代将当前参数下的预测结果以蓝色线条的形式绘制出来,并将随机生成的100个样本以红色散点图显示出来。

  1. 输出最终训练得到的参数:
print('w: ', w.item(), 'b: ', b.item())

输出训练得到的参数wb的值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/116583.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JavaScript处理字符串

字符串(String)是不可变的、有限数量的字符序列,字符包括可见字符、不可见字符和转义字符。在程序设计中,经常需要处理字符串,如复制、替换、连接、比较、查找、截取、分割等。在JavaScript中,字符串是一类简单值,直接…

NLP之Bert多分类实现案例(数据获取与处理)

文章目录 1. 代码解读1.1 代码展示1.2 流程介绍1.3 debug的方式逐行介绍 3. 知识点 1. 代码解读 1.1 代码展示 import json import numpy as np from tqdm import tqdmbert_model "bert-base-chinese"from transformers import AutoTokenizertokenizer AutoToken…

AI:57-基于机器学习的番茄叶部病害图像识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

体验SOLIDWORKS钣金切口工具增强 硕迪科技

在工业生产制造中,钣金加工是一种常用的加工方式,在SOLIDWORKS2024新版本中,钣金切口工具再次增强了,从SOLIDWORKS 2024 开始, 您可以使用切口工具在空心或薄壁圆柱体和圆锥体中生成切口。 只需在现有空心或薄壁圆柱体…

每天五分钟计算机视觉:搭建手写字体识别的卷积神经网络

本文重点 我们学习了卷积神经网络中的卷积层和池化层,这二者都是卷积神经网络中不可缺少的元素,本例中我们将搭建一个卷积神经网络完成手写字体识别。 卷积和池化的直观体现 手写字体识别 手写字体的图片大小是32*32*3的,它是一张 RGB 模式的图片,现在我们想识别它是从 …

Leetcode刷题详解——求根节点到叶节点数字之和

1. 题目链接:129. 求根节点到叶节点数字之和 2. 题目描述: 给你一个二叉树的根节点 root ,树中每个节点都存放有一个 0 到 9 之间的数字。 每条从根节点到叶节点的路径都代表一个数字: 例如,从根节点到叶节点的路径 1…

软通杯算法竞赛--周赛题目(一)

目录 一、S属性大爆发 二、日期杯 三、 三人行必由我师 四、集合之差 五、咱们计算机不懂烷烃 六、适度跑步健康长寿 一、S属性大爆发 测试用例 5 esS qwert codeforces PoSgju LkkJKkO 输出案例 二、日期杯 输入案例: 3 2022 2022 11 1900 2100 15 1989 20…

Java继承:抽取相同共性,实现代码复用

👑专栏内容:Java⛪个人主页:子夜的星的主页💕座右铭:前路未远,步履不停 目录 一、继承的概念二、继承的语法三、父类成员访问1、子类中访问父类成员变量Ⅰ、子类和父类不存在同名成员变量Ⅱ、子类和父类成员…

Zabbix监控联想服务器的配置方法

简介 图片 随着科技的发展,对于数据的敏感和安全大部分取决于对硬件性能、故障预判的监测,由此可见实时监测保障硬件的安全很重要,从而衍生了很多对硬件的监测软件,Zabbix就一个不错的选择。开源 开源 开源! zabbix是…

树结构及其算法-二叉运算树

目录 树结构及其算法-二叉运算树 C代码 树结构及其算法-二叉运算树 二叉树的应用实际上相当广泛,例如表达式之间的转换。可以把中序表达式按运算符优先级的顺序建成一棵二叉运算树(Binary Expression Tree,或称为二叉表达式树)…

【文生图】Stable Diffusion XL 1.0模型Full Fine-tuning指南(U-Net全参微调)

文章目录 前言重要教程链接以海报生成微调为例总体流程数据获取POSTER-TEXTAutoPosterCGL-DatasetPKU PosterLayoutPosterT80KMovie & TV Series & Anime Posters 数据清洗与标注模型训练模型评估生成图片样例宠物包商品海报护肤精华商品海报 一些TipsMata:…

UUNet训练自己写的网络

记录贴写的很乱仅供参考。 自己写的Unet网络不带深度监督,但是NNUNet默认的训练方法是深度监督训练的,对应的模型也是带有深度监督的。但是NNUNetV2也贴心的提供了非深度监督的训练方法在该目录下: 也或者说我们想要自己去定义一个nnUNWtTra…

使用自定义函数拟合辨识HPPC工况下的电池数据(适用于一阶RC、二阶RC等电池模型)

该程序可以离线辨识HPPC工况下的电池数据,只需要批量导入不同SOC所对应的脉冲电流电压数据,就可以瞬间获得SOC为[100% 90% 80% 70% 60% 50% 40% 30% 20% 10% 0%]的所有电池参数,迅速得到参数辨识的结果并具有更高的精度,可以很大程度上降低参…

C++对象模型

思考:对于实现平面一个点的参数化。C的class封装看起来比C的struct更加的复杂,是否意味着产生更多的开销呢? 实际上并没有,类的封装不会产生额外的开销,其实,C中在布局以及存取上的额外开销是virtual引起的…

MySQL 表的增删查改(CRUD)

MySQL 表的增删查改(CRUD) 文章目录 MySQL 表的增删查改(CRUD)1. 新增(Create)2. 查询(Retrieve)2.1 全列查询2.2 指定列查询2.3 查询字段为表达式2.4 别名2.5 去重:DISTINCT2.6 排序:ORDER BY2.7 条件查询2.8 分页查询: LIMIT 3. 修改(Update)4. 删除(D…

vuepress使用及拓展(骚操作)

官网 文章目录 背景问题思考方案思索实现方案实现结果存在问题 背景 当前开放平台文件静态保存在前端项目,每次修改都需要通过修改文件发版的方式,很不便利。 1、需要前端手动维护 2、每次小的修改都要发版 随着对接业务的增多,对接文档的变…

第8章_聚合函数

文章目录 1 聚合函数介绍1.1 AVG和SUM函数1.2 MIN和Max函数1.3 COUNT函数演示代码 2 GROUP BY2.1 基本使用2.2 使用多个列分组2.3 演示代码 3 HAVING3.1 基本使用3.2 WHERE和HAVING的对比3.3 演示代码 4 SELECT的执行过程4.1 查询的结构4.2 SELECT执行顺序4.3 SQL的执行原理演示…

4 Tensorflow图像识别模型——数据预处理

上一篇:3 tensorflow构建模型详解-CSDN博客 本篇开始介绍识别猫狗图片的模型,内容较多,会分为多个章节介绍。模型构建还是和之前一样的流程: 数据集准备数据预处理创建模型设置损失函数和优化器训练模型 本篇先介绍数据集准备&am…

newstarctf2022week2

Word-For-You(2 Gen) 和week1 的界面一样不过当时我写题的时候出了个小插曲 连接 MySQL 失败: Access denied for user rootlocalhost 这句话印在了背景,后来再进就没了,我猜测是报错注入 想办法传参 可以看到一个name2,试着传参 发现有回显三个字段…

【CMU15445】Fall 2019, Project 3: Query Execution 实验记录

目录 实验准备实验测试Task 1: CREATING A CATALOG TABLE SQL 执行是由数据库解析器转化为一个由多个 executor 组成的 Query Plan 来完成的,本实验选择了火山模型来完成 query execution,这一次的 project 就是实现各种 exeutor,从而可以通过…