VAE原理 代码详解 pin_memory

VAE代码

import torch
from torch import nn
import torch.nn.functional as F
class VAE(nn.Module):
    def __init__(self, input_dim=784, h_dim=400, z_dim=20):  # 28x28=784,20可能是这个手写体一共有20类?
        super(VAE, self).__init__()

        self.input_dim = input_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        '''编码器要用到的东西'''
        self.fc1 = nn.Linear(input_dim, h_dim)  # 第一个全连接层
        self.fc2 = nn.Linear(h_dim, z_dim)  # mu
        self.fc3 = nn.Linear(h_dim, z_dim)  # log_var

        '''解码器要用到的'''
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, input_dim)

    def encoder(self, x):
        '''
        :param x: image
        :return:  均值mu和方差log_var
        '''
        h = F.relu(self.fc1(x))
        mu = self.fc2(h)
        log_var = self.fc3(h)
        return mu, log_var

    def reparameterization(self, mu, log_var):
        '''
        reparameterization是重新采样的意思,标准正态分布 epsilon~N(0,1)
        :param mu:
        :param log_var:
        :return: 采样的z
        '''
        sigma = torch.exp(log_var * 0.5)
        eps = torch.randn_like(sigma)
        return mu + sigma * eps

    def decode(self, z):
        '''
        给出一个采样的z,把它解码回图片
        :param z:
        :return:
        '''
        h = F.relu(self.fc4(z))
        x_hat = torch.sigmoid(self.fc5(h))  # 图片归一化后的数值为0-1,不能用ReLU
        return x_hat

    def forward(self, x):
        '''
        :param x: [batch_size,通道,28,28]
        :return:
        '''
        batch_size = x.shape[0]
        # x.shape = [128,1,28,28]
        x = x.view(batch_size, self.input_dim)  # 把[batch_size,1,28,28]合并成 [batch_size,728]
        # 输入图片进行encoder 得到均值和方差
        mu, log_var = self.encoder(x)
        # 重采样得到潜在变量sampled_z
        sampled_z = self.reparameterization(mu, log_var)
        # 把采样的潜层变量解码回图片
        x_hat = self.decode(sampled_z)  # 预测的图片
        # 把形状改为 (batch,通道,28,28)
        x_hat = x_hat.view(batch_size,1,28,28)
        return x_hat, mu, log_var

训练部分代码

import torch
import time
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
from VAE import VAE
import matplotlib.pyplot as plt
import argparse
import os
import shutil
import numpy as np

# 设置运行的设备
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# 设置模型参数
parser = argparse.ArgumentParser(description="Variational Auto-Encoder MNIST Example")
parser.add_argument('--result_dir', type=str, default='./VAEResult', metavar='DIR', help='output directory')
parser.add_argument('--save_dir', type=str, default='./checkPoint', metavar='N', help='model saving directory')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='batch size for training(default: 128)')
parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train(default: 200)')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed(default: 1)')
parser.add_argument('--resume', type=str, default='', metavar='PATH', help='path to latest checkpoint(default: None)')
parser.add_argument('--test_every', type=int, default=10, metavar='N', help='test after every epochs')
parser.add_argument('--num_worker', type=int, default=1, metavar='N', help='the number of workers')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate(default: 0.001)')
parser.add_argument('--z_dim', type=int, default=20, metavar='N', help='the dim of latent variable z(default: 20)')
parser.add_argument('--input_dim', type=int, default=28 * 28, metavar='N', help='input dim(default: 28*28 for MNIST)')
parser.add_argument('--input_channel', type=int, default=1, metavar='N', help='input channel(default: 1 for MNIST)')
args = parser.parse_args()
# 如果cuda为True,那么添加两个键值对,num_workers和pin_memory(详细作用看下面的补充)
kwargs = {'num_workers': 2, 'pin_memory': True} if cuda else {}

def dataloader(batch_size=128,num_workers =2):
    # 把图片数据转换为tensor
    transform = transforms.Compose([transforms.ToTensor()])
    # 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
    mnist_train = datasets.MNIST('../data',
                                 train=True,
                                 transform=transform,
                                 download=True)
    mnist_test = datasets.MNIST('../data',
                                 train=False,
                                 transform=transform,
                                 download=True)
    mnist_train = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
    mnist_test = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True)
    classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
    return mnist_test, mnist_train, classes

