NLP(3)--利用nn反向计算参数

前言

仅记录学习过程,有问题欢迎讨论

获取数据

自定义一个方程,获取一批数据X,Y

import matplotlib.pyplot as pyplot
import math
import sys

X = [0.01 * x for x in range(100)]
Y = [2*x**2 + 3*x + 4 for x in X]
print(X)
print(Y)
pyplot.scatter(X, Y, color='red')
pyplot.show()
input()

利用nn模型计算w

X = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18,
     0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35000000000000003,
     0.36, 0.37, 0.38, 0.39, 0.4, 0.41000000000000003, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47000000000000003, 0.48, 0.49,
     0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.5700000000000001, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66,
     0.67, 0.68, 0.6900000000000001, 0.7000000000000001, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8,
     0.81, 0.8200000000000001, 0.8300000000000001, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93,
     0.9400000000000001, 0.9500000000000001, 0.96, 0.97, 0.98, 0.99]
Y = [4.0, 4.0302, 4.0608, 4.0918, 4.1232, 4.155, 4.1872, 4.2198, 4.2528, 4.2862, 4.32, 4.3542, 4.3888, 4.4238, 4.4592,
     4.495, 4.5312, 4.5678, 4.6048, 4.6422, 4.68, 4.7181999999999995, 4.7568, 4.7958, 4.8352, 4.875, 4.9152000000000005,
     4.9558, 4.9968, 5.0382, 5.08, 5.122199999999999, 5.1648, 5.2078, 5.2512, 5.295, 5.3392, 5.3838, 5.4288, 5.4742,
     5.5200000000000005, 5.5662, 5.6128, 5.6598, 5.7072, 5.755, 5.8032, 5.851800000000001, 5.9008, 5.9502, 6.0, 6.0502,
     6.1008, 6.1518, 6.203200000000001, 6.255000000000001, 6.3072, 6.3598, 6.4128, 6.4662, 6.52, 6.5742, 6.6288, 6.6838,
     6.7392, 6.795, 6.8512, 6.9078, 6.9648, 7.022200000000001, 7.08, 7.138199999999999, 7.1968, 7.2558, 7.3152, 7.375,
     7.4352, 7.4958, 7.5568, 7.6182, 7.680000000000001, 7.7422, 7.8048, 7.867800000000001, 7.9312, 7.994999999999999,
     8.0592, 8.1238, 8.1888, 8.2542, 8.32, 8.3862, 8.4528, 8.5198, 8.587200000000001, 8.655000000000001, 8.7232, 8.7918,
     8.8608, 8.9302]

# 标签
def func(x):
    y = w1 * x**2 + w2 * x + w3
    return y

# 损失函数
def loss(y_pre, y_true):
    return (y_true - y_pre) ** 2


# 随机定义 w
w1, w2, w3 = 1, 2, 3
# 学习率
lr = 0.1

# 训练过程
for epoch in range(500):
    epoch_loss = 0
    for x, y_true in zip(X, Y):
        y_pre = func(x)
        # 本轮loss 总值
        epoch_loss += loss(y_pre, y_true)
        # 梯度计算
        grad_w1 = 2 * (y_pre - y_true) * x ** 2
        grad_w2 = 2 * (y_pre - y_true) * x
        grad_w3 = 2 * (y_pre - y_true)
        # 权重更新
        w1 = w1 - lr * grad_w1  # sgd
        w2 = w2 - lr * grad_w2
        w3 = w3 - lr * grad_w3
    # 本轮结束 计算平均loss
    epoch_loss /= len(X)
    print("第%d轮, loss %f" % (epoch, epoch_loss))
    if epoch_loss < 0.00001:
        break

print(f"训练后权重:w1:{w1} w2:{w2} w3:{w3}")
# #使用训练后模型输出预测值
Yp = [func(i) for i in X]
# 预测值与真实值比对数据分布
pyplot.scatter(X, Y, color="red")
pyplot.scatter(X, Yp)
pyplot.show()

优化梯度、权重计算

X = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18,
     0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35000000000000003,
     0.36, 0.37, 0.38, 0.39, 0.4, 0.41000000000000003, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47000000000000003, 0.48, 0.49,
     0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.5700000000000001, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66,
     0.67, 0.68, 0.6900000000000001, 0.7000000000000001, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8,
     0.81, 0.8200000000000001, 0.8300000000000001, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93,
     0.9400000000000001, 0.9500000000000001, 0.96, 0.97, 0.98, 0.99]
