前馈神经网络dropout实例

直接看代码。

(一)手动实现


import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


#下载MNIST手写数据集  
mnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  

#读取数据  
batch_size = 256 
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  


#初始化参数  
num_inputs,num_hiddens,num_outputs =784, 256,10

num_epochs=30

lr = 0.001

def init_param():
    W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  
    b1 = torch.zeros(1, dtype=torch.float32)  
    W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  
    b2 = torch.zeros(1, dtype=torch.float32)  
    params =[W1,b1,W2,b2]
    for param in params:  
        param.requires_grad_(requires_grad=True)  
    return W1,b1,W2,b2

def dropout(X, drop_prob):
    X = X.float()
    assert 0 <= drop_prob <= 1
    keep_prob = 1 - drop_prob
    if keep_prob == 0:
        return torch.zeros_like(X)
    mask = (torch.rand(X.shape) < keep_prob).float()
    print(mask)
    return mask * X / keep_prob

def net(X, is_training=True):
    X = X.view(-1, num_inputs)
    H1 = (torch.matmul(X, W1.t()) + b1).relu()
    if is_training:
        H1 = dropout(H1, drop_prob)
    return (torch.matmul(H1,W2.t()) + b2).relu()


def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None):
    train_ls, test_ls = [], []
    for epoch in range(num_epochs):
        ls, count = 0, 0
        for X,y in train_iter:
            l=loss(net(X),y)
            
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            
            ls += l.item()
            count += y.shape[0]
            
        train_ls.append(ls)
        
        ls, count = 0, 0
        
        for X,y in test_iter:
            
            l=loss(net(X,is_training=False),y)
            
            ls += l.item()
            count += y.shape[0]
            
        test_ls.append(ls)
        
        if(epoch+1)%10==0:
            print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
            
    return train_ls,test_ls


drop_probs = np.arange(0,1.1,0.1)

Train_ls, Test_ls = [], []

for drop_prob in drop_probs:

    W1,b1,W2,b2 = init_param()
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)
    train_ls, test_ls =  train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer)   
    Train_ls.append(train_ls)
    Test_ls.append(test_ls)
    
    
x = np.linspace(0,len(train_ls),len(train_ls))

plt.figure(figsize=(10,8))

for i in range(0,len(drop_probs)):
    plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    
# plt.legend()
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()

运行结果:

在这里插入图片描述

(二)torch.nn实现

import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

mnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  
batch_size = 256 
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  


class LinearNet(nn.Module):
    def __init__(self,num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob1,drop_prob2):
        super(LinearNet,self).__init__()
        self.linear1 = nn.Linear(num_inputs,num_hiddens1)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(drop_prob1)
        self.linear2 = nn.Linear(num_hiddens1,num_hiddens2)
        self.drop2 = nn.Dropout(drop_prob2)
        self.linear3 = nn.Linear(num_hiddens2,num_outputs)
        self.flatten  = nn.Flatten()
    
    def forward(self,x):
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.drop1(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.drop2(x)
        x = self.linear3(x)
        y = self.relu(x)
        return y
    
    
    
    
def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):
    train_ls, test_ls = [], []
    for epoch in range(num_epochs):
        ls, count = 0, 0
        for X,y in train_iter:
            l=loss(net(X),y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            ls += l.item()
            count += y.shape[0]
        train_ls.append(ls)
        ls, count = 0, 0
        for X,y in test_iter:
            l=loss(net(X),y)
            ls += l.item()
            count += y.shape[0]
        test_ls.append(ls)
        if(epoch+1)%5==0:
            print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
    return train_ls,test_ls    
    
    
    
    
    
num_inputs,num_hiddens1,num_hiddens2,num_outputs =784, 256,256,10
num_epochs=20
lr = 0.001
drop_probs = np.arange(0,1.1,0.1)
Train_ls, Test_ls = [], []

for drop_prob in drop_probs:
    net = LinearNet(num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob,drop_prob)
    for param in net.parameters():
        nn.init.normal_(param,mean=0, std= 0.01)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(),lr)
    train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,net.parameters,lr,optimizer)
    Train_ls.append(train_ls)
    Test_ls.append(test_ls)
    
    
    
    
x = np.linspace(0,len(train_ls),len(train_ls))
plt.figure(figsize=(10,8))
for i in range(0,len(drop_probs)):
    plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)
    plt.xlabel('epoch')
    plt.ylabel('loss')
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()