def loss_function(x_hat, x, mu, log_var):
    """
    Calculate the loss. Note that the loss includes two parts.
    :param x_hat:
    :param x:
    :param mu:
    :param log_var:
    :return: total loss, BCE and KLD of our model
    """
    # 1. the reconstruction loss.
    # We regard the MNIST as binary classification
    BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')
    # 2. KL-divergence
    # D_KL(Q(z|X) || P(z)); calculate in closed form as both dist. are Gaussian
    # here we assume that \Sigma is a diagonal matrix, so as to simplify the computation
    KLD = 0.5 * torch.sum(torch.exp(log_var) + torch.pow(mu, 2) - 1. - log_var)

    # 3. total loss
    loss = BCE + KLD
    return loss, BCE, KLD

def save_checkpoint(state,is_best,outdir):
    '''
    每当训练一定的epochs后,判断损失函数的值是不是最小的 并保存模型的参数
    :param state: 要保存的模型参数,类型为dict
    :param is_best: 是否为当前最优
    :param outdir: 保存的文件夹
    :return:
    '''
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    checkpoint_file = os.path.join(outdir,'checkpoint.pth') # 把checkpoint.pth保存在outdir中
    best_file = os.path.join(outdir,'model_best.pth')
    torch.save(state,checkpoint_file)
    if is_best:
        # 如果是最优的参数,则把checkpoint_file复制为best_file
        shutil.copyfile(checkpoint_file,best_file)

def test(model,optimizer,mnist_test,epoch,best_test_loss):
    test_avg_loss = 0.0
    with torch.no_grad(): # 测试时不计算梯度
        for test_batch_index,(test_x,_) in enumerate(mnist_test):
            test_x = test_x.to(device)
            # 前向传播
            test_x_hat,test_mu,test_log_var = model(test_x)
            # 计算损失函数
            test_loss,test_BCE,test_KID = loss_function(test_x_hat,test_x,test_mu,test_log_var)
            test_avg_loss += test_loss
        # 对和求平均值,得到每一张图片的平均损失
        test_avg_loss /=len(mnist_test.dataset)

        '''测试随机生成的隐变量'''
        # 在正态分布中随机采样一个个数为batch_size,形状为z_dim的隐变量
        z = torch.randn(args.batch_size,args.z_dim).to(device)
        # 把隐变量输入到解码器生成图片
        random_res = model.decode(z).view(-1,1,28,28)
        # 保存生成的图片
        save_image(random_res,'./%s/random_sampled-%d.png'%(args.result_dir,epoch+1))

        '''保存目前训练好的模型'''
        is_best = test_avg_loss < best_test_loss
        best_test_loss = min(test_avg_loss,best_test_loss)
        save_checkpoint({
            'epoch':epoch,
            'best_test_loss':best_test_loss,
            'state_dict':model.state_dict(),
            'optimizer':optimizer.state_dict(),
        },is_best,args.save_dir)
        return best_test_loss

