【深度学习_TensorFlow】梯度下降

写在前面

一直不太理解梯度下降算法是什么意思,今天我们就解开它神秘的面纱


写在中间

线性回归方程


如果要求出一条直线,我们只需知道直线上的两个不重合的点,就可以通过解方程组来求出直线

但是,如果我们选取的这两个点不在直线上,而是存在误差(暂且称作观测误差),这样求出的直线就会和原直线相差很大,我们应该怎样做呢?首先肯定不能只通过两个点,就武断地求出这条直线。

在这里插入图片描述

我们通常尽可能多地使用分布在直线周围的点,也可能不存在一条直线完美的穿过所有采样点。那么,退而求其次,我们希望能找到一条比较“好”的位于采样点中间的直线。那么怎么衡量“好”与“不好”呢?一个很自然的想法就是,求出当前模型的所有采样点上的预测值𝑤𝑥(𝑖) + 𝑏与真实值𝑦(𝑖)之间的差的平方和作为总误差 L \mathcal{L} L,然后搜索一组参数 w ∗ , b ∗ w^{*},b^{*} w,b使得 L \mathcal{L} L最小,对应的直线就是我们要寻找的最优直线。

w ∗ , b ∗ = arg ⁡ min ⁡ w , b 1 n ∑ i = 1 n ( w x ( i ) + b − y ( i ) ) 2 w^*,b^*=\arg\min_{w,b}\frac{1}{n}\sum_{i=1}^{n}\bigl(wx^{(i)}+b-y^{(i)}\bigr)^2 w,b=argminw,bn1i=1n(wx(i)+by(i))2

最后再通过梯度下降法来不断优化参数 w ∗ , b ∗ w^{*},b^{*} w,b

有基础的小伙伴们可能知道求误差的方法其实就是均方误差函数,不懂得可以看这篇文章补充养分《误差函数》 ,我们这篇文章就侧重梯度下降。

梯度下降


函数的梯度定义为函数对各个自变量的偏导数组成的向量。不会的话,翻翻高等数学下册书。

举个例子,对于曲面函数𝑧 = 𝑓(𝑥, 𝑦),函数对自变量𝑥的偏导数记为 ∂ z ∂ x \frac{\partial z}{\partial x} xz,函数对自变量𝑦的偏导数记为 ∂ z ∂ y \frac{\partial z}{\partial y} yz,则梯度∇𝑓为向量 ( ∂ z ∂ x , ∂ z ∂ y ) ({\frac{\partial z}{\partial x}},{\frac{\partial z}{\partial y}}) (xz,yz),梯度的方向总是指向当前位置函数值增速最大的方向,函数曲面越陡峭,梯度的模也越大。

函数在各处的梯度方向∇𝑓总是指向函数值增大的方向,那么梯度的反方向−∇𝑓应指向函数值减少的方向。利用这一性质,我们只需要按照下式来更新参数,,其中𝜂用来缩放梯度向量,一般设置为某较小的值,如 0.01、0.001 等。

x ′ = x − η ⋅ d y d x x'=x-\eta\cdot\frac{\mathrm{d}y}{\mathrm{d}x} x=xηdxdy

结合上面的回归方程,我们就可对误差函数求偏导,以循环的方式更新参数 w , b w,b w,b

w ′ = w − η ∂ L ∂ w b ′ = b − η ∂ L ∂ b \begin{aligned}w'&=w-\eta\frac{\partial\mathcal{L}}{\partial w}\\\\b'&=b-\eta\frac{\partial\mathcal{L}}{\partial b}\end{aligned} wb=wηwL=bηbL

函数实现


计算过程都需要包裹在 with tf.GradientTape() as tape 上下文中,使得前向计算时能够保存计算图信息,方便自动求导操作。通过tape.gradient()函数求得网络参数到梯度信息,结果保存在 grads 列表变量中。

GradientTape()函数

