SGD随机梯度下降

一、补充概念:

  1. 目标函数(Objective Function):这个术语通常指的是整个优化问题中需要最小化(或最大化)的函数。在机器学习和优化中,目标函数可以包括损失函数以及正则化项等目标函数的最优化过程旨在找到使目标函数取得最小值或最大值的参数值。

  2. 损失函数(Loss Function):这个术语通常指的是在监督学习中用来衡量模型预测值与真实标签之间差异的函数。损失函数是目标函数的一部分,通常作为目标函数的组成部分出现。在训练过程中,损失函数的值被用来作为优化算法的目标,以便通过调整模型参数来最小化损失函数。

      3.损失函数:某些情况下为目标函数。

      4.梯度:梯度通常是指损失函数关于模型参数的偏导数

在机器学习和深度学习中,训练模型的目标是通过最小化损失函数来优化模型参数,而梯度是一种用于指导参数更新的重要工具。

二、SGD

SGD 是随机梯度下降(Stochastic Gradient Descent)的缩写。梯度下降是一种优化算法,用于最小化损失函数,通过迭代更新参数来逐步调整模型以最小化损失。在机器学习和深度学习中,梯度下降被广泛用于训练模型。

随机梯度下降梯度下降的一种变体,其基本思想是每次迭代时随机选择一个样本来计算梯度并更新参数,而不是使用整个训练集来计算梯度。相比于传统的批量梯度下降(Batch Gradient Descent),随机梯度下降的计算代价更低尤其在大规模数据集上更为高效

优点:随机梯度下降通常用于训练大规模数据集和深度神经网络,因为它能够以较低的计算成本和内存消耗实现模型的训练。

缺点:由于随机梯度下降对梯度的估计是基于单个样本的,因此可能会导致参数更新的不稳定性,需要采用一些技巧来调整学习率和控制收敛速度,如学习率衰减、动量等。

三、简单代码举例

当使用 PyTorch 进行随机梯度下降的实现时,可以通过 PyTorch 提供的优化器类 torch.optim.SGD 来实现。以下是一个简单的示例代码,展示了如何使用 PyTorch 来实现随机梯度下降算法:

import torch
import torch.nn as nn
import torch.optim as optim

# 假设有一些训练数据
# 这里我们创建一个简单的线性回归问题
# 输入特征维度为 1,输出维度为 1
# 我们的目标是拟合一个简单的线性函数 y = 2x + 1

# 构建训练数据
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])

# 构建一个简单的线性模型
model = nn.Linear(1, 1)

# 定义损失函数,这里使用均方误差损失
criterion = nn.MSELoss()

# 定义优化器,这里使用随机梯度下降
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 进行模型训练
num_epochs = 100
for epoch in range(num_epochs):
    # 前向传播
    outputs = model(x_train)
    # 计算损失
    loss = criterion(outputs, y_train)
    # 梯度清零
    optimizer.zero_grad()
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 测试模型
x_test = torch.tensor([[5.0]])
predicted = model(x_test)
print(f'Predicted value for x = 5: {predicted.item():.4f}')

运行结果截图:

图1 运行结果

这个示例代码中,我们首先定义了训练数据 x_trainy_train,然后构建了一个简单的线性模型 model。接着定义了损失函数 criterion,这里使用均方误差损失。然后使用 torch.optim.SGD 定义了优化器 optimizer,学习率设置为 0.01。

在训练过程中,我们对模型进行多轮的迭代,每一轮中首先进行前向传播计算输出,然后计算损失,接着梯度清零,进行反向传播计算梯度,最后更新参数。最后我们对模型进行测试,对一个新的输入样本进行预测。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/532053.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python程序设计 列表

教学案例八 列表 1. 计算并显示斐波那契数列 输入n,计算并显示斐波那契数列前n项.一行打印5项,每项显示宽度为6 什么是斐波那契数列 斐波那契数列(Fibonacci sequence),又称黄金分割数列、 因数学家莱昂纳多斐波那契&#xff…

基于SSM+Jsp+Mysql的农产品供销服务系统

开发语言:Java框架:ssm技术:JSPJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包…

windows系统安装mysql5.7

1、下载 下载路径:https://downloads.mysql.com/archives/community/ 2、创建配置文件my.ini 下载压缩包解压到安装目录(本机解压后在D:\mysql-5.7.44-winx64) 在bin的同级目录下创建my.ini文件 my.ini文件 [mysql] # 设置mysql客户端默认字符…

接口自动化测试(python+pytest+requests)

一、选取自动化测试用例 优先级高:先实现业务流程用例、后实现单接口用例功能较稳定的接口优先开展测试用例脚本的实现二、搭建自动化测试环境 核心技术:编程语言:python;测试框架:pytest;接口请求:requests安装/验证requests:命令行终端分别输入 pip install requests / p…

6.1Python之字典的初识

【1】字典的创建与价值 字典(Dictionary)是一种在Python中用于存储和组织数据的数据结构。元素由键和对应的值组成。其中,键(Key)必须是唯一的,而值(Value)则可以是任意类型的数据。…

国内如何实现GPT升级付款

本来想找国外的朋友代付的,但是他告诉我他的信用卡已经被绑定了他也升级了所以只能自己想办法了。就在一位博主下边发现了这个方法真的可以。只是需要与支付宝验证信息。刚开始不敢付款害怕被骗哈哈,我反诈骗意识绝对杠杠的 该方法就是我们办理一张虚拟…

