机器学习 - 训练模型

接着这一篇博客做进一步说明:
机器学习 - 选择模型

为了解决测试和预测之间的差距,可以通过更新 internal parameters, the weights set randomly use nn.Parameter() and bias set randomly use torch.randn().
Much of the time you won’t know what the ideal parameters are for a model.
Instead, it is much more fun to write code to see if the model can try and figure them out itself. That is a loss function as well as and optimizer.

FunctionWhat does it do?Where does it live in PyTorch?Common values
Loss functionMeasures how wrong your models predictions (e.g. y_preds) are compared to the truth labels (e.g. y_test). Lower the better.差距越小越好PyTorch has plenty of built-in loss functions in torch.nnMean absolute error (MAE) for regression problems (torch.nn.L1Loss()). Binary cross entropy for binary classification problems (torch.nn.BCELoss())
OptimizerTells your model how to update its internal parameters to best lower the loss.You can find various optimization function implementations in torch.optimStochastic gradient descent (torch.optim.SGD()). Adam.optimizer (torch.optim.Adam())

介绍 MAE: Mean absolute error 也称为平均绝对误差,是一种用于衡量预测值与真实值之间差异的损失函数。MAE计算的是预测值与真实值之间的绝对差值的平均值,即平均误差的绝对值。在 PyTorch 中可以使用 torch.nn.L1Loss 来计算MAE.

介绍Stochastic gradient descent:
这是一种常用的优化算法,用于训练神经网络模型。它是梯度下降算法的变种,在每次更新参数时都使用随机样本的梯度估计来更新参数。SGD的基本思想是通过最小化损失函数来调整模型参数,使得模型的预测结果与真实标签尽可能接近。在每次迭代中,SGD随机选择一小批样本 (mini-batch) 来计算损失函数关于参数的梯度,并使用该梯度来更新参数。由于每次更新只使用了一部分样本,因此SGD通常具有更快的收敛速度和更低的计算成本。在 PyTorch 中,可以使用 torch.optim.SGD(params, lr) 来实现,其中

  • params is the target model parameters you’d like to optimize (e.g. the weights and bias values we randomly set before).
  • lr is the learning rate you’d like the optimizer to update the parameters at, higher means the optimizer will try larger updates (these can sometimes be too large and the optimizer will fail to work), lower means the optimizer will try smaller updates (these can sometimes be too small and the optimizer will take too long to find the ideal values). Common starting values for the learning rate are 0.01, 0.001, 0.0001.

介绍 Adam 优化器:
Adam优化器是一种常用的优化算法,它结合了动量法和自适应学习率调整的特性,能够高效地优化神经网络模型的参数。Adam优化器的基本思想是在梯度下降的基础上引入了动量项和自适应学习率调整项。动量项可以帮助优化器在更新参数时保持方向性,从而加速收敛;而自适应学习率调整项可以根据参数的历史梯度来调整学习率,从而在不同参数上使用不同的学习率,使得参数更新更加稳健。

介绍学习率:
学习率是在训练神经网络时控制参数更新步长的一个超参数。它决定了每次参数更新时,参数沿着梯度方向更新的程度。学习率越大,参数更新的步长越大;学习率越小,参数更新的步长越小。选择合适的学习率通常是训练神经网络时需要调节的一个重要超参数。如果学习率过大,可能导致参数更新过大,导致模型不稳定甚至发散;如果学习率过小,可能导致模型收敛速度过慢,训练时间变长。

代码如下:

import torch

# Create the loss function 
loss_fn = nn.L1Loss()  # MAE loss is same as L1Loss

# Create the optimizer
optimizer = torch.optim.SGD(params = model_0.parameters(),
                            lr = 0.01)


现在创造一个optimization loop
The training loop involves the model going through the training data and learning the relationships between the features and labels.
The testing loop involves going through the testing data and evaluating how good the patterns are that the model learned on the training data (the model never sees the testing data during training).
Each of these is called a “loop” because we want our model to look (loop through) at each sample in each dataset. 所以,得用 for 循环来实现。

