霹雳吧啦Wz《pytorch图像分类》-p2AlexNet网络

《pytorch图像分类》p2AlexNet网络基础及代码

  • 一、零碎知识点
    • 1.过拟合
    • 2.使用dropout后的正向传播
    • 3.正则化regularization
    • 4.代码中所用的知识点
  • 二、总体架构分析
    • 1.ReLU激活函数
    • 2.手算
    • 3.模型代码
  • 三、训练花分类课程代码
    • 1.model.py
    • 2.train.py
    • 3.predict.py

一、零碎知识点

1.过拟合

模型假设过于复杂,参数过多,训练数据过少,噪声过多,导致拟合的函数完美的预测训练集,但对新数据的测试集预测结果差。 过度的拟合了训练数据,而没有考虑到泛化能力。
举个栗子:当我们在学习一门新的学科时会做一些例题,我们把这些例题完完整整背下来,但是一旦给出新的题目,还是不会做,这就是学习没有泛化能力,只记住了例题的细节而忽视了更普遍的规律。

2.使用dropout后的正向传播

dropout会在每一层当中随机失活一部分神经元,从而减少了神经元之间的共适应性,防止过拟合。
nn.Dropout(p= )p代表的是随机失活的比例,默认p=0.5

3.正则化regularization

是一种通过添加额外的约束或惩罚项来控制模型的复杂度的技术,其目的也是防止过拟合。
假设我们的模型是一个二次多项式,我们的目标是最小化损失函数(均方误差MSE),让预测的曲线与真实值尽可能接近。
L2正则化将惩罚项加到损失函数中,使用权重的平方和乘以一个正则化系数λ。
带正则化的损失函数为:
L o s s = M S E + λ ∗ ∣ ∣ w ∣ ∣ 2 Loss = MSE + λ * ||w||^2 Loss=MSE+λ∣∣w2
很抽象,以后再深入学习。

4.代码中所用的知识点

  1. 设备设置
    如果有可以使用的gpu设备,默认使用第一个,没有的话即使用cpu。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  1. RandomResizedCrop随机裁剪
    将裁剪后的图像调整为指定的大小(224x224)
transforms.RandomResizedCrop(224)
  1. root=“. ./. .”
    返回上上层目录

  2. CrossEntropyLoss交叉熵损失函数
    它是PyTorch中的一个损失函数,常用于多分类问题的训练中。它结合了Softmax激活函数和负对数似然损失(NLLLoss)
    详情请见我之前写的博客:多分类问题

  3. net.train( )和net.eval( )
    当调用net.train()时,模型将被设置为训练模式。在训练模式下,模型会启用一些特定的操作,如Batch Normalization归一化处理和Dropout随机失活一些神经元,防止过拟合。而调用net.eval()时,模型将被设置为评估(evaluate)模式。

二、总体架构分析

1.ReLU激活函数

ReLU(Rectified Linear Unit)线性整流函数,又称修正线性单元。在PyTorch中,可以使用torch.nn.ReLU类来表示。
其公式为:
f ( x ) = M a x ( 0 , x ) f(x)=Max(0,x) f(x)=Max(0,x)
它将小于零的输入映射为0,大于等于零的输入保持不变,求导简单方便。
我总结了这些天学习的几个激活函数:激活函数总结

2.手算

在这里插入图片描述

3.模型代码

nn.sequential将一系列的层结构打包组合成一个新的结构
classifer包括了三个全连接层
在这里插入图片描述
1.Conv1 第一个卷积层
代码表示为: nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
彩色图像有rgb三个通道,所以输入通道数为3,输出通道数=卷积核的个数kernel_num
为了方便训练,卷积核个数只取一半,padding直接取2,训练效果与原本的影响不大

 nn.Conv2d(3,48,kernel_size=11,stride=4,padding=2)

2.Maxpool1 第一个池化层
代码表示为:nn.MaxPool2d(kernel_size,stride)

nn.MaxPool2d(kernel_size=3,stride=2)

3.后续代码

nn.Conv2d(48,128,kernel_size=5,padding=2)
nn.MaxPool2d(kernel_size=3,stride=2)
nn.Conv2d(128,192,kernel_size=3,padding=1)
nn.Conv2d(192,192,kernel_size=3,padding=1)
nn.Conv2d(192,128,kernel_size=3,padding=1)
nn.MaxPool2d(kernel_size=3,stride=2)