input = torch.randn(2, 5, 5)
m = nn.Sequential(
nn.Flatten()
)
output = m(input)
output.size()

运行结果:

在这里插入图片描述

关于dropout的原理,网上资料很多,一般都是用一个正态分布的矩阵,比较矩阵元素和(1-dropout),大于(1-dropout)的矩阵元素值的修正为1,小于(1-dropout)的改为1,将输入的值乘以修改后的矩阵,再除以(1-dropout)。

疑问:

  1. 数值经过正态分布矩阵的筛选后,还要除以 (1-dropout),这样做的原因是什么?
  2. Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/83626.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

一百六十三、Kettle——Linux上安装Kettle9.2(亲测有效,附截图)

一、目的 由于之前发现kettle8.2和kettle9.3这两个版本&#xff0c;或多或少的存在问题 比如kettle8.2的本地服务没问题&#xff0c;但在Linux上创建共享资源库时就有问题&#xff1b; 比如kettle9.3由于不自带shims驱动包&#xff0c;目前在新的下载官网上无法找到下载路径…

PCIE 信息

PCIe&#xff08;外围组件互连快件&#xff09;是用于连接高速组件的接口标准。每台台式电脑主板有许多 PCIe 插槽&#xff0c;可用于添加通用显卡&#xff0c;各种外设卡&#xff0c;无线网卡或固态硬盘等等。PC 中可用的 PCIe 插槽类型将取决于你购买的主板. PCIe 插槽有不同…

LeetCode 542. 01 Matrix【多源BFS】中等

本文属于「征服LeetCode」系列文章之一&#xff0c;这一系列正式开始于2021/08/12。由于LeetCode上部分题目有锁&#xff0c;本系列将至少持续到刷完所有无锁题之日为止&#xff1b;由于LeetCode还在不断地创建新题&#xff0c;本系列的终止日期可能是永远。在这一系列刷题文章…

ubuntu18.04安装keil5(踩坑)看完再享用,别直接上手

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、安装winewine的总结 二、安装Keil5总结 前言 切记看完再享用&#xff0c;别直接上手&#xff0c;不然安装的时候会和我一样踩坑的&#xff08;走了很多弯路…

【汇编语言】CS、IP寄存器

文章目录 修改CS、IP的指令转移指令jmp问题分析 修改CS、IP的指令 理论&#xff1a;CPU执行何处的指令&#xff0c;取决于CS:IP应用&#xff1a;程序员可以通过改变CS、IP中的内容&#xff0c;进行控制CPU即将要执行的目标指令&#xff1b;问题&#xff1a;如何改变CS、IP中的…

数据在内存中的储存·大小端(文字+画图详解)(c语言·超详细入门必看)

前言&#xff1a;Hello&#xff0c;大家好&#xff0c;我是心跳sy&#x1f618;&#xff0c;本节我们介绍c语言的两种基本的内置数据类型&#xff1a;数值类型和字符类型在内存中的储存方法&#xff0c;并对大小端进行详细介绍&#xff08;附两种大小端判断方法&#xff09;&am…

打印技巧——word中A4排版打印成A3双面对折翻页

在进行会议文件打印时&#xff0c;我们常会遇到需要将A4排版的文件&#xff0c;在A3纸张上进行双面对折翻页打印&#xff0c;本文对设置方式进行介绍&#xff1a; 1、在【布局】选项卡中&#xff0c;点击右下角小箭头&#xff0c;打开页面设置选项卡 1.1在【页边距】中将纸张…

【校招VIP】java语言考点之List和扩容

考点介绍&#xff1a; List是最基础的考点&#xff0c;但是很多同学拿不到满分。本专题从两种实现子类的比较&#xff0c;到比较复杂的数组扩容进行分析。 『java语言考点之List和扩容』相关题目及解析内容可点击文章末尾链接查看&#xff01;一、考点题目 1、以下关于集合类…

k8s容器加入host解析字段

一、通过edit或path来修改 kubectl edit deploy /xxxxx. x-n cattle-system xxxxx为你的资源对象名称 二、添加字段 三、code hostAliases:- hostnames:- www.rancher.localip: 10.10.2.180

Linux面试笔试题(1)