PyTorch training loop

NumberStep nameWhat does it do?Code example
1Forward passThe model goes through all of the training data once, performing its forward() function calculations.model(x_train)
2Calculate the lossThe model’s outputs (predictions) are compared to the ground truth and evaluated to see how wrong they are.loss = loss_fn(y_pred, y_train)
3Zero gradientsThe optimizers gradients are set to zero (they are accumulated by default) so they can be recalculated for the specific training step.optimizer.zero_grad()
4Perform backpropagation on the lossComputes the gradient of the loss with respect for every model parameter to be updated (each parameter with requires_grad=True). This is known as backpropagation, hence “backwards”.loss.backward()
5Update the optimizer (gradient descent)Update the parameters with requires_grad=True with respect to the loss gradients in order to improve them.optimizer.step()

PyTorch testing loop
As for the testing loop (evaluating the model), the typical steps include:

NumberStep nameWhat does it do?Code example
1Forward passThe model goes through all of the training data once, performing its forward() function calculations.model(x_test)
2Calculate the lossThe model’s outputs (predictions) are compared to the ground truth and evaluated to see how wrong they are.loss = loss_fn(y_pred, y_test)
3Calculate evaluation metrics (optional)Alongside the loss value you may want to calculate other evaluation metrics such as accuracy on the test set.Custom functions

下面是代码实现

# Create the loss function
# 那你。L1Loss() 是用于计算平均绝对误差 (MAE) 的损失函数。
loss_fn = nn.L1Loss()  # MAE loss is same as L1Loss

# Create the optimizer
# torch.optim.SGD() 是用于创建随机梯度下降优化器的函数。
# parameters() 返回一个包含了模型中所有需要进行梯度更新的参数的迭代器
optimizer = torch.optim.SGD(params = model_0.parameters(),
                            lr = 0.01)

# Set the number of epochs (how many times the model will pass over the training data)
epochs = 200

# Create empty loss lists to track values
train_loss_values = []
test_loss_values = []
epoch_count = []

for epoch in range(epochs):
  ### Training 

  # Put model in training mode (this is the default state of a model)
  # train() 函数通常用于将模型设置为训练模式
  model_0.train()

  # 1. Forward pass on train data using the forward() method inside
  y_pred = model_0(X_train)

  # 2. Calculate the loss (how different are our models predictions to the ground truth)
  loss = loss_fn(y_pred, y_train)

  # 3. Zero grad of the optimizer
  optimizer.zero_grad() 

  # 4. Loss backwards
  loss.backward()

  # 5. Progress the optimizer
  # step() 用于执行一步参数更新操作。
  optimizer.step() 

  ### Testing

  # Put the model in evaluation mode
  model_0.eval() 

  with torch.inference_mode():
    # 1. Forward pass on test data 
    test_pred = model_0(X_test)

    # 2. Calculate loss on test data 
    test_loss = loss_fn(test_pred, y_test.type(torch.float))  # predictions come in torch.float datatype, so comparisons need to be done with tensors of the same type 

    # Print out 
    if epoch % 10 == 0:
      epoch_count.append(epoch)
      train_loss_values.append(loss.detach().numpy())
      test_loss_values.append(test_loss.detach().numpy())
      print(f"Epoch: {epoch} | MAE Train Loss: {loss} | MAE Test Loss: {test_loss}")


plt.plot(epoch_count, train_loss_values, label="Train loss")
plt.plot(epoch_count, test_loss_values, label="Test loss")
plt.title("Training and test loss curves")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend()

print("The model learned the following values for weights and bias: ")
print(model_0.state_dict())
print("\nAnd the original values for weights and bias are: ")
print(f"weights: {weight}, bias: {bias}")