Y = [4.0, 4.0302, 4.0608, 4.0918, 4.1232, 4.155, 4.1872, 4.2198, 4.2528, 4.2862, 4.32, 4.3542, 4.3888, 4.4238, 4.4592,
     4.495, 4.5312, 4.5678, 4.6048, 4.6422, 4.68, 4.7181999999999995, 4.7568, 4.7958, 4.8352, 4.875, 4.9152000000000005,
     4.9558, 4.9968, 5.0382, 5.08, 5.122199999999999, 5.1648, 5.2078, 5.2512, 5.295, 5.3392, 5.3838, 5.4288, 5.4742,
     5.5200000000000005, 5.5662, 5.6128, 5.6598, 5.7072, 5.755, 5.8032, 5.851800000000001, 5.9008, 5.9502, 6.0, 6.0502,
     6.1008, 6.1518, 6.203200000000001, 6.255000000000001, 6.3072, 6.3598, 6.4128, 6.4662, 6.52, 6.5742, 6.6288, 6.6838,
     6.7392, 6.795, 6.8512, 6.9078, 6.9648, 7.022200000000001, 7.08, 7.138199999999999, 7.1968, 7.2558, 7.3152, 7.375,
     7.4352, 7.4958, 7.5568, 7.6182, 7.680000000000001, 7.7422, 7.8048, 7.867800000000001, 7.9312, 7.994999999999999,
     8.0592, 8.1238, 8.1888, 8.2542, 8.32, 8.3862, 8.4528, 8.5198, 8.587200000000001, 8.655000000000001, 8.7232, 8.7918,
     8.8608, 8.9302]

# 标签
def func(x):
    y = w1 * x**2 + w2 * x + w3
    return y

# 损失函数
def loss(y_pre, y_true):
    return (y_true - y_pre) ** 2


# 随机定义 w
w1, w2, w3 = 1, 2, 3
# 学习率
lr = 0.1

batch_size = 20
# 训练过程
for epoch in range(500):
    epoch_loss = 0
    count = 0
    grad_w1 = 0
    grad_w2 = 0
    grad_w3 = 0
    for x, y_true in zip(X, Y):
        count += 1
        y_pre = func(x)
        # 本轮loss 总值
        epoch_loss += loss(y_pre, y_true)
        # 梯度计算
        grad_w1 += 2 * (y_pre - y_true) * x ** 2
        grad_w2 += 2 * (y_pre - y_true) * x
        grad_w3 += 2 * (y_pre - y_true)
        # 更新权重
        if count == batch_size:
            count = 0
            # 权重更新
            w1 = w1 - lr * grad_w1/batch_size  # sgd
            w2 = w2 - lr * grad_w2/batch_size
            w3 = w3 - lr * grad_w3/batch_size
    # 本轮结束 计算平均loss
    epoch_loss /= len(X)
    print("第%d轮, loss %f" % (epoch, epoch_loss))
    if epoch_loss < 0.00001:
        break

print(f"训练后权重:w1:{w1} w2:{w2} w3:{w3}")
# #使用训练后模型输出预测值
Yp = [func(i) for i in X]
# 预测值与真实值比对数据分布
pyplot.scatter(X, Y, color="red")
pyplot.scatter(X, Yp)
pyplot.show()

利用轮子优化代码

import numpy
import numpy as np
import torch
import torch.utils.data as Data



true_w = torch.from_numpy(numpy.array([5, 6, 7])).float()
true_b = 7
print(true_w)
# 定义模型
feature = torch.tensor(numpy.random.normal(0, 1, (100, 2)), dtype=float).float()
print(feature)
labels = true_w[0] * feature[:, 0] + true_w[1] * feature[:, 1] + true_b
labels += torch.tensor(numpy.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)
lr = 0.2
batch_size = 20
dateset = Data.TensorDataset(feature, labels)
data_iter = Data.DataLoader(dateset, batch_size, shuffle=True)

for X,y in data_iter:
    print(X,y)
    break
net = torch.nn.Sequential(
    torch.nn.Linear(2, 1)
)
print("net =", net)
# 使用net前 需要初始化参数 初始化
torch.nn.init.normal(net[0].weight, mean=0, std=0.01)
torch.nn.init.constant_(net[0].bias, val=0)

