动手学深度学习(二)线性神经网络

推荐课程:跟李沐学AI的个人空间-跟李沐学AI个人主页-哔哩哔哩视频

回归任务是指对连续变量进行预测的任务。

一、线性回归

线性回归模型是一种常用的统计学习方法,用于分析自变量与因变量之间的关系。它通过建立一个关于自变量和因变量的线性方程,来对未知数据进行预测。

1.1 线性模型

举个例子,房价预测模型

  • 假设1︰影响房价的关键因素是卧室个数,卫生间个数和居住面积,记为x1,x2,x3。
  • 假设2:成交价是关键因素的加权和,y = w_1x_1 + w_2x_2 + w_3x_3 + b

权重w和偏差b的实际值在后面决定。

  • 给定n维输入,x=[x_1,x_2, ....x_n]^T,向量x对应于单个数据样本的特征
  • 线性模型有一个n维权重和一个标量偏差,w =[w_1, w_2, ..., w_n]^Tb权重w决定了每个特征对预测值的影响。偏置b是指当所有的特征都取0时,预测值应为多少。
  • 输出是输入的加权和,\hat{y} = w_1x_1+w_2x_2+ ...+ w_nx_n + b。我们常用\hat{y}表示预测值

则,该房价预测模型为:\hat{y} = w^Tx+ b,这是一个线性预测模型。给定一个数据集(如x),我们的目标就是寻找模型的权重w和偏置b,使得根据模型做出的预测大体符合数据中真实价格y。也是就说最佳的权重w和偏置b有能力使得预测值\hat{y}逼近真实值y,找到最佳的权重w和偏置b这是我们的最终目的。

1.2 损失函数(衡量预估质量)

用于比较真实值和预估值的差异,即以特定规则计算真实值和预估值的差值,例如房屋售价和估价。

假设y是真实值,\hat{y}是预测值,平方差损失\ell(y,\hat{y})=(y-\hat{y})^2,我们以该函数作为损失函数。

设训练集有n个样本,则这n个样本的损失均值

             L(w, b)=\frac{1}{n}\sum_{n}^{i=1}\ell^i(y,\hat{y})=\frac{1}{n}\sum_{n}^{i=1}(y_i-\hat{y_i})^2=\frac{1}{n}\sum_{n}^{i=1}(y_i-w^Tx_i+b)^2

Q:那么损失函数,对我们找到最优的权重w和偏置b有什么帮助呢?

我们可以看到,最佳的预测值与真实值之间的损失值一定是尽可能小的,因此我们只要求得最小的损失值,那么得到这个损失值的权重w和偏置b一定是最优的。

Q:怎么求得最小的损失值呢?

如,平方差损失函数是一个凹函数,那么求解最小的损失值,我们只需要将该函数关于w的偏导数设为0,求导即可。求解得到的w就是最优的权重w。预测出的预估值\hat{y}也就最接近真实值。这类解称为解析解。

二、基础优化算法(梯度下降算法)

在绝大多数的情况下,损失函数是很复杂的(比如逻辑回归),根本无法得到参数估计值的表达式,也就无从获取没有显示解(解析解)

此需要一种对大多数函数都适用的方法,这就引出了“梯度下降算法”,这种方法几乎可以优化所有深度学习模型它通过不断地在损失函数递减的方向上更新参数来降低误差(原理)

2.1 梯度下降公式

首先,我们需要确定初始化模型的参数w_0,接下来重复迭代更新参数t=1、2、3、....、n,更新权重的公式为:

其中,\textup{w}_{t-1}为上一次更新权重的结果,\eta为学习率(这是一个超参数,决定了每次参数更新的步长),\frac{\partial \ell }{\partial \textup{w}_{t-1}}损失函数递增的方向(注意公式中为负)。

2.2 选择学习率

梯度下降的过程宛如一个人在走下山路,一步一步地接近谷底,学习率相当于这个人的步长

学习率的选取不易过大,也不宜过小。学习率选取过大会使得权重更新的过程一直在震荡,而不是真正的在下降。学习率选取过小,会使得权重更新的过程十分缓慢,影响效率。

2.3 小批量随机梯度下降

一个神经网络模型的训练可能需要几分钟至数个小时,我们可以采用小批量随机梯度下降的方式来加快这一过程。

在整个训练集上计算梯度太昂贵了,因此可以随机采用 b 个样本i_1,i_2,...,i_b来求取整个训练集的近似损失(原理)。求近似损失公式为:

 其中,b批量大小,另一个重要的超参数。