# 结果如下:
Epoch: 0 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 10 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 20 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 30 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 40 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 50 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 60 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 70 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 80 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 90 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 100 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 110 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 120 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 130 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 140 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 150 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 160 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 170 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 180 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
Epoch: 190 | MAE Train Loss: 0.008932482451200485 | MAE Test Loss: 0.005023092031478882
The model learned the following values for weights and bias: 
OrderedDict([('weights', tensor([0.6990])), ('bias', tensor([0.3093]))])

And the original values for weights and bias are: 
weights: 0.7, bias: 0.3

Loss is the measure of how wrong your model is. Loss 的值越低,效果越好。

效果图

都看到这了,点个赞支持下呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/470941.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Unity 实现双屏或多屏内容展示

在某些应用场景,一个应用可能需要使用多个显示器显示。 Unity支持最多8个不同显示器同时显示应用程序中八个摄像头的视图,如下图: 具体实现如下: 1、在Hiearchy面板上点击鼠标右键->Camera,创建多一个Camera,如图&#xff1a…

postman 用上一个请求的响应体中的字段设置下一个请求的请求参数

文章目录 IntroPostman usagePre-request ScriptTests javascripts API Intro 这一切都是为了增加自动化动作所占的比例(减少人手工操作复制粘贴可能会造成的错误)。 Postman usage 最常用的:选HTTP方法类型、写URL,在Headers中…

EvaRichter勒索病毒来袭

目前勒索病毒仍然是全球最大的威胁,最近一年针对企业的勒索病毒攻击越来越多,大部分勒索病毒是无法解密的,并且不断有新型的勒索病毒出现,各企业一定要保持高度的重视,马上放假了,一款新型勒索病毒来袭....…

【Linux】编译器-gcc/g++的使用(预处理、编译、汇编、连接)

目录 01.预处理(宏替换) 02.编译(生成汇编) 03.汇编(生成机器可识别码) 04.连接(生成可执行文件或库文件) 05.选项 编译器在编译代码时包含以下四个步骤:1.预处理 2…

数据结构(三)复杂度的深层次剖析