loss = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)

for epoch in range(10):
    for X,y_ture in data_iter:
        y_pre = net(X)
        l = loss(y_pre,y_ture.view(-1,1))
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    print('epoch %d, loss: %f' % (epoch, l.item()))

    # 比较学到的模型参数和真实的模型参数
    print('result ==================')
    dense = net[0]
    print(true_w, dense.weight)
    print(true_b, dense.bias)




本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/572759.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

pwn--realloc [CISCN 2019东南]PWN5

首先学习一下realloc这个函数&#xff0c;以下是文心一言的解释&#xff1a; realloc是C语言库函数之一&#xff0c;用于重新分配内存空间。它的主要功能是调整一块内存空间的大小。当需要增加内存空间时&#xff0c;realloc会分配一个新的更大的内存块&#xff0c;然后将原内…

QT从入门到实战x篇_22_番外1_Qt事件系统

文章目录 1. Qt事件系统简介1.1 事件的来源和传递1.2 事件循环和事件分发1.2.1 QT消息/事件循环机制1.2.1.1 机制解释1.2.1.2 两个问题 1.2.2 事件分发 2. 事件过滤基础2.1 什么是事件过滤器&#xff08;Event Filter&#xff09;&#xff1f;2.2 如何安装事件过滤器 3. 事件过…

《QT实用小工具·四十》显示帧率的控件

1、概述 源码放在文章末尾 该项目实现了可以显示帧率的控件&#xff0c;项目demo演示如下所示&#xff1a; 、 项目部分代码如下所示&#xff1a; #ifndef FPSITEM_H #define FPSITEM_H#include <QQuickItem>class FpsItem : public QQuickItem {Q_OBJECTQ_PROPERTY(i…

制作自己的YOLOv8数据集

制作自己的YOLO8数据集 前言 该数据集的格式参照于coco数据集结构✨ 步骤一&#xff1a;收集图像数据 从互联网上下载公开的数据集&#xff0c;也可以使用摄像头或其他设备自行采集图像&#xff0c;确保你的图像数据覆盖了你感兴趣的目标和场景 步骤二&#xff1a;安装Labe…

提升工作效率必备,桌面待办事项提醒软件

在快节奏的现代社会&#xff0c;提升工作效率成为众多上班族的共同追求。有效的时间管理、合理的工作计划和正确的工具选择&#xff0c;是实现高效工作的三大关键。尤其是选择一款优秀的待办事项管理软件&#xff0c;能够极大地助力我们提升工作效率。 而我在网上找到了一款提…

11 - 在k8s官方文档上,经常搜索不到内容的问题

使用k8s官方文档时&#xff0c;会出现首页可以正常打开&#xff0c;但是输入搜索关键字之后&#xff0c;搜索不到内容的情况&#xff0c;如下图&#xff1a; 这是由于相关搜索组件被墙的原因&#xff0c;处理方法如下&#xff1a; 谷歌浏览器&#xff1a; 火狐浏览器&#x…

AI大模型重塑新媒体变现格局:智能写作技术助力腾飞!

文章目录 一、AI大模型&#xff1a;新媒体变革的引擎二、智能写作&#xff1a;内容生产的新范式三、精准推送&#xff1a;增强用户粘性的关键四、新媒体变现&#xff1a;插上AI翅膀的飞跃五、挑战与机遇并存&#xff1a;AI与新媒体的未来展望AI智能写作: 巧用AI大模型让新媒体变…

智慧园区引领未来产业趋势:科技创新驱动园区发展,构建智慧化产业新体系

目录 一、引言 二、智慧园区引领未来产业趋势 1、产业集聚与协同发展 2、智能化生产与服务 3、绿色可持续发展 三、科技创新驱动园区发展 1、创新资源的集聚与整合 2、创新成果的转化与应用 3、创新文化的培育与弘扬 四、构建智慧化产业新体系 1、优化产业布局与结构…

鸿蒙OpenHarmony【集成三方SDK】 (基于Hi3861开发板)

OpenHarmony致力于打造一套更加开放完善的IoT生态系统&#xff0c;为此OpenHarmony规划了一组目录&#xff0c;用于将各厂商的SDK集成到OpenHarmony中。本文档基于Hi3861开发板&#xff0c;向平台开发者介绍将SDK集成到OpenHarmony的方法。 规划目录结构 三方SDK通常由静态库…