3D可视化技术亮相高铁站,引领智慧出行新潮流

在科技飞速发展的今天,我们的生活正经历着前所未有的变革。高铁站作为现代交通的重要枢纽,也在不断地创新和进步。 3D可视化技术通过三维立体的方式,将高铁站内部和外部的结构、设施、流线等以更加直观、生动的形式呈现出来。乘客们只需通过手…

2、java语法之循环、数组与方法(找工作版)

写在前面:整个系列文章是自己学习慕课相关视频,进行的一个总结。文章只是为了记录学习课程的整个过程,方便以后查漏补缺,找到对应章节。 文章目录 一、Java循环结构1、while循环2、do-while循环3、for循环4、嵌套循环5、break语句…

我认识的建站公司老板都躺平了,每年维护费都大几十万。

这些老板们过的悠哉游哉,大富大贵没有,达到中产,活得舒服,没毛病。 企业官网每年需要交维护费主要是因为以下几个原因: 网站服务器和域名费用:企业官网需要通过服务器进行托管和访问,同时需要…

三种常见webshell工具的流量特征分析

又来跟师傅们分享小技巧了,这次简单介绍一下三种常见的webshell流量分析,希望能对参加HW蓝队的师傅们有所帮助。 什么是webshell webshell就是以asp、php、jsp或者cgi等网页文件形式存在的一种代码执行环境,主要用于网站管理、服务器管理、…

蓝桥杯第十二届c++大学B组详解

目录 1.空间 2.直线 3.路径 4.卡片 5.货物摆放 6.时间显示 7.砝码称重 8.杨辉三角 9.双向排序 10.括号序列 1.空间 题目解析:1Byte 8bit 1kb 1024B 1MB 1024kb; 先将256MB变成Byte 256 * 1024 * 1024; 再将32位 变成Byte就是 32 / 8 4;…

4.进程相关 2

8.内存映射 8.1 内存映射相关定义 创建一个文件,将保存在磁盘中的文件映射到内存中,后期两个进程之间对内存中的数据进行操作,大大减少了访问磁盘的时间,也是一种最快的 IPC ,因为进程之间可以直接对内存进行存取 8.…

第十二届蓝桥杯省赛真题(C/C++大学B组)

目录 #A 空间 #B 卡片 #C 直线 #D 货物摆放 #E 路径 #F 时间显示 #G 砝码称重 #H 杨辉三角形 #I 双向排序 #J 括号序列 #A 空间 #include <bits/stdc.h> using namespace std;int main() {cout<<256 * 1024 * 1024 / 4<<endl;return 0; } #B 卡片…

JVM性能监控与调优——命令行工具

文章目录 1、概述2、jps:查看正在运行的Java进程3、jstat&#xff1a;查看JVM统计信息4、jinfo&#xff1a;实时查看和修改JVM配置参数5、jmap&#xff1a;导出内存映像文件和内存使用情况6、jhat:JDK自带堆分析工具7、jstack&#xff1a;打印JVM中线程快照8、jcmd&#xff1a;…

第一个Swift程序

要创建第一个Swift项目,请按照以下步骤操作: 打开Xcode。如果您没有安装Xcode,可以在App Store中下载并安装它。在Xcode的欢迎界面上,选择“Create a new Xcode project”(创建新Xcode项目)。在模板选择界面上,选择“App”(应用程序)。在应用模板选择界面上,选择“Si…

【静态分析】静态分析笔记01 - Introduction

参考&#xff1a; BV1zE411s77Z [南京大学]-[软件分析]课程学习笔记(一)-introduction_南京大学软件分析笔记-CSDN博客 ------------------------------------------------------------------------------------------------------ 1. program language and static analysis…

【JavaWeb】Jsp基本教程

目录 JSP概述作用一个简单的案例&#xff1a;使用JSP页面输出当前日期 JSP处理过程JSP 生命周期编译阶段初始化阶段执行阶段销毁阶段案例 JSP页面的元素JSP指令JSP中的page指令Include指令示例 taglib指令 JSP中的小脚本与表达式JSP中的声明JSP中的注释HTML的注释JSP注释 JSP行…

llinux进程控制

学习进程创建,fork/vfork 学习到进程等待 学习到进程程序替换, 微型shell&#xff0c;重新认识shell运行原理 学习到进程终止,认识$? fork函数 在linux中fork函数时非常重要的函数&#xff0c;它从已存在进程中创建一个新进程。新进程为子进程&#xff0c;而原进程为父进程…

PostgreSQL入门到实战-第十五弹

PostgreSQL入门到实战 PostgreSQL数据过滤(八)官网地址PostgreSQL概述PostgreSQL中LIKE命令理论PostgreSQL中LIKE命令实战更新计划 PostgreSQL数据过滤(八) 如何使用PostgreSQL LIKE运算符基于模式查询数据。 官网地址 声明: 由于操作系统, 版本更新等原因, 文章所列内容不一…

Commitizen:规范化你的 Git 提交信息

简介 在团队协作开发过程中&#xff0c;规范化的 Git 提交信息可以提高代码维护的效率&#xff0c;便于追踪和定位问题。Commitizen 是一个帮助我们规范化 Git 提交信息的工具&#xff0c;它提供了一种交互式的方式来生成符合约定格式的提交信息。 原理 Commitizen 的核心原…
最新文章