三、训练花分类课程代码

1.model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

2.train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=True,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True)

    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

3.predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "1..jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)

    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = AlexNet(num_classes=5).to(device)

    # load model weights
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

1.jpg是郁金香
在这里插入图片描述
2.jpg是向日葵
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/284736.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

FPGA项目(14)——基于FPGA的数字秒表设计

1.功能设计 设计内容及要求: 1.秒表最大计时范围为99分59. 99秒 2.6位数码管显示,分辨率为0.01秒 3.具有清零、启动计时、暂停及继续计时等功能 4.控制操作按键不超过二个。 2.设计思路 所采用的时钟为50M,先对时钟进行分频,得到100HZ频率…

【Maven】下载配置maven以及IDEA配置maven详情

目录 1、下载maven 2、配置settings.xml 2.1、配置本地仓库 2.2、配置阿里云镜像仓库 2.3、配置JDK 3、配置环境变量 4、IDEA配置maven 1、下载maven maven官网&#xff1a;https://maven.apache.org/ 2、配置settings.xml 2.1、配置本地仓库 <localRepository>C:\…

oracle 9i10g编程艺术-读书笔记1

根据书中提供的下载代码链接地址&#xff0c;从github上找到源代码下载地址。 https://github.com/apress下载好代码后&#xff0c;开始一段新的旅行。 设置 SQL*Plus 的 AUTOTRACE 设置 SQL*Plus 的 AUTOTRACE AUTOTRACE 是 SQL*Plus 中一个工具&#xff0c;可以显示所执行…

GPT4-AIl本地部署-chat AI本地使用

文章目录 GPT4-AIl本地部署GPT4客户端下载地址&#xff1a;对应的下载下载后的文件点击安装&#xff0c;改一下文件存放路径&#xff0c;下面都是默认下一步进度条100%后&#xff0c;点击完成 安装完桌面生成图标&#xff0c;点击选择都是NO&#xff0c;不进行数据上传点击后&a…

Python编程新技能:如何优雅地实现水仙花数?

水仙花数&#xff08;Narcissistic number&#xff09;也被称为阿姆斯特朗数&#xff08;Armstrong number&#xff09;或自恋数等&#xff0c;它是一个非负整数&#xff0c;其特性是该数的每个位上的数字的n次幂之和等于它本身&#xff0c;其中n是该数的位数。简单来说&#x…

一起学Elasticsearch系列-写入原理

本文已收录至Github&#xff0c;推荐阅读 &#x1f449; Java随想录 微信公众号&#xff1a;Java随想录 文章目录 写入过程写操作写流程写一致性策略 写入原理RefreshMergeFlushTranslog图解写入流程 ES作为一款开源的分布式搜索和分析引擎&#xff0c;以其卓越的性能和灵活的扩…

29 UVM Command Line Processor (CLP)

随着设计和验证环境的复杂性增加&#xff0c;编译时间也增加了&#xff0c;这也影响了验证时间。因此&#xff0c;需要对其进行优化&#xff0c;以便在不强制重新编译的情况下考虑新的配置或参数。我们已经看到了function or task如何基于传递参数进行行为。类似地&#xff0c;…

均方差损失推导

一、损失函数&#xff08;Cost function&#xff09; 定义&#xff1a;用于衡量模型预测结果与真实结果之间差距的函数。&#xff08;有的地方称之为代价函数&#xff0c;但是个人感觉损失函数这个名称更贴近实际用途&#xff09; 理解&#xff1a;&#xff08;以均方差损失函…

Python浪漫520表白代码

系列文章 序号文章目录直达链接表白系列1浪漫520表白代码https://want595.blog.csdn.net/article/details/1306668812满屏飘字表白代码https://want595.blog.csdn.net/article/details/1349149703无限弹窗表白代码https://want595.blog.csdn.net/article/details/1297945184跳…

使用.Net nanoFramework 驱动ESP32的OLED显示屏