GradientTape(persistent=False, watch_accessed_variables=True)

  • persistent: 布尔值,用来指定新创建的gradient
    tape是否是可持续性的。默认是False,意味着只能够调用一次GradientTape()函数,再次使用会报错

  • watch_accessed_variables:布尔值,表明GradientTape()函数是否会自动追踪任何能被训练的变量。默认是True。要是为False的话,意味着你需要手动去指定你想追踪的那些变量。

tape.watch()函数

tape.watch()用于跟踪指定类型的tensor变量。

  • 由于GradientTape()默认只对tf.Variable类型的变量进行监控。如果需要监控的变量是tensor类型,则需要tape.watch()来监控,否则输出结果将是None

tape.gradient()函数

tape.gradient(target, source)

  • target:求导的因变量

  • source:求导的自变量

import tensorflow as tf

w = tf.constant(1.)
x = tf.constant(2.)
y = x * w

with tf.GradientTape() as tape:
    tape.watch([w])
    y = x * w

grads = tape.gradient(y, [w])
print(grads)

写在最后

👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/60161.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GD32F103VE外部中断

GD32F103VE外部中断线线0~15,对应外部IO口的输入中断。它有7个中断向量,外部中断线0 ~ 4分别对应EXTI0_IRQn ~ EXTI4_IRQn中断向量;外部中断线 5 ~ 9 共用一个 EXTI9_5_IRQn中断向量;外部中断线10~15 共用一个 EXTI15_10_IRQn中断…

MySQL数据库:表的约束

表的约束,实质上就是用数据类型去约束字段,但是数据类型的约束手法很单一,比如,我们在设置身份证号这个字段,数据类型唯一起的约束是它属于char类型或者varchar类型,不能是浮点型也不能是日期时间类型&…

.net 6 efcore一个model映射到多张表(非使用IEntityTypeConfiguration)