def train():
    # Step 1: 载入数据
    mnist_test, mnist_train, classes = dataloader(args.batch_size, args.num_worker)

    # 查看每一个batch图片的规模
    x, label = iter(mnist_train).__next__()  # 取出第一批(batch)训练所用的数据集
    print(' img : ', x.shape)  # img :  torch.Size([batch_size, 1, 28, 28]), 每次迭代获取batch_size张图片,每张图大小为(1,28,28)

    # Step 2: 准备工作 : 搭建计算流程
    model = VAE(z_dim=args.z_dim).to(device)  # 定义VAE模型,并转移到GPU上去
    print('The structure of our model is shown below: \n')
    print(model)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)  # 生成优化器,需要优化的是model的参数,学习率为0.001

    # Step 3: 选择是否加载保存的参数
    start_epoch = 0
    best_test_loss = np.finfo('f').max
    if args.resume:
        if os.path.isfile(args.resume):
            # 载入已经训练过的模型参数与结果
            print('=> loading checkpoint %s' % args.resume)
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch'] + 1
            best_test_loss = checkpoint['best_test_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('=> loaded checkpoint %s' % args.resume)
        else:
            print('=> no checkpoint found at %s' % args.resume)

    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    # Step 4: 开始训练
    loss_epoch = []
    for epoch in range(start_epoch, args.epochs):
        # 训练模型
        # 每一代都要遍历所有的批次
        loss_batch = []
        for batch_index, (x, _) in enumerate(mnist_train):
            # x : [b, 1, 28, 28], remember to deploy the input on GPU
            x = x.to(device)

            # 前向传播
            x_hat, mu, log_var = model(x)  # 模型的输出,在这里会自动调用model中的forward函数
            loss, BCE, KLD = loss_function(x_hat, x, mu, log_var)  # 计算损失值,即目标函数
            loss_batch.append(loss.item())  # loss是Tensor类型

            # 反向传播
            optimizer.zero_grad()  # 梯度清零,否则上一步的梯度仍会存在
            loss.backward()  # 后向传播计算梯度,这些梯度会保存在model.parameters里面
            optimizer.step()  # 更新梯度,这一步与上一步主要是根据model.parameters联系起来了

            # 每100个epoch打印一次
            if (batch_index + 1) % 100 == 0:
                print('Epoch [{}/{}], Batch [{}/{}] : Total-loss = {:.4f}, BCE-Loss = {:.4f}, KLD-loss = {:.4f}'
                      .format(epoch + 1, args.epochs, batch_index + 1, len(mnist_train.dataset) // args.batch_size,
                              loss.item() / args.batch_size, BCE.item() / args.batch_size,
                              KLD.item() / args.batch_size))

            if batch_index == 0:
                # visualize reconstructed result at the beginning of each epoch
                x_concat = torch.cat([x.view(-1, 1, 28, 28), x_hat.view(-1, 1, 28, 28)], dim=3)
                save_image(x_concat, './%s/reconstructed-%d.png' % (args.result_dir, epoch + 1))

        # 把这一个epoch的每一个样本的平均损失存起来
        loss_epoch.append(np.sum(loss_batch) / len(mnist_train.dataset))  # len(mnist_train.dataset)为样本个数

        # 测试模型
        if (epoch + 1) % args.test_every == 0:
            best_test_loss = test(model, optimizer, mnist_test, epoch, best_test_loss)
    return loss_epoch


if __name__ == '__main__':
    '''开始计时'''
    start_time = time.time()

    '''开始训练'''
    loss_epoch = train()

    '''计时结束'''
    end_time = time.time()
    run_time = end_time - start_time
    # 将输出的秒数保留两位小数
    if int(run_time) < 60:
        print(f'{round(run_time, 2)}s')
    else:
        print(f'{round(run_time / 60, 2)}minutes')

    # 绘制迭代结果
    plt.plot(loss_epoch)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()

在这里插入图片描述

补充

VAE不能用transforms.Normalize(0.5,0.5)进行归一化,否则Loss直接变成负数,loss要最小化,会变成越来越小的负数

在这里插入图片描述

F.relu(self.fc1(x))和nn.ReLU(self.fc1(x))有什么区别?

F.relu(self.fc1(x))和nn.ReLU(self.fc1(x))在功能上是相同的,都是使用ReLU(Rectified Linear Unit)作为激活函数来处理self.fc1(x)的结果。它们之间的区别在于调用方式和所属的模块。

F.relu()是PyTorch中torch.nn.functional模块中的一个函数,用于实现激活函数ReLU。这个函数是独立于任何特定的神经网络层的,你可以直接调用它来对张量进行ReLU操作。

nn.ReLU()是PyTorch中torch.nn模块中的一个类,用于构建ReLU激活函数的实例。通过将nn.ReLU()作为一个层添加到神经网络模型中,你可以在模型的前向传播过程中应用ReLU激活函数。

综上所述,F.relu(self.fc1(x))是直接调用了ReLU激活函数功能,而nn.ReLU(self.fc1(x))是通过在神经网络模型中添加一个ReLU层来实现激活函数的功能。

pin_memory参数的作用

pin_memory参数在PyTorch中用于数据加载过程中,特别是在使用GPU进行训练时。当设置pin_memory=True时,数据会被加载到主机(Host)的固定内存区域中,而不是被加载到默认的分页内存(Paged Memory)。这样做的目的是为了将数据从主机内存快速传输到GPU内存,以提高数据加载的效率。

在训练过程中,GPU通常需要频繁地从主机内存中读取数据。如果数据未锁定(pinned)并且位于分页内存中,GPU访问主机内存的速度可能会相对较慢。而将数据锁定在主机内存中,可以避免数据在传输过程中被分页,提高了数据传输的效率,从而减少了数据加载到GPU的时间。

需要注意的是,使用pin_memory=True会占用更多的主机内存资源,因此只有在确实需要提高数据加载效率的情况下才建议使用该参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/90597.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

微信开放注册微信小号功能,工作人群福音!

微信&#xff0c;这个坐拥数亿用户的社交巨头&#xff0c;最近终于开放了注册微信小号的功能。这个功能对于需要多个微信账号进行工作的人来说&#xff0c;无疑是一场及时雨&#xff0c;极大地提高了工作便利性。 在之前的版本中&#xff0c;每个微信账号都绑定了一个手机号&am…

主从、哨兵、集群模式有什么区别 ?

目录 1.Redis 多机部署的方式 2.主从、哨兵、集群模式有什么区别 2.1 主从同步 2.2 哨兵模式 2.3 集群模式 1.Redis 多机部署的方式 Redis 多机部署主要有 3 种方式&#xff1a; 1. 主从同步&#xff1a;主要存储数据的节点叫做主节点&#xff08;master&#xff09;&…

限时 180 天,微软为 RHEL 9 和 Ubuntu 22.04 推出 SQL Server 2022 预览评估版

导读近日消息&#xff0c;微软公司今天发布新闻稿&#xff0c;宣布面向 Red Hat Enterprise Linux&#xff08;RHEL&#xff09;9 和 Ubuntu 22.04 两大发行版&#xff0c;以预览模式推出 SQL Server 2022 评估版。 近日消息&#xff0c;微软公司今天发布新闻稿&#xff0c;宣布…

网络安全(黑客)零基础自学

网络安全是什么&#xff1f; 网络安全&#xff0c;顾名思义&#xff0c;网络上的信息安全。 随着信息技术的飞速发展和网络边界的逐渐模糊&#xff0c;关键信息基础设施、重要数据和个人隐私都面临新的威胁和风险。 网络安全工程师要做的&#xff0c;就是保护网络上的信息安…

数字 IC 设计职位经典笔/面试题(三)

共100道经典笔试、面试题目&#xff08;文末可全领&#xff09; 1. IC 设计中同步复位与异步复位的区别&#xff1f; 同步复位在时钟沿变化时&#xff0c;完成复位动作。异步复位不管时钟&#xff0c;只要复位信号满足条件&#xff0c;就完成复位动作。异步复位对复位信号要求…

开始MySQL之路——MySQL的DataGrip图形化界面

下载DataGrip 下载地址&#xff1a;Download DataGrip: Cross-Platform IDE for Databases & SQL 安装DataGrip 准备好一个文件夹&#xff0c;不要中文和空格 C:\Develop\DataGrip 激活DataGrip 激活码&#xff1a; VPQ9LWBJ0Z-eyJsaWNlbnNlSWQiOiJWUFE5TFdCSjBaIiwibGl…

用 Audacity 比较两段音频差异

工作中遇到相同的处理流程&#xff0c;处理同一段音频&#xff0c;看看处理结果是否一致&#xff0c;可以用audacity来处理。 假设待比较的音频分别为 1.wav 2.wav 1、用Audacity打开1.wav 2、用Audacity打开2.wav&#xff0c;选中音频&#xff0c;然后用 效果 -> 反向&am…

Linux内核学习(九)—— 虚拟文件系统(基于Linux 2.6内核)

虚拟文件系统&#xff08;VFS&#xff09;作为内核子系统&#xff0c;为用户空间程序提供了文件和文件系统相关的接口。通过虚拟文件系统&#xff0c;程序可以利用标准的 Unix 系统调用对不同的文件系统&#xff08;甚至不同介质上的文件系统&#xff09;进行读写操作。 一、通…

【算法系列篇】前缀和

文章目录 前言什么是前缀和算法1.【模板】前缀和1.1 题目要求1.2 做题思路1.3 Java代码实现 2. 【模板】二维前缀和2.1 题目要求2.2 做题思路2.3 Java代码实现 3. 寻找数组的中心下标3.1 题目要求3.2 做题思路3.3 Java代码实现 4. 除自身以外的数组的乘积4.1 题目要求4.2 做题思…

C++:构造方法(函数);拷贝(复制)构造函数:浅拷贝、深拷贝;析构函数。

1.构造方法(函数) 构造方法是一种特殊的成员方法&#xff0c;与其他成员方法不同: 构造方法的名字必须与类名相同&#xff1b; 无类型、可有参数、可重载 会自动生成&#xff0c;可自定义 一般形式:类名(形参)&#xff1b; 例: Stu(int age); 当用户没自定义构造方法时&…

apache的ab工具测试网页优化效果速度以及服务器承载

今天为大家介绍一款apache自带的一种的测试网页优化效果速度以及服务器承载的工具——ab.exe。 大家在工作中或者开发中可以使用apache的ab工具来测试自己的网站并发量大小&#xff0c;和某个页面的访问时间。 一、基本用法 如果你是用的是apache的话&#xff0c;那么只要进…

基于swing的校园茶餐厅java jsp点餐订餐管理mysql源代码

本项目为前几天收费帮学妹做的一个项目&#xff0c;Java EE JSP项目&#xff0c;在工作环境中基本使用不到&#xff0c;但是很多学校把这个当作编程入门的项目来做&#xff0c;故分享出本项目供初学者参考。 一、项目描述 基于swing的校园茶餐厅 系统有1权限 二、主要功能 …

sql server删除历史数据

1 函数 datediff函数: DATEDIFF ( datepart , startdate , enddate )datepart的取值可以是year,quarter,Month,dayofyear,Day,Week,Hour,minute,second,millisecond startdate 是从 enddate 减去。如果 startdate 比 enddate 晚&#xff0c;返回负值。 2 例子 删除2023年以…

最新PHP短网址生成系统/短链接生成系统/URL缩短器系统源码

全新PHP短网址系统URL缩短器平台&#xff0c;它使您可以轻松地缩短链接&#xff0c;根据受众群体的位置或平台来定位受众&#xff0c;并为缩短的链接提供分析见解。 系统使用了Laravel框架编写&#xff0c;前后台双语言使用&#xff0c;可以设置多域名&#xff0c;还可以开设套…

《Zookeeper》源码分析(二十二)之 客户端核心类

目录 CliCommand数据结构parse()exec() ZooKeeperHostProviderZKClientConfigClientCnxnSocket数据结构构造函数 ClientCnxn数据结构构造函数start() CliCommand 数据结构 CliCommand定义了两个抽象方法&#xff0c;以CreateCommand为例来看下它的parse()和exec()方法。 先看…

腾讯云V265/TXAV1直播场景下的编码优化和应用

// 编者按&#xff1a;随着视频直播不断向着超高清、低延时、高码率的方向发展&#xff0c; Apple Vision的出现又进一步拓展了对3D, 8K 120FPS的视频编码需求&#xff0c;视频的编码优化也变得越来越具有挑战性。LiveVideoStackCon 2023上海站邀请到腾讯云的姜骜杰老师分享腾…

docker 重装提示 Exising installation is up to date 解决方法

Windows Docker 重装提示 Exising installation is up to date 解决方法 出现这个问题是因为卸载Docker没有卸载干净&#xff0c;导致无法重装 解决方法&#xff1a; 按下WindowR唤起命令输入界面&#xff0c;输入 regedit 打开注册表编辑在地址栏输入HKEY_LOCAL_MACHINE\SOFTW…

【Hadoop】Hadoop入门概念简介

&#x1f341; 博主 "开着拖拉机回家"带您 Go to New World.✨&#x1f341; &#x1f984; 个人主页——&#x1f390;开着拖拉机回家_Linux,Java基础学习,大数据运维-CSDN博客 &#x1f390;✨&#x1f341; &#x1fa81;&#x1f341; 希望本文能够给您带来一定的…

postmarketOS

nexus 5 nexu5 刷入最新固件、android系统 步骤0. 使用旧的 platform-tools_r27: mkdir ~/nexus5; cd ~/nexus5;#https://android.googlesource.com/platform/tools/google_prebuilts/studio/sdk/remote//a66136ae1bfeb1b08a42319158a7652938c648d3 #此页面有:dl.google.com…

浅谈泛在电力物联网发展形态与技术挑战

安科瑞 华楠 摘 要&#xff1a;泛在电力物联网是当前智能电网发展的一个方向。首先&#xff0c;总结了泛在电力物联网的主要作用和价值体现&#xff1b;其次&#xff0c;从智能电网各个环节概述了物联网技术在电力领域的已有研究和应用基础&#xff1b;进而&#xff0c;构思并…
最新文章