本文介绍如何使用.Net nanoFramework 驱动ESP32的OLED显示屏。我们将会从最基础的部分开始&#xff0c;逐步深入&#xff0c;让你能够理解并实现整个过程。无论你是初学者还是有一定经验的开发者&#xff0c;这篇文章都会对你有所帮助。 1. 硬件准备 1.1 ESP32开发板 这里我们…

搭建flink集群 —— 筑梦之路

Apache Flink 是一个框架和分布式处理引擎&#xff0c; 用于在无边界和有边界数据流上进行有状态的计算。 Flink 能在所有常见集群环境中运行&#xff0c;并能以内存速度和任意规模进行计算。 Flink并没有依靠自身实现所有分布式系统需要解决的问题&#xff0c; 而是在已有集群…

c语言之将输入的十进制转换成二进制数并打印原码反码补码

十进制转二进制 首先&#xff0c;我们要知道的是十进制转换成二进制数的方法。我们一般采用的除二取余的方法&#xff0c;在这里我用32位数组来进行转换。 int main() {printf("请输入一个十进制数\n");int n 0;scanf("%d", &n);int arr[32];int* p…

实战入门 K8s剩下三个模块

1.Label Label是kubernetes系统中的一个重要概念。它的作用就是在资源上添加标识&#xff0c;用来对它们进行区分和选择。 Label的特点&#xff1a; 一个Label会以key/value键值对的形式附加到各种对象上&#xff0c;如Node、Pod、Service等等 一个资源对象可以定义任意数量…

VMware虚拟机和Centos7镜像安装

文章目录 安装VMware虚拟机1、下载2、激活 安装Centos7镜像启动虚拟机 安装VMware虚拟机 1、下载 建议还是安装16版本 VMware16下载 https://www.123pan.com/s/HQeA-aX1Sh VMware15 链接&#xff1a;https://pan.baidu.com/s/11UD1hb6IydbxNNPxmh-MqA?pwd0630 提取码&am…

【熔断限流组件resilience4j和hystrix】

文章目录 &#x1f50a;博主介绍&#x1f964;本文内容起因resilience4j落地实现pom.xml依赖application.yml配置接口使用 hystrix 落地实现pom.xml依赖启动类上添加注解接口上使用 &#x1f4e2;文章总结&#x1f4e5;博主目标 &#x1f50a;博主介绍 &#x1f31f;我是廖志伟…

使用YOLOv8和Grad-CAM技术生成图像热图

目录 yolov8导航 YOLOv8&#xff08;附带各种任务详细说明链接&#xff09; 概述 环境准备 代码解读 导入库 定义letterbox函数 调整尺寸和比例 计算填充 应用填充 yolov8_heatmap类定义和初始化 后处理函数 绘制检测结果 类的调用函数 热图生成细节 参数解释 we…

QT上位机开发(带配置文件的倒计时软件)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 前面我们用qt写过倒计时软件&#xff0c;但是那个时候界面只有分钟和秒钟&#xff0c;这一次我们希望在之前的基础上拓展一下。第一&#xff0c;可…

ThreadLocal 是什么?它的实现原理是什么?

文章目录 ThreadLocal 是什么&#xff1f;它的实现原理是什么&#xff1f; ThreadLocal 是什么&#xff1f;它的实现原理是什么&#xff1f; ThreadLocal 是一种线程隔离机制&#xff0c;它提供了多线程环境下对于共享变量访问的安全性。 在多线程访问共享变量的场景中&#…

2023-12-14 LeetCode每日一题(用邮票贴满网格图)

2023-12-14每日一题 一、题目编号 2132. 用邮票贴满网格图二、题目链接 点击跳转到题目位置 三、题目描述 给你一个 m x n 的二进制矩阵 grid &#xff0c;每个格子要么为 0 &#xff08;空&#xff09;要么为 1 &#xff08;被占据&#xff09;。 给你邮票的尺寸为 stam…

【心得】PHP反序列化高级利用(phar|session)个人笔记

目录 ①phar反序列化 ②session反序列化 ①phar反序列化 phar 认为是java的jar包 calc.exe phar能干什么 多个php合并为独立压缩包&#xff0c;不解压就能执行里面的php文件&#xff0c;支持web服务器和命令行 phar协议 phar://xxx.phar $phar->setmetadata($h); m…
最新文章