1、以长格式列目录时&#xff0c;若文件test的权限描述为&#xff1a;drwxrw-r–&#xff0c;则文件test的类型及文件主的权限是__A____。 A.目录文件、读写执行 B.目录文件、读写 C.普通文件、读写 D.普通文件、读 在这个问题中&#xff0c;我们需要解析文件权限的描述&…

ViT模型架构和CNN区别

目录 Vision Transformer如何工作 ViT模型架构 ViT工作原理解析 步骤1&#xff1a;将图片转换成patches序列 步骤2&#xff1a;将patches铺平 步骤3&#xff1a;添加Position embedding 步骤4&#xff1a;添加class token 步骤5&#xff1a;输入Transformer Encoder 步…

「Qt」文件读写操作

0、引言 我们知道 C 和 C 都提供了文件读写的类库&#xff0c;不过 Qt 也有一套自己的文件读写操作&#xff1b;本文主要介绍 Qt 中进行文件读写操作的类 —— QFile。 1、QFileDialog 文件对话框 一般的桌面应用程序&#xff0c;当我们想要打开一个文件时&#xff0c;通常会弹…

【广州虚拟现实开发】VR智能中控系统进一步提高VR教学管理水平

随着科技的不断发展&#xff0c;虚拟现实(VR)技术已经逐渐走进了人们的生活。在教育领域&#xff0c;VR技术也得到了广泛的应用&#xff0c;尤其是在教学终端中控系统方面。那么&#xff0c;广州华锐互动开发的VR智能中控系统对学校有何益处呢&#xff1f; 首先&#xff0c;VR智…

C# 学习笔记

此笔记极水~ &#xff0c;来自两年前的库存。 是来自 B站 刘铁猛大佬 的视频&#xff0c;因为 好奇学了学。 其他 c# 变量的 内联赋值 vs. 构造函数内赋值 (引用自&#xff1a;https://www.iteye.com/blog/roomfourteen224-2208838) 上下文&#xff1a;c#中变量的内联赋值其…

Windows Server --- RDP远程桌面服务器激活和RD授权

RDP远程桌面服务器激活和RD授权 一、激活服务器二、设置RD授权 系统&#xff1a;Window server 2008 R2 服务&#xff1a;远程桌面服务 注&#xff1a;该方法适合该远程桌面服务器没网络状态下&#xff08;离线&#xff09;&#xff0c;激活服务器。 一、激活服务器 1.打开远…

css学习4(背景)

1、CSS中&#xff0c;颜色值通常以以下方式定义: 十六进制 - 如&#xff1a;"#ff0000"RGB - 如&#xff1a;"rgb(255,0,0)"颜色名称 - 如&#xff1a;"red" 2、background-image 属性描述了元素的背景图像. 默认情况下&#xff0c;背景图像进…

机器人操作系统【02】:如何在 ROS2 中对点云数据进行建模

一、说明 RViz和Gazebo中RADU的模拟进展顺利。在上一篇文章中&#xff0c;我们学习了如何启动机器人并使用远程节点进行操作。在本文中&#xff0c;我们将添加两个视觉传感器。首先&#xff0c;一个图像摄像机&#xff0c;用于在机器人四处移动时查看机器人的实时馈送。其次&am…

浅析深浅拷贝

我们在对对象进行复制时就用到深浅拷贝。 一、普通复制 <script>const people{name:tim,age:22}const testpeople;console.log(test);//tim 22test.age20;console.log(test);//tim 20console.log(people);//tim 20 </script> 控制台打印结果&#xff1a; 之所以…

spad芯片学习总结

一、时间相关单光子计数法TCSPC(Time correlated single photon counting) 1> 如果spad接收用单次发射、峰值检测会怎么样 首先spad是概率性触发的器件&#xff0c;探测到的概率远小于1&#xff0c;而且不仅接收信号的光子可以触发&#xff0c;环境光噪声一样会被spad接收到…

使用 Ploomber、Arima、Python 和 Slurm 进行时间序列预测

推荐&#xff1a;使用 NSDT场景编辑器助你快速搭建可二次编辑的3D应用场景 简短的笔记本说明 笔记本由 8 个任务组成&#xff0c;如下图所示。它包括建模的大多数基本步骤 - 获取数据清理、拟合、超参数调优、验证和可视化。作为捷径&#xff0c;我拿起笔记本并使用Soorgeon工具…