Q:如何选择批量大小?

选择批量大小不能太小,也不能太大。批量大小选择过小,则每次计算量太小,不适合并行来最大利用计算资源。批量大小选择过大,内存消耗增加浪费计算,例如如果所有样本都是相同的。

三、线性回归的从零开始实现(代码实现)

3.1 生成数据集

首先,我们根据带有噪声的线性模型构造一个人造数据集,我们的目的是通过这个数据集来还原线性模型中正确的参数。

我们使用线性模型参数 \textup{w}=[2,-3.4]^Tb = 4.2​ 和噪声项 \varepsilon 生成数据集及其标签。

# 生成数据集
def synthetic_data(w, b, num_examples):
    """生成 y=Xw + b + 噪声"""
    X = torch.normal(0, 1, (num_examples, len(w))) # 正态分布(均值为0,标准差为1)
    y = torch.matmul(X, w) + b # 矩阵相乘
    y += torch.normal(0, 0.01, y.shape) # 加入噪声项
    # 得到的y为行向量的形式,为了使其变为一列的形式需要进行reshape
    return X, y.reshape((-1, 1))

3.2 传输数据集

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # 这些样本是随机读出的,没有特定的顺序
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i:min(i+batch_size,num_examples)])
        
        yield features[batch_indices],labels[batch_indices]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/59455.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Matlab的信号频谱分析——FFT变换

Matlab的信号频谱分析——FFT变换 Matlab的信号频谱分析 FFT是离散傅立叶变换的快速算法,可以将一个时域信号变换到频域。 有些信号在时域上是很难看出什么特征的。但是如果变换到频域之后,就很容易看出特征了。 这就是很多信号分析采用FFT变换的原因…

企业微信小程序在调用wx.qy.login时返回错误信息qy.login:fail

原因是大概是绑定了多个企业但是在开发者工具中没有选择正确的企业 解决方法: 重新选择企业后即可成功获取code

媒介易讲解体育冠军助力品牌解锁市场营销新玩法

在当今竞争激烈的市场中,品牌推广成为企业取得商业成功的重要一环。然而,随着传统市场推广方式的日益饱和,企业急需创新的市场营销策略来吸引消费者的关注和认可。在这样的背景下,体育冠军助力品牌成为了一种备受瞩目的市场营销新…

Android Jetpack

Jetpack 是一个由多个库组成的套件,可帮助开发者遵循最佳实践、减少样板代码并编写可在各种 Android 版本和设备中一致运行的代码,让开发者可将精力集中于真正重要的编码工作。 1.基础组件 (1)AppCompat:使得支持较低…

SharpShooter Gauges.Professional Crack

SharpShooter Gauges.Professional Crack 将真实仪器硬件的外观添加到.NET应用程序中。 SharpShooter Gauges旨在创建和使用任何行业特定的KPI、计算机辅助培训系统、SCADA系统、数字仪表板和任何其他用于显示和监控关键任务数据的应用程序。为了实现所有这些目的,一…

概率论与数理统计:第一章:随机事件及其概率

