神经网络框架的基本设计

一、神经网络框架设计的基本流程

确定网络结构、激活函数、损失函数、优化算法,模型的训练与验证,模型的评估和优化,模型的部署。

二、网络结构与激活函数

1、网络架构

这里我们使用的是多层感知机模型MLP(multilayer prrceptron):

MLP一般分为三层:输入层、隐藏层和输出层。
输入层:接收输入数据。
隐藏层:负责处理数据,可以有一个或多个隐藏层,每个隐藏层包含若干个神经元。
输出层:输出最终结果,通常是一个softmax层,用于多分类任务,或者是一个sigmoid层,用于二分类任务。
MLP的核心思想是通过增加神经元的数量和层次,提高模型的表达能力。由于有多个层,参数需要在这些层之间传递。

每个隐藏层神经元中计算过程如下:
将输入数据传递给第一个隐藏层的神经元。
对于每个神经元,计算其加权和,即将输入与对应的权重相乘并求和,再加上偏置项。
将加权和输入到激活函数中,得到激活值,作为该神经元的输出。
将每个神经元的输出传递到下一层的神经元,直至输出层。
在这个过程中,数据和权重是前向传播的主要传播内容。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28,312),
            nn.ReLU(),
            nn.Linear(312, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)
        return logits

2、激活函数:

用于引入非线性因素,使得神经网络具有更强的表达能力。常见的激活函数有Sigmoid、ReLU、Tanh等。这里我们暂时略过

3、损失函数:

损失函数是用来量化模型预测值与真实值差距的方式(这么说可能过于抽象,我们初中数学应该学过标准差和方差的概念,方差和标准差是量化数组各元素与平均值的差距的。)而能反映二者间差距的值有很多,常见的有均方损失、交叉熵损失、铰链损失、绝对值损失。

具体什么含义可以自行学习。

损失函数的作用是用来评估模型的准确度的,

在上次实验中,我们使用的是均方损失,这一次我们使用交叉熵损失函数:

loss_fu = torch.nn.CrossEntropyLoss()
loss = loss_fu(pred, label_batch)

torch.nn模块:是构建神经网络的基石,提供了各种类型的层,包括卷积层、池化层、激活函数、循环层和全连接层

4、优化算法

本次还是采用:自适应学习率的优化算法,adam优化器。

三、模型的训练

1、数据集调整

此次的数据集依然是mnist数据集,下载与处理可以看Hello World!-CSDN博客。不过对于idx3-udyte文件,如果我们每次都需要这么读一下,那多少有些麻烦,可以将idx3-udyte文件转换成npy文件以便长期存储。

npy文件是NumPy数组的一种存储格式,用于将Python中的NumPy数组保存到磁盘上,并可以在以后需要时加载回来。它不仅存储了数组的数据,还存储了数组的形状、数据类型等信息。这使得.npy文件非常紧凑且占用空间小,并且可以快速地读写。由于它是二进制格式,所以不能直接用文本编辑器打开查看内容。非常适合在需要频繁保存和加载数组数据的场景下使用,例如机器学习中的模型参数保存、数据集的存储等。

我们回到上期的读取文件的部分

​
def read_data3(self, roadurl):
        with open(roadurl, 'rb') as f:
            content = f.read()
        # print(content)
 
        fmt_header = '>iiii'  # 网络字节序
        offset = 0
 
        magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, content, offset)
        print('幻数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))
        # 定义一张图片需要的数据个数(每个像素一个字节,共需要行*列个字节的数据)
        img_size = num_rows * num_cols
        # struct.calcsize(fmt)用来计算fmt格式所描述的结构的大小
        offset += struct.calcsize(fmt_header)
        # '>784B'是指用大端法读取784个unsigned byte
        fmt_image = '>' + str(img_size) + 'B'
        # 定义了一个三维数组,这个数组共有num_images个 num_rows*num_cols尺寸的矩阵。
        images = np.empty((num_images, num_rows, num_cols))
 
        for i in range(num_images):
            images[i] = np.array(struct.unpack_from(fmt_image, content, offset)).reshape((num_rows, num_cols))
            offset += struct.calcsize(fmt_image)
 
        return images

​

我们已经接收到了images这个numpy数组,接下来只需要增加一点代码

# 保存到
np.save('文件路径/文件名.npy', array)
# 读取文件
data1 = np.load('文件路径/文件名.npy')

2、模型搭建

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 指定GPU编
import torch
import numpy as np
from tqdm import tqdm

batch_size = 320  # 设定每次训练的批次数
epochs = 1024  # 设定训练次数

# device = "cpu"                         # Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"  # 在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式


# 设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 312),
            torch.nn.ReLU(),
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)
        )

    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)

        return logits


