【PyTorch】多层感知机

文章目录

  • 1. 理论介绍
    • 1.1. 背景
    • 1.2. 多层感知机
    • 1.3. 激活函数
      • 1.3.1. ReLU函数
      • 1.3.2. sigmoid函数
      • 1.3.3. tanh函数
  • 2. 代码实现
    • 2.1. 主要代码
    • 2.2. 完整代码
    • 2.2. 输出结果

1. 理论介绍

1.1. 背景

许多问题要使用线性模型,但无法简单地通过预处理来实现。此时我们可以通过在网络中加入一个或多个隐藏层来克服线性模型的限制, 使其能处理更普遍的函数关系类型。

1.2. 多层感知机

将许多全连接层堆叠在一起。 每一层都输出到上面的层,直到生成最后的输出,我们可以把前层看作表示,把最后一层看作线性预测器。 这种架构通常称为多层感知机,通常缩写为MLP。
多层感知机

1.3. 激活函数

我们需要在仿射变换之后对每个隐藏单元应用非线性的激活函数,这样就不可能再将我们的多层感知机退化成线性模型,使得模型具有更强的表达能力。
激活函数是通过计算加权和并加上偏置来确定神经元是否应该被激活, 并将输入信号转换为输出的可微运算的函数。

1.3.1. ReLU函数

  • 修正线性单元(Rectified linear unit,ReLU)。
  • 最受欢迎的激活函数。
  • 定义: R e L U ( x ) = m a x ( 0 , x ) \mathrm{ReLU}(x)=\mathrm{max}(0,x) ReLU(x)=max(0,x)
    relu
  • 当输入接近0时,sigmoid函数接近线性变换。
    gradofrelu
  • 当输入值精确等于0时,ReLU函数不可导。 在此时,我们默认使用左侧的导数,即当输入为0时导数为0。 我们可以忽略这种情况,因为输入可能永远都不会是0。
  • 变体:参数化的ReLU(Parameterized ReLU,pReLU),允许即使参数是负的,某些信息依然可以通过,其定义如下: p R e L U ( x ) = m a x ( 0 , x ) + α m i n ( 0 , x ) \mathrm{pReLU}(x)=\mathrm{max}(0,x)+\alpha\mathrm{min}(0,x) pReLU(x)=max(0,x)+αmin(0,x)等等。

1.3.2. sigmoid函数

  • 将输入变换为区间(0, 1)上的输出。
  • 在隐藏层中已经较少使用, 它在大部分时候被更简单、更容易训练的ReLU所取代。
  • 定义: s i g m o i d ( x ) = 1 1 + e x p ( − x ) \mathrm{sigmoid}(x)=\frac{1}{1+\mathrm{exp}(-x)} sigmoid(x)=1+exp(x)1
    sigmoid
  • 导数: d d x s i g m o i d ( x ) = s i g m o i d ( x ) ( 1 − s i g m o i d ( x ) ) \frac{\mathrm{d}}{\mathrm{d}x}\mathrm{sigmoid}(x)=\mathrm{sigmoid}(x)(1-\mathrm{sigmoid}(x)) dxdsigmoid(x)=sigmoid(x)(1sigmoid(x))
    gradofsigmoid

1.3.3. tanh函数

  • 将其输入压缩转换到区间(-1, 1)上。
  • 定义: t a n h ( x ) = 1 − e x p ( − 2 x ) 1 + e x p ( − 2 x ) \mathrm{tanh}(x)=\frac{1-\mathrm{exp}(-2x)}{1+\mathrm{exp}(-2x)} tanh(x)=1+exp(2x)1exp(2x)
    tanh
  • 当输入接近0时,tanh函数接近线性变换。
  • 导数: d d x t a n h ( x ) = 1 − t a n h 2 ( x ) \frac{\mathrm{d}}{\mathrm{d}x}\mathrm{tanh}(x)=1-\mathrm{tanh}^2(x) dxdtanh(x)=1tanh2(x)
    gradoftanh

2. 代码实现

2.1. 主要代码

net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10)
    ).cuda()

2.2. 完整代码

import os
import torch
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torch import nn
from tensorboardX import SummaryWriter
from rich.progress import track

def load_dataset():
    """加载数据集"""
    root = "./dataset"
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = FashionMNIST(
        root=root, 
        train=True, 
        transform=transform, 
        download=True
    )
    mnist_test = FashionMNIST(
        root=root, 
        train=False, 
        transform=transform, 
        download=True
    )

    dataloader_train = DataLoader(
        mnist_train,
        batch_size, 
        shuffle=True,
        num_workers=num_workers
    )
    dataloader_test = DataLoader(
        mnist_test,
        batch_size, 
        shuffle=False,
        num_workers=num_workers
    )
    return dataloader_train, dataloader_test