之前发布了数据结构(一),很多同学反响不够清晰,那今天就发一篇对复杂度专题的博客,希望对大家理解复杂度提供一些帮助。 时间复杂度 我们先来一个理解一个复杂度,二分查找的复杂度(之前写过二…

app调用系统接口示意图

1.linux下app不能直接访问内核。 用户态和内核态 2. 系统调用是应用程序和系统内核之间的接口。 (1)app访问内核通过调用glibc中的系统调用接口(open()、read()、write()…

html5cssjs代码 024 响应式布局示例

html5&css&js代码 024 响应式布局示例 一、代码二、解释 该HTML代码重点在于构建一个带有响应式设计的两栏布局网页,包含页头、导航条、主要内容区(左右两列)和底部区域,并运用CSS样式设置页面元素的布局、颜色、字体、间…

C++特性三:多态---案例三(电脑组装)

案例描述: 电脑主要组成部件为 CPU(用于计算),显卡(用于显示),内存条(用于存储) 将每个零件封装出抽象基类,并且提供不同的厂商生产不同的零件,例…

PwnLab靶场PHP伪协议OSCP推荐代码审计命令劫持命令注入

下载链接:PwnLab: init ~ VulnHub 安装: 打开vxbox直接选择导入虚拟电脑即可 正文: 先用nmap扫描靶机ip nmap -sn 192.168.1.1/24 获取到靶机ip后,对靶机的端口进行扫描,并把结果输出到PwnLab文件夹下,命名…

杰发科技AC7801——Keil编译的Hex大小如何计算

编译结果是Keil里面前三个数据的总和: 即CodeRoDataRWData的总和。 通过ATCLinkTool工具查看内存,发现最后一个字节正好是5328 注意读内存数据时候需要强转成32位,加1000的 增加1024的地址只需要加256即可

【技术栈】Redis 删除策略

SueWakeup 个人主页:SueWakeup 系列专栏:学习技术栈 个性签名:保留赤子之心也许是种幸运吧 本文封面由 凯楠📸 友情提供 目录 相关传送门 前言 1. 删除策略的目标 2. 数据删除策略 2.1 定时删除 2.2 惰性删除 2.3 定期删除…

【S5PV210】 | GPIO编程

【S5PV210】 | GPIO编程 时间:2024年3月17日22:02:32 目录 文章目录 【`S5PV210`】 | `GPIO`编程目录1.参考2.`DataSheet`2.1.概述2.1.1.特色2.1.2 输入/输出配置2.1.3 `S5PV210` 输入/输出类型2.1.4 IO驱动强度**2.1.4.1 类型A IO驱动强度****2.1.4.2 类型A IO驱动强度****2…

Windows电源管理调节-Powercfg命令应用

Windows电源管理调节 PowerCfg命令介绍 在Windows下我们使用 powercfg.exe命令 来控制电源计划(也称为电源方案),以使用可用的睡眠状态、控制单个设备的电源状态,以及分析系统中常见的能效和电池寿命问题。 语法 Powercfg 命令行使用以下语法: powercfg /option [arg…

使用WordPress在US Domain Center上建立多语言网站的详细教程

第一部分:介绍多语言网站 多语言网站是一种可以用多种语言呈现内容的网站。它能够满足不同国家或地区用户的语言需求,提升网站的用户体验和可访问性。在WordPress中,您可以轻松地创建一个多语言网站,并通过插件来管理多语言内容&…

Go 1.22 - 更加强大的 Go 执行跟踪

原文:Michael Knyszek - 2024.03.14 runtime/trace 包含了一款强大的工具,用于理解和排查 Go 程序。这个功能可以生成一段时间内每个 goroutine 的执行追踪。然后,你可以使用 go tool trace 命令(或者优秀的开源工具 gotraceui&a…

漏洞发现-漏扫项目篇Poc开发Yaml语法插件一键生成匹配结果交互提取

知识点 1、Nuclei-Poc开发-环境配置&编写流程 2、Nuclei-Poc开发-Yaml语法&匹配提取 3、Nuclei-Poc开发-BurpSuite一键生成插件 章节点: 漏洞发现-Web&框架组件&中间件&APP&小程序&系统 扫描项目-综合漏扫&特征漏扫&被动漏扫…

C语言经典算法-8

文章目录 其他经典例题跳转链接41.基数排序法42.循序搜寻法(使用卫兵)43.二分搜寻法(搜寻原则的代表)44.插补搜寻法45.费氏搜寻法 其他经典例题跳转链接 C语言经典算法-1 1.汉若塔 2. 费式数列 3. 巴斯卡三角形 4. 三色棋 5. 老鼠…

Flutter开发进阶之使用Socket实现主机服务(二)

Flutter开发进阶之使用Socket实现主机服务(二) Flutter开发进阶之使用Socket实现主机服务(一) 在完成局域网内设备定位后就可以进入微服务的实操了。 I、构建Socket连接池 一、定义Socket 使用socket_io_client socket_io_client: ^2.0.3+1导入头文件 import packag…

LiveGBS流媒体平台GB/T28181功能-HTTPS 服务支持配置开启什么时候需要开启HTTPS测试SSL证书配置HTTPS测试证书

LiveGBS功能支持HTTPS 服务支持配置开启什么时候需要开启HTTPS测试SSL证书配置HTTPS测试证书 1、配置开启HTTPS1.1、准备https证书1.1.1、选择Nginx类型证书下载 1.2、配置 LiveCMS 开启 HTTPS1.2.1 web页面配置1.2.2 配置文件配置 2、HTTPS测试证书3、验证HTTPS服务4、为什么要…

图书馆管理系统 2.后台系统管理模块编写

后端 1.实体类编写 用户实体类 package jkw.pojo;import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import lombok.Data;import java.io.Serializable; import java.util.List;/*** 用户*/ Data public class …
最新文章