model = NeuralNetwork()
model = model.to(device)  # 将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
# model = torch.compile(model)            # Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()       # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)  # 设定优化函数

# 载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")

train_num = len(x_train) // batch_size

# 开始计算
for epoch in range(20):
    train_loss = 0
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size

        train_batch = torch.tensor(x_train[start:end]).to(device)
        label_batch = torch.tensor(y_train_label[start:end]).to(device)

        pred = model(train_batch)
        loss = loss_fu(pred, label_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  # 记录每个批次的损失值

    # 计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_size
    print("epoch:", epoch, "train_loss:", round(train_loss, 2), "accuracy:", round(accuracy, 2))
torch.save(model, './model.pth')
print("模型已更新")

3、训练完成

四、可视化操作

1、代码

import torch
import torch.nn as nn

if __name__ == '__main__':
    model = NeuralNetwork()
    #print(model)

    params = list(model.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

2、使用netron软件进行可视化(推荐)

在github上下载:https://github.com/lutzroeder/netron

 五、模型的部署(暂无)

从前边可知,所谓模型就是一个由无数参数组成的.pth文件,而.pth文件可以通过python指令读取。

torch.load()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/296312.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录 1143. 最长公共子序列

题目 给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 ,返回 0 。 一个字符串的 子序列 是指这样一个新的字符串:它是由原字符串在不改变字符的相对顺序的情况下删除某些字符(也…

MongoDB 启动时:服务名无效

1.问题场景 电脑睡眠后,再连接服务发现无法连接,启动服务报:服务名无效。 2.打开服务管理: 发现服务中没有MongoDB的服务 3.解决 (1)先找打MongoDB安装路径,把data文件夹下所有文件删除 &a…

Vue中使用Element UI的Table组件实现嵌套表格(最简单示例)

以下是一个简单的示例代码&#xff0c;演示如何在Vue中使用Element UI的Table组件实现嵌套表格&#xff1a; html <template><div><el-table :data"tableData" style"width: 100%"><el-table-column prop"name" label&quo…

Centos服务器安装Certbot以webroot的方式定时申请SSL免费证书

最近发现原先免费一年的SSL证书都改为3个月的有效期了&#xff0c;原先一年操作一次还能接受&#xff0c;现在3个月就要手动续期整的太慢烦了&#xff0c;还是让程序自动给处理下吧&#xff0c; 安装 Certbot yum install epel-release -y yum install certbot -yEPEL是由 Fe…

云计算历年题整理

第一大题 第一大题计算 给出计算连接到EC2节点的EBS的高可用性(HA)的数学公式&#xff0c;如场景中所述&#xff1b;计算EC2节点上的EBS的高可用性(HA)&#xff1b;场景中80%的AWS EC2节点用于并行处理&#xff0c;总共有100个虚拟中央处理单元(vCPUs)用于处理数据&#xff0…

蟹目标检测数据集VOC格式400张

蟹&#xff0c;一种独特的海洋生物&#xff0c;以其强壮的身体和独特的生活习性而闻名。 蟹的身体宽厚&#xff0c;有一对锐利的大钳子&#xff0c;这使得它们在寻找食物和保护自己时非常有力。蟹的外观颜色多样&#xff0c;有绿色、蓝色、棕色和红色等&#xff0c;这使得它们在…

法一(auto-py-to-exe):Pyinstaller将yolov5的detect.py封装成detect.exe

pip install pyinstaller # 安装最新版本的pyinstaller指令# 在dist目录下只生成一个较大xxx.exe文件&#xff0c;所有依赖库全打包到exe中&#xff0c;打包后的exe可单独使用 pyinstaller -F xxx.py # 在dist目录下生成较小的exe文件&#xff0c;其他依赖库全都在dist文件夹下…

[C#]利用opencvsharp实现深度学习caffe模型人脸检测

【官方框架地址】 https://github.com/opencv/opencv/blob/master/samples/dnn/face_detector/deploy.prototxt 采用的是官方caffe模型res10_300x300_ssd_iter_140000.caffemodel进行人脸检测 【算法原理】 使用caffe-ssd目标检测框架训练的caffe模型进行深度学习模型检测 …

【ARMv8架构系统安装PySide2】

ARMv8架构系统安装PySide2 Step1. 下载Qt资源包Step2. 配置和安装Qt5Step3. 检查Qt-5.15.2安装情况Step4. 安装PySide2所需的依赖库Step5. 下载和配置PySide2Step6. 检验PySide2是否安装成功 Step1. 下载Qt资源包 if you need the whole Qt5 (~900MB): wget http://master.qt…

全新盲盒商城源码 /潮乎盲盒源码 /搭建教程/后端采用Laravel框架开发

源码介绍&#xff1a; 全新盲盒商城源码、潮乎盲盒源码&#xff0c;它附有搭建教程&#xff0c;后端采用Laravel框架开发。 采用后端Laravel框架进行开发&#xff0c;前端开发框架则使用了uniappvue。在环境配置方面&#xff0c;我们建议使用php7.4 mysql5.6 nginx1.22 re…

用友U8 Cloud smartweb2.RPC.d XML外部实体注入漏洞

产品介绍 用友U8cloud是用友推出的新一代云ERP&#xff0c;主要聚焦成长型、创新型、集团型企业&#xff0c;提供企业级云ERP整体解决方案。它包含ERP的各项应用&#xff0c;包括iUAP、财务会计、iUFO cloud、供应链与质量管理、人力资源、生产制造、管理会计、资产管理&#…

MATLAB中xcorr函数用法

目录 语法 说明 示例 两个向量的互相关 向量的自相关 归一化的互相关 xcorr函数的功能是返回互相关关系。 语法 r xcorr(x,y) r xcorr(x) r xcorr(___,maxlag) r xcorr(___,scaleopt) [r,lags] xcorr(___) 说明 r xcorr(x,y) 返回两个离散时间序列的互相关。互相…

基于R语言(SEM)结构方程模型教程

详情点击链接&#xff1a;基于R语言&#xff08;SEM&#xff09;结构方程模型教程 01、R/Rstudio (2)R语言基本操作&#xff0c;包括向量、矩阵、数据框及数据列表等生成和数据提取等 (3)R语言数据文件读取、整理&#xff08;清洗&#xff09;、结果存储等&#xff08;含tidve…

助力实体店数字化升级,VR智慧门店打造线上逛店体验

近年来&#xff0c;传统实体店业绩增长过于缓慢&#xff0c;实体门店的销售疲态十分明显&#xff0c;甚至于部分城市已经出现大量线下实体店开始关门的现象&#xff0c;因此顺应实体零售数字化升级趋势已经刻不容缓。越来越多的实体门店开始意识到这个问题&#xff0c;并逐步开…

window服务器thinkphp队列监听服务

经常使用linux的同学们应该对使用宝塔来做队列监听一定非常熟悉&#xff0c;但对于windows系统下&#xff0c;如何去做队列的监听&#xff1f;是一个很麻烦的事情。 本文将通过windows系统的服务来实现队列的监听。 对于thinkphp6 queue如何使用&#xff0c;不再赘述。其它系…

算法29:不同路径问题(力扣62和63题)--针对算法28进行扩展

题目&#xff1a;力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为 “Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff0…

L1-085:试试手气

我们知道一个骰子有 6 个面&#xff0c;分别刻了 1 到 6 个点。下面给你 6 个骰子的初始状态&#xff0c;即它们朝上一面的点数&#xff0c;让你一把抓起摇出另一套结果。假设你摇骰子的手段特别精妙&#xff0c;每次摇出的结果都满足以下两个条件&#xff1a; 1、每个骰子摇出…

设计模式② :交给子类

文章目录 一、前言二、Template Method 模式1. 介绍2. 应用3. 总结 三、Factory Method 模式1. 介绍2. 应用3. 总结 参考内容 一、前言 有时候不想动脑子&#xff0c;就懒得看源码又不像浪费时间所以会看看书&#xff0c;但是又记不住&#xff0c;所以决定开始写"抄书&qu…

C#之反编译之路(一)

本文将介绍微软反编译神器dnSpy的使用方法 c#反编译之路(一) dnSpy.exe区分64位和32位,所以32位的程序,就用32位的反编译工具打开,64位的程序,就用64位的反编译工具打开(个人觉得32位的程序偏多,如果不知道是32位还是64位,就先用32位的打开试试) 目前只接触到wpf和winform的桌…

算法每日一题:在链表中插入最大公约数 | 链表 | 最大公约数

hello&#xff0c;大家好&#xff0c;我是星恒 今天的题目是有关链表和最大公约数的题目&#xff0c;比较简单&#xff0c;核心在于求解最大公约数&#xff0c;我们题解中使用辗转相除法来求解&#xff0c;然后我们会在最后给大家拓展一下求解最大公约数的四个方法&#xff0c;…
最新文章