class Accumulator:
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def accuracy(y_hat, y):
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())


if __name__ == "__main__":
    # 全局参数设置
    batch_size = 256
    num_epochs = 10
    num_workers = 3

    lr = 0.1
    device = torch.device('cuda:0')

    # 创建记录器
    def log_dir():
        root = "runs"
        if not os.path.exists(root):
            os.mkdir(root)
        order = len(os.listdir(root)) + 1
        return f'{root}/exp{order}'
    writer = SummaryWriter(log_dir=log_dir())

    # 加载数据集
    dataloader_train, dataloader_test = load_dataset()

    # 定义模型
    net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10)
    ).to(device)
    def init_weights(m):
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight, mean=0, std=0.01)
            nn.init.constant_(m.bias, val=0)
    net.apply(init_weights)
    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)

    train_metrics = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数
    test_metrics = Accumulator(2)   # 测试准确度总和、样本数
    for epoch in track(range(num_epochs), description='多层感知机'):
        for X, y in dataloader_train:
            X, y = X.to(device), y.to(device)
            loss = criterion(net(X), y)
            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()

        train_metrics.reset()
        for X, y in dataloader_train:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            train_metrics.add(loss.sum(), accuracy(y_hat, y), y.numel())
        train_loss, train_acc = train_metrics[0]/train_metrics[2], train_metrics[1]/train_metrics[2]

        test_metrics.reset()
        with torch.no_grad():    
            for X, y in dataloader_test:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                test_metrics.add(accuracy(y_hat, y), y.numel())
        test_acc = test_metrics[0] / test_metrics[1]
        writer.add_scalars("metrics", {
            'train_loss': train_loss, 
            'train_acc': train_acc, 
            'test_acc': test_acc
            }, epoch)
        
    writer.close()

2.2. 输出结果

多层感知机

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/228660.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

暖雪-终业游戏攻略 开荒职业无量尊者圣物搭配(60亿秒伤)

本攻略基本没有用到dlc2的圣物,便于前期开荒,远程攻击难度低(本体也能用这套搭配) 圣物搭配&面板展示 开局可以选择遗物的可以选:堕龙(放三号位) 核心: 灵玉/青龙力量: 玄武/青龙/飞蝗剑/灵玉/朱雀敏捷: 堕龙功效: 憎恨之心 圣物优先级越靠前的越好 武器选择 回魂-搭…

Sublime Text 卡顿

复制下方代码,保存后重启Sublime Text {"non_blocking" : "true","live_mode" : "false" }

28. Python Web 编程:Django 基础教程

目录 安装使用创建项目启动服务器创建数据库创建应用创建模型设计路由设计视图设计模版 安装使用 Django 项目主页:https://www.djangoproject.com 访问官网 https://www.djangoproject.com/download/ 或者 https://github.com/django/django Windows 按住winR 输…

基于SpringBoot+Vue学生成绩管理系统前后端分离(源码+数据库)

一、项目简介 本项目是一套基于SpringBootVue学生成绩管理系统,主要针对计算机相关专业的正在做bishe的学生和需要项目实战练习的Java学习者。 包含:项目源码、数据库脚本等,该项目可以直接作为bishe使用。 项目都经过严格调试,确…

OpenCL学习笔记(二)手动编译开发库(win10+vs2019)

前言 有时需求比较特别,可能需要重新编译opencl的sdk库。本文档简单记录下win10下,使用vs2019编译的过程,有需要的小伙伴可以参考下 一、获取源码 项目地址:GitHub - KhronosGroup/OpenCL-SDK: OpenCL SDK 可以直接使用git命令…

PostgreSQL从小白到高手教程 - 第38讲:数据库备份

PostgreSQL从小白到专家,是从入门逐渐能力提升的一个系列教程,内容包括对PG基础的认知、包括安装使用、包括角色权限、包括维护管理、、等内容,希望对热爱PG、学习PG的同学们有帮助,欢迎持续关注CUUG PG技术大讲堂。 第38讲&#…

【MATLAB源码-第97期】基于matlab的能量谷优化算法(EVO)机器人栅格路径规划,输出做短路径图和适应度曲线。

操作环境: MATLAB 2022a 1、算法描述 能量谷优化算法(Energy Valley Optimization, EVO)是一种启发式优化算法,灵感来源于物理学中的“能量谷”概念。它试图模拟能量在不同能量谷中的转移过程,以寻找最优解。 在EVO…

当然热门的原创改写改写大全【2023最新】