文章目录 概率论Ch1. 随机事件及其概率1.基本概念(1)随机试验、随机事件、样本空间(2)事件的关系和运算①定义:互斥(互不相容)、对立②运算法则:德摩根率 (3)概率的定义(4)概率的性质(5)概率计算排列组合 2.等可能概型1.古典概型 (离散)2.几何概型 (连续…

glut太阳系源码修改和对cpu占用观察

#include <GL/glut.h> static int day 100; // day 的变化&#xff1a;从 0 到 359 void myDisplay(void) {//glEnable(GL_DEPTH_TEST);glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);glMatrixMode(GL_PROJECTION);glLoadIdentity();gluPerspective(75, 1, 1, 40…

Docker-Compose编排与部署(lnmp实例)

第四阶段 时 间&#xff1a;2023年8月3日 参加人&#xff1a;全班人员 内 容&#xff1a; Docker-Compose编排与部署 目录 一、Docker Compose &#xff08;一&#xff09;概述 &#xff08;二&#xff09;Compose适用于所有环境&#xff1a; &#xff08;三&#xf…

Java阶段五Day16

Java阶段五Day16 文章目录 Java阶段五Day16问题解析启动servlet冲突问题nacos注册中心用户信息验证失败前端效果不对前端请求到后台服务的流转过程 远程dubbo调用业务需求dubbo配置xml配置domain层代码 补充远程调用 师傅详情接口抽象开发WorkderServerControllerWorkerServerS…

51单片机学习-AT24C02数据存储秒表(定时器扫描按键数码管)

首先编写I2C模块&#xff0c;根据下面的原理图进行位声明&#xff1a; sbit I2C_SCL P2^1; sbit I2C_SDA P2^0;再根据下面的时序结构图编写函数&#xff1a; /*** brief I2C开始* param 无* retval 无*/ void I2C_Start(void) {I2C_SDA 1; I2C_SCL 1; I2C_SDA 0;I2C_S…

redis原理 6:小道消息 —— PubSub

前面我们讲了 Redis 消息队列的使用方法&#xff0c;但是没有提到 Redis 消息队列的不足之处&#xff0c;那就是它不支持消息的多播机制。 img 消息多播 消息多播允许生产者生产一次消息&#xff0c;中间件负责将消息复制到多个消息队列&#xff0c;每个消息队列由相应的消费组…

【Leetcode刷题】位运算

本篇文章为 LeetCode 位运算模块的刷题笔记&#xff0c;仅供参考。 位运算的常用性质如下&#xff1a; a ^ a 0 a ^ 0 a a ^ 0xFFFFFFFF ~a目录 一. 基本位运算Leetcode29.两数相除Leetcode89.格雷编码 二. 位运算的性质Leetcode136.只出现一次的数字Leetcode137.只出现一…

好用的数据库管理软件之idea(idea也有数据库???)

1.建立maven项目&#xff08;maven项目添加依赖&#xff0c;对于后期连接数据库很方便&#xff09; 2.连接数据库。。。 这里一定注意端口号&#xff0c;不要搞错了 和上一张图片不一样哦 3.数据库测试代码。。。 然后你就可以在这里边写MySQL代码了&#xff0c;这个工具对于新…

RunnerGo条件控制器使用方法

在做性能测试时我们需要根据业务需求、业务场景来配置测试脚本&#xff0c;举个例子&#xff1a;在登录注册场景中&#xff0c;可能会有账号密码全部正确、账号格式错误、密码错误等多种情况&#xff0c;这里的“登录/注册”事件可以视为一个场景。一个真实业务中的场景&#x…

人工智能学习07--pytorch23--目标检测:Deformable-DETR训练自己的数据集

参考 https://blog.csdn.net/qq_44808827/article/details/125326909https://blog.csdn.net/dystsp/article/details/125949720?utm_mediumdistribute.pc_relevant.none-task-blog-2~default~baidujs_baidulandingword~default-0-125949720-blog-125326909.235^v38^pc_releva…

【Python】Web学习笔记_flask(3)——上传文件

用GET、POST请求上传图片并呈现出来 首先还是创建文件上传的模板 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>上传图片</title> </head> <body> <form action""…

MySQL 远程操作mysql

可以让别人在他们的电脑上操作我电脑上的数据库 create user admin identified with mysql_native_password by admin; //设置账号密码都为admingrant all on *.* to admin; //给admin账号授权 授权完成

直播课 | 大橡科技研发总监丁端尘博士“类器官芯片技术在新药研发中的应用”

从类器官到类器官芯片&#xff0c;正在生物科学领域大放异彩。 药物研发需要新方法 众所周知&#xff0c;一款新药是一个风险大、周期长、成本高的艰难历程&#xff0c;国际上有一个传统的“双十”说法——10年时间&#xff0c;10亿美金&#xff0c;才可能成功研发出一款新药…

【安全测试】Web应用安全之XSS跨站脚本攻击漏洞

目录 前言 XSS概念及分类 反射型XSS(非持久性XSS) 存储型XSS(持久型XSS) 如何测试XSS漏洞 方法一&#xff1a; 方法二&#xff1a; XSS漏洞修复 原则&#xff1a;不相信客户输入的数据 处理建议 资料获取方法 前言 以前都只是在各类文档中见到过XSS&#xff0c;也进…

微信小程序animation动画,微信小程序animation动画无限循环播放

需求是酱紫的&#xff1a; 页面顶部的喇叭通知&#xff0c;内容不固定&#xff0c;宽度不固定&#xff0c;就是做走马灯&#xff08;轮播&#xff09;效果&#xff0c;从左到右的走马灯&#xff08;轮播&#xff09;&#xff0c;每播放一遍暂停 1500ms &#xff5e; 2000ms 刚…