现在有两张表,结构一模一样,我又不想创建两个一模一样的model,就想一个model映射到两张表 废话不多说直接上代码 安装依赖包 创建model namespace oneModelMultiTable.Model {public class Test{public int id { get; set; }public string…

Linux服务器大量日志如何快速定位

Linux服务器大量日志如何快速定位 在生产环境,定位问题,经常会遇到日志文件特别多的情况,经常会遇到日志比较难拿的情况,所以有什么方法可以快速拿日志?除了在代码里很好的打印关键日志信息外,也需要掌握L…

RabbitMQ 教程 | 第10章 网络分区

👨🏻‍💻 热爱摄影的程序员 👨🏻‍🎨 喜欢编码的设计师 🧕🏻 擅长设计的剪辑师 🧑🏻‍🏫 一位高冷无情的编码爱好者 大家好,我是 DevO…

第20节 R语言医学分析:某保险医疗事故赔偿因素分析

文章目录 某保险医疗事故赔偿因素分析源码源文件下载某保险医疗事故赔偿因素分析 我们分析数据集“诉讼”的第一个方法是确定样本数量、变量类型、缩放/编码约定(如果有)用于验证数据清理。 接下来,数据集看起来很干净,没有缺失值,并且对于分类变量,将编码约定替换为实际…

智慧工地云平台源码,基于微服务+Java+Spring Cloud +UniApp +MySql开发

智慧工地可视化系统利用物联网、人工智能、云计算、大数据、移动互联网等新一代信息技术,通过工地中台、三维建模服务、视频AI分析服务等技术支撑,实现智慧工地高精度动态仿真,趋势分析、预测、模拟,建设智能化、标准化的智慧工地…

LeetCode151.反转字符串中的单词

151.反转字符串中的单词 目录 151.反转字符串中的单词题目描述解法一:调用API解法二:原生函数编写 题目描述 给你一个字符串s,请你反转字符串中单词的顺序。 单词是由非空格字符组成的字符串,s中使用至少一个空格将字符串中的单…

嵌入式一开始该怎么学?学习单片机

学习单片机: 模电数电肯定必须的,玩单片机大概率这两门课都学过,学过微机原理更好。 直接看野火的文档,芯片手册,外设手册。 学单片机不要纠结于某个型号,我认为stm32就OK,主要是原理和感觉。…

IPSEC VPN知识点总结

具体的实验:使用IPSEC VPN实现隧道通信 使用IPSEC VPN在有防火墙和NAT地址转换的场景下实现隧道通信 DS VPN实验 目录 1.什么是数据认证,有什么作用,有哪些实现的技术手段? 2.什么是身份认证,有什么作用,有哪些实现…

SAP标准搜索帮助(Search Help)改造之标准增强点

1. 搜索帮助加载前 包含程序:LWDTMO01 行:40 标准搜索帮助输出前的控制(影响标准Search Help CDS View Search Help(如果在标准Search Help搜索帮助出口函数上修改控制参数,则不会影响 CDS View Search Help&#xf…

一百四十五、Kettle——查看Kettle在Windows本地和在Linux上生成的.kettle文件夹位置

(一)目的 查看kettle连数据库后自动生成的.kettle文件夹在Windows本地和在Linux中的位置, 这个文件很重要!!! (二).kettle文件夹在Windows本地的位置 C:\Users\Administrator\.k…

轻松搭建酒店小程序

酒店小程序的制作并不需要编程经验,只需要按照以下步骤进行操作,就能很快地搭建自己的小程序商城。 第一步,注册登录账号进入操作后台,找到并点击【商城】中的【去管理】进入商城的后台管理页面,然后再点击【小程序商城…

使用langchain与你自己的数据对话(四):问答(question answering)

之前我已经完成了使用langchain与你自己的数据对话的前三篇博客,还没有阅读这三篇博客的朋友可以先阅读一下: 使用langchain与你自己的数据对话(一):文档加载与切割使用langchain与你自己的数据对话(二):向量存储与嵌入使用langc…

orangepi 4lts ubuntu安装RabbitMQ

4lts的emmc 系统安装选文件系统格式 ext4 需先安装erlang: sudo apt install erlang 安装RabbitMQ: sudo apt install rabbitmq-server - 添加用户以便远程访问: - 账号密码都是admin: sudo rabbitmqctl add_user admin admin -sudo rabbitmqct…

嵌入式面试刷题(day3)

文章目录 前言一、怎么判断两个float是否相同二、float数据可以移位吗三、数据接收和发送端大小端不一致怎么办四、怎么传输float类型数据1.使用联合进行传输2.使用字节流3.强制类型转换 总结 前言 本篇文章我们继续讲解嵌入式面试刷题,给大家继续分享嵌入式中的面…

Python web实战之Django用户认证详解

关键词: Python Web 开发、Django、用户认证、实战案例 概要 今天来探讨一下 Django 的用户认证吧!在这篇文章中,我将为大家带来一些有关 Django 用户认证的最佳实践。 1. Django 用户认证 在开发 Web 应用程序时,用户认证是一个…

深度学习实战46-基于CNN的遥感卫星地图智能分类,模型训练与预测

大家好,我是微学AI,今天给大家介绍一下深度学习实战46-基于CNN的遥感卫星地图智能分类,模型训练与预测。随着遥感技术和卫星图像获取能力的快速发展,卫星图像分类任务成为了计算机视觉研究中一个重要的挑战。为了促进这一领域的研究进展,EuroSAT数据集应运而生。本文将详细…

Python小白学习:超级详细的字典介绍(字典的定义、存储、修改、遍历元素和嵌套)

目录 一、字典简介1.1 创建字典1.2 访问字典中的值1.3 添加键值对1.4 修改字典中的值实例 1.5 删除键值对1.6 由多个类似对象组成的字典1.7 使用get()访问值1.8 练习题 二、遍历字典2.1 遍历所有键值对实例 2.2 遍历字典中的所有键2.3 按照特定顺序遍历字典中的所有键2.4 遍历字…

AWS Amplify 部署node版本18报错修复

Amplify env:Amazon Linux:2 Build Error : Specified Node 18 but GLIBC_2.27 or GLIBC_2.28 not found on build 一、原因 报错原因是因为默认情况下,AWS Amplify 使用 Amazon Linux:2 作为其构建镜像,并自带 GLIBC 2.26。不过,…
最新文章