在信息时代,随着科技的不断发展,改写软件逐渐成为提高文案质量和写作效率的重要工具。本文将专心分享一些好用的改写软件,其中包括百度文心一言智能写作以及147SEO改写软件。这些工具不仅支持批量改写,而且在发布到各大平台后能够…

一文详解设备维护管理软件:降本增效的关键利器

设备维护管理软件是一种专门为优化和简化设备维护流程而设计的工具。随着技术的进步和业务的扩展,各类企业和机构不得不面对规模日益庞大和复杂的设备和基础设施,如何高效地维护和管理这些设备成为企业发展中面临的一项重要挑战。 在这个背景下&#xff…

uniapp实战 —— 开发微信小程序的调试技巧

手机真机调试微信小程序 开发版和体验版的小程序,域名没有备案时想调试接口访问效果,可以按下述方式操作: 在手机上点右上方三个点,点击“开发调试”,开启调试模式,即可真机访问接口(跳过域名校…

一篇文章理解 Taro1/3原理

最近在进行 Taro项目的升级,由 Taro1升级到 Taro3,但是细想一下一直都是在使用 Taro,却没有进行深入的了解,本篇文章就深入解析下 什么是 Taro? 为什么要用 Taro 在使用任何一种技术之前,都要了解这两个问题&#xf…

LangChain的函数,工具和代理(六):Conversational agent

关于langchain的函数、工具、代理系列的博客我之前已经写了五篇,还没有看过的朋友请先看一下,这样便于对后续博客内容的理解: LangChain的函数,工具和代理(一):OpenAI的函数调用 LangChain的函数,工具和代…

05-微服务架构构建之六边形架构

文章目录 前言一、六边形架构的概念二、六边形架构的特点三、微服务架构的良好实践总结 前言 通过前面的学习,我们掌握了微服务架构的基本组件等内容。在选择适合每个微服务的架构时,六边形架构“天然”成为每个微服务构建的最佳选择。 一、六边形架构的…

HttpComponents: 概述

文章目录 1. 概述2. 生态位 1. 概述 早期的Java想要实现HTTP客户端需要借助URL/URLConnection或者自己手动从Socket开始编码,需要处理大量HTTP协议的具体细节,不但繁琐还容易出错。 Apache Commons HttpClient的诞生就是为了解决这个问题,它…

数据资源和数据资产的区别是什么?

数据资源:狭义的数据资源是指数据本身,即企业运作中积累下来的各种各样的数据记录,如客户记录、销售记录、人事记录、采购记录、财务数据和库存数据等。广义的数据资源涉及数据的产生、处理、传播、交换的整个过程,包括数据本身、…

Vue3小兔鲜电商前台项目总结

1.code地址 https://github.com/15347113049/vue-rabbit.git 2.项目基础栈 Vue3全家桶:create-vue Pinia ElementPlus Vue3Setup Vue-Router VueUse 3.主要业务 (1)整体路由搭建 (2)layout布局 (3)Home页一级分类 (4)二级分类详情页 (5)登录功能 (6)购物车(头部…

基于POSIX标准的Linux进程间通信

文章目录 1 管道(匿名管道)1.1 管道抽象1.2 接口——pipe1.3 管道的特征1.4 管道的四种情况1.5 匿名管道用例 2 命名管道2.1 创建一个命名管道——mkfifo2.2 关闭一个管道文件——unlink2.3 管道和命名管道的补充2.4 命名管道用例 3 共享内存3.1 原理3.2…

持续集成交付CICD:Jenkins使用GitLab共享库实现前后端项目Sonarqube

目录 一、实验 1.Jenkins使用GitLab共享库实现后端项目Sonarqube 2.优化GitLab共享库 3.Jenkins使用GitLab共享库实现前端项目Sonarqube 4.Jenkins通过插件方式进行优化 二、问题 1.sonar-scanner 未找到命令 2.npm 未找到命令 一、实验 1.Jenkins使用GitLab共享库实现…

【MATLAB源码-第98期】基于matlab的能量谷优化算法(EVO)无人机三维路径规划,输出做短路径图和适应度曲线。

操作环境: MATLAB 2022a 1、算法描述 能量谷优化算法(Energy Valley Optimization, EVO)是一种启发式优化算法,灵感来源于物理学中的“能量谷”概念。它试图模拟能量在不同能量谷中的转移过程,以寻找最优解。 在EVO…

vue 一直运行 /sockjs-node/info?及 /sockjs-node/info报错解决办法

sockjs-node介绍 sockjs-node 是一个JavaScript库,提供跨浏览器JavaScript的API,创建了一个低延迟、全双工的浏览器和web服务器之间通信通道。 服务端:sockjs-node(https://github.com/sockjs/sockjs-node) 客户端&a…
最新文章