C语言联合体详解

下午好诶&#xff0c;今天小眼神给大家带来一篇C语言联合体详解的文章~ 目录 联合体 1. 联合体类型的声明 2. 联合体的特点 代码一&#xff1a; 代码二&#xff1a; 3. 相同成员的结构体和联合体对比 ​编辑4. 联合体大小的计算 5. 联合体的优点 联合体 1. 联合体…

干货:html中的标签属性大全25个,收藏起来吧。

<meta> 元素是 HTML 中的一个标签&#xff0c;用于提供关于文档的元数据信息。它通常位于 <head> 标中&#xff0c;不会直接在页面上显示内容&#xff0c;而是用于告诉浏览器和搜索引擎一些关于页面的信息。 <meta charset"字符编码">&#xff1a;…

【算法入门-Python】02_递归

一、递归 递归的两个特点&#xff1a;调用自身&#xff1b;结束条件。 def func1(x):print(x)func1(x-1)没有结束条件&#xff0c;si递归。不是递归。 def func2(x)&#xff1a;if x > 0:print(x)func2(x1)递归调用的x1&#xff0c;没有结束条件。不是递归 def func3(x)…

【保姆级教程】Windows 远程登陆 Linux 服务器的两种方式:SSH + VS Code,开发必备

0. 前言 很多情况下代码开发需要依赖 Linux 系统&#xff0c;远程连接 Linux 服务器进行开发和维护已成为一种常态。对于使用Windows系统的开发者来说&#xff0c;掌握如何通过 SSH 安全地连接到 Linux 服务器&#xff0c;并利用 VS Code 编辑器进行开发&#xff0c;是一项必备…

Unix 进程基本信息

目录 一、程序执行流程二、进程的执行状态三、进程信息记录3.1 proc结构体3.2 user结构体 四、内存分配4.1 代码段代码段如何管理&#xff1f;4.2 数据段4.3 虚拟地址空间4.4 交换地址APR构成APR数量APR切换 内容来源&#xff1a;《Unix内核源码剖析》 一、程序执行流程 为程序…

python学习笔记(集合)

知识点思维导图 # 直接使用{}进行创建 s{10,20,30,40} print(s)# 使用内置函数set()创建 sset() print(s)# 创建一个空的{}默认是字典类型 s{} print(s,type(s))sset(helloworld) print(s) sset([10,20,30]) print(s) s1set(range(1,10)) print(s1)print(max:,max(s1)) print(m…

Java web第四次作业

要求&#xff1a;读取xml文件并在页面中显示出来。 采用三种方式实现&#xff0c;并体会其中的原理&#xff1a; 1.常规方式&#xff0c;controlller控制器不分层 代码&#xff1a;RestController public class PoetController { RequestMapping("/listPoet") pu…

STL::string简单介绍

目录 1、什么是STL STL6大组件:仿函数、算法、容器、空间配置器、迭代器、配接器 推荐文档&#xff08;必须学会看文档&#xff09; 2、string常用接口 a、初始化 1、什么是STL 标准模板库 STL&#xff08;Standard Template Library&#xff09;&#xff0c;主要是数据结构…

如何带好一个开发小团队?

俗话说&#xff1a;授人以鱼不如授人以渔&#xff0c;这句话强调的是教会别人解决问题的方法比单纯给予一次性帮助更有价值。提倡教育和培养团队成员&#xff0c;使其具备自我解决问题的能力。带领一个开发小团队需要综合考虑管理、沟通和技术能力等方面。以下是一些建议&#…

2024年电子商务与大数据经济国际会议 (EBDE 2024)

2024年电子商务与大数据经济国际会议 (EBDE 2024) 2024 International Conference on E-commerce and Big Data Economy 【会议简介】 2024年电子商务与大数据经济国际会议即将在厦门召开。本次会议旨在汇聚全球电子商务与大数据经济领域的专家学者&#xff0c;共同探讨电子商务…

实验五 Spark SQL编程初级实践

Spark SQL编程初级实践 Spark SQL基本操作 将下列JSON格式数据复制到Linux系统中&#xff0c;并保存命名为employee.json。 { "id":1 , "name":" Ella" , "age":36 } { "id":2, "name":"Bob","a…