AlexNet(pytorch)

AlexNet是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Challenge)竞赛的冠军网络,分类准确率由传统的 70%+提升到 80%+

该网络的亮点在于:

(1)首次利用 GPU 进行网络加速训练。

(2)使用了 ReLU 激活函数,而不是传统的 Sigmoid 激活函数以及 Tanh 激活函数。

(3)使用了 LRN 局部响应归一化。

(4)在全连接层的前两层中使用了 Dropout 随机失活神经元操作,以减少过拟合

模型:

模型参数表:

model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]: (55-3+0)/4 + 1=27
            
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
           
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    #前期的网络还是用的Normalize标准化,之后的网络会用到BN批标准化
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../../"))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

    #注意这里的数据加载还是直接用的torchvision.datasets.ImageFolder加载,
    #并不需要定义数据加载的脚本,可能是数据比较简单吧
    #定义数据集时候直接定义数据处理方法,之后torch.utils.data.DataLoader加载数据集加载时候直接调用这里定义的数据处理参数的方法
    #train文件夹下还有五种花的文件夹,这个具体处理看下面的代码,可能是ImageFolder直接加载文件夹里的图片文件
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    #训练集图片的个数
    train_num = len(train_dataset)


    #train_dataset.class_to_idx 是一个字典,将类别名称映射到相应的索引。
    #下行注释就是flower_list具体内容
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    # cla_dict是一个反转字典,将原始字典 flower_list 的键和值进行交换
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # json.dumps() 将 cla_dict 转换为格式化的 JSON 字符串。
    # 最后,将 JSON 字符串写入名为 class_indices.json 的文件中
    # indent 参数表示有几类
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32

    #这个代码片段的目的是为了确定在并行计算时使用的最大工作进程数,并确保不超过系统的逻辑 CPU 核心数量和其他限制
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True)

    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    #一个epoch训练多少批次的数据,一批数据32个CWH,即32张图片
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        #这段代码使用了 tqdm 库来创建一个进度条,用于迭代训练数据集 train_loader 中的批次数据
        #file=sys.stdout 的作用是将进度条的输出定向到标准输出流,即将进度条显示在终端窗口中
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            #更新进度条的描述信息,显示当前训练的轮数、总轮数和损失值
            #这个loss是批次损失,在进度条上显示出来
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # 验证是训练完一个epoch后进行在验证集上验证,验证准确率
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                #val_bar 的类型是 tqdm.tqdm,它是 tqdm 库中的一个类。该类提供了迭代器的功能,
                # 可以用于包装迭代器对象,并在循环中显示进度条和相关信息
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))   #outputs:[batch_size,num_classes]
                predict_y = torch.max(outputs, dim=1)[1]  #torch.max  返回的第一个元素是张量数值,第二个是对应的索引
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        #验证完后计算验证集里所有的正确个数/总个数
        val_accurate = acc / val_num

        #总损失/训练总批次,求得平均每批的损失
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

训练过程:

using cuda:0 device.
Using 8 dataloader workers every process
using 3306 images for training, 364 images for validation.
train epoch[1/10] loss:1.215: 100%|██████████| 104/104 [00:23<00:00,  4.38it/s]
100%|██████████| 91/91 [00:15<00:00,  5.73it/s]
[epoch 1] train_loss: 1.342  val_accuracy: 0.478
train epoch[2/10] loss:1.111: 100%|██████████| 104/104 [00:19<00:00,  5.30it/s]
100%|██████████| 91/91 [00:15<00:00,  5.75it/s]
[epoch 2] train_loss: 1.183  val_accuracy: 0.533
train epoch[3/10] loss:1.252: 100%|██████████| 104/104 [00:19<00:00,  5.30it/s]
100%|██████████| 91/91 [00:15<00:00,  5.75it/s]
[epoch 3] train_loss: 1.097  val_accuracy: 0.604
train epoch[4/10] loss:0.730: 100%|██████████| 104/104 [00:19<00:00,  5.32it/s]
100%|██████████| 91/91 [00:15<00:00,  5.74it/s]
[epoch 4] train_loss: 1.025  val_accuracy: 0.607
train epoch[5/10] loss:0.961: 100%|██████████| 104/104 [00:19<00:00,  5.28it/s]
100%|██████████| 91/91 [00:16<00:00,  5.65it/s]
[epoch 5] train_loss: 0.941  val_accuracy: 0.676
train epoch[6/10] loss:0.853: 100%|██████████| 104/104 [00:19<00:00,  5.31it/s]
100%|██████████| 91/91 [00:15<00:00,  5.82it/s]
[epoch 6] train_loss: 0.915  val_accuracy: 0.659
train epoch[7/10] loss:1.032: 100%|██████████| 104/104 [00:19<00:00,  5.34it/s]
100%|██████████| 91/91 [00:15<00:00,  5.82it/s]
[epoch 7] train_loss: 0.864  val_accuracy: 0.684
train epoch[8/10] loss:0.704: 100%|██████████| 104/104 [00:19<00:00,  5.32it/s]
100%|██████████| 91/91 [00:15<00:00,  5.80it/s]
[epoch 8] train_loss: 0.842  val_accuracy: 0.706
train epoch[9/10] loss:1.279: 100%|██████████| 104/104 [00:19<00:00,  5.30it/s]
100%|██████████| 91/91 [00:15<00:00,  5.83it/s]
[epoch 9] train_loss: 0.825  val_accuracy: 0.714
train epoch[10/10] loss:0.796: 100%|██████████| 104/104 [00:19<00:00,  5.31it/s]
100%|██████████| 91/91 [00:15<00:00,  5.82it/s]
[epoch 10] train_loss: 0.801  val_accuracy: 0.703
Finished Training

Process finished with exit code 0

predict.py:

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "./test.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)

    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = AlexNet(num_classes=5).to(device)

    # load model weights
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)

    #torch.load() 函数会根据路径加载模型的权重,并返回一个包含模型参数的字典
    #load_state_dict() 函数将加载的模型参数字典应用到 model 中,从而将预训练模型的参数加载到 model 中
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

预测结果:

我感觉pycharm的plt显示并不是特别明了

class: daisy        prob: 4.2e-06
class: dandelion    prob: 9.61e-07
class: roses        prob: 0.000773
class: sunflowers   prob: 1.28e-05
class: tulips       prob: 0.999

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/251446.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

介绍strncpy函数

strncpy函数需要引用#include <string.h>头文件 函数原型&#xff1a; char *_Dest 是字符串的去向 char *_Source是字符串的来源 size_t_Count是复制字符串的大小 #include <stdio.h> #include <string.h> int main() { char arr[128] { \0 }; …

数据结构之排序

目录 ​ 1.常见的排序算法 2.插入排序 直接插入排序 希尔排序 3.交换排序 冒泡排序 快速排序 hoare版本 挖坑法 前后指针法 非递归实现 4.选择排序 直接选择排序 堆排序 5.归并排序 6.排序总结 一起去&#xff0c;更远的远方 1.常见的排序算法 排序&#xff1a;所…

积分球均匀光源遥感器如何保持稳定

积分球的亮度均匀性取决于其内部涂层的反射率和分布情况。当光源通过积分球时&#xff0c;光会被内部的涂层多次反射&#xff0c;最终从出光口均匀地散射出去。为了提高亮度均匀性&#xff0c;可以采用具有高反射率和均匀分布的光源&#xff0c;同时选择合适的涂层材料和涂层厚…

NXP应用随记(五):eMios功能点阅读随记

目录 1、概念点 2、eMios功能点 2.1、eMIOS - Single Action Input Capture (SAIC) 2.2、eMIOS - Single Action Output Compare (SAOC) 2.3、eMIOS - Double Action Output Compare (DAOC) 2.4、eMIOS - Pulse/Edge Counting (PEC) – Single Shot 2.5、eMIOS - Pulse/E…

算法:程序员的数学读书笔记

目录 ​0的故事 ​一、按位计数法 二、不使用按位计数法的罗马数字 三、十进制转二进制​​​​​​​ ​四、0所起到的作用​​​​​​​ 逻辑 一、为何逻辑如此重要 二、兼顾完整性和排他性 三、逻辑 四、德摩根定律 五、真值表 六、文氏图 七、卡诺图 八、逻…

【算法Hot100系列】最长回文子串

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

向华为学习:基于BLM模型的战略规划研讨会实操的详细说明,含研讨表单(二)

上一篇文章&#xff0c;华研荟结合自己的经验和实践&#xff0c;详细介绍了基于BLM模型的战略规划研讨会的设计和组织流程&#xff0c;提高效率的做法。有朋友和我私信沟通说&#xff0c;其实这个流程不单单适合于BLM模型的战略规划研讨会&#xff0c;实际上&#xff0c;使用其…

Linux centos7安装redis 6.2.14 gz并且使用systemctl为开机自启动 / 彻底删除 redis

1.下载 && 减压 wget http://download.redis.io/releases/redis-6.2.14.tar.gz tar -zvxf redis-6.2.14.tar.gz 2.编译&#xff08;分开运行&#xff09; cd redis-6.2.14 make cd src make install 安装目录展示 3.redis.conf 配置更改 daemonize yes supervised s…

【STM32入门】4.1中断基本知识

1.中断概览 在开展红外传感器遮挡计次的实验之前&#xff0c;有必要系统性的了解“中断”的基本知识. 中断是指&#xff1a;在主程序运行过程中&#xff0c;出现了特定的中断触发条件&#xff08;中断源&#xff09;&#xff0c;使得CPU暂停当前正在运行的程序&#xff0c;转…

交友网站的设计与实现(源码+数据库+论文+开题报告+说明文档)

项目描述 临近学期结束&#xff0c;还是毕业设计&#xff0c;你还在做java程序网络编程&#xff0c;期末作业&#xff0c;老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下&#xff0c;你想解决的问…

听力健康“吃”出来

大多数的研究报告都指出&#xff0c;听力下降的最常见原因是年龄和噪音暴露。然而&#xff0c;近年来越来越多的文章开始探讨其他因素对听力的影响。食物不仅是维持人类基本生存的必需品&#xff0c;随着营养学的进步&#xff0c;人们也逐渐认识到食物中的营养与保持健康之间存…

Unity中Shader URP 简介

文章目录 前言一、URP&#xff08;Universal Render Pipeline&#xff09;由名字可知&#xff0c;这是一个 通用的 渲染管线1、Universal&#xff08;通用性&#xff09;2、URP的由来 二、Build-in Render Pipeline&#xff08;内置渲染管线&#xff09;1、LWRP&#xff08;Lig…

产品经理在项目周期中扮演的角色Axure的安装与基本使用

目录 一.项目周期流程 二.Axure是什么 三.Axure安装 3.1 一键式安装 3.2 汉化 3.3 授权登录 四.Axure的界面介绍及基本使用 4.1 菜单栏的使用 4.2 工具栏的使用 4.3 页面概要的使用及组件的使用 4.4 组件的样式设计 一.项目周期流程 在一般的项目周期中包含的工作内容有&…

Tektronix泰克TCP303示波器电流探头

主要特点和优点&#xff1a; ● 交流/直流测量功能 ● DC~100MHz电流探头放大器&#xff08;TCPA300&#xff09;&#xff0c;当使用&#xff1a; - DC~100MHz, 30A DC&#xff08;TCP312&#xff09; - DC~50MHz, 50A DC&#xff08;TCP305&#xff09; - DC~5MHz, 150A DC&a…

VC++项目的32位、64位的配置和链接问题

新建一个项目&#xff0c;默认是x86配置&#xff1b; 添加包含目录、库目录&#xff0c;之后可以编译通过&#xff1b; 但是链接会出错&#xff0c;因为链接的dll是64位&#xff1b; 把项目配置改为x64&#xff1b; 需要把包含目录和库目录针对x64重新添加&#xff0c;否则会…

CSS学习笔记整理

CSS 即 层叠样式表/CSS样式表/级联样式表&#xff0c;也是标记语言&#xff0c; 用于设置HTML页面中的文本内容&#xff08;字体、大小、对齐方式等&#xff09;、图片的外形&#xff08;宽高、边框样式、边距&#xff09;以及版面的布局和外观显示样式 目录 准备工作 Chrome调…

【JavaWeb学习笔记】10 - 手写Tomcat底层,Maven的初步使用

一、Maven 1.Maven示意图 类似Java访问数据库 2.创建Maven案例演示 配置阿里镜像 找到setting目录 但一开始配置不存在该文件 需要去Maven主目录下的conf拿到settings拷贝到上述目录 拷贝到admin/.m2后打开该settings 在<mirrors>内输入镜像地址 <mirror> …

SpringBoot中处理处理国际化

SpringBoot中处理处理国际化 1. 创建SpringBoot项目2. resource下创建i18n目录3. 右键i18n新建资源包4. 弹框中添加需要支持的国际化语言5. messages.properties中添加需要国际化的键6. application.yaml添加配置7. 国际化工具8. 使用功能9 场景问题 1. 创建SpringBoot项目 2.…

【Flink-cdc-Mysql-To-Kafka】使用 Flinksql 利用集成的 connector 实现 Mysql 数据写入 Kafka

【Flink-cdc-Mysql-To-Kafka】使用 Flinksql 利用集成的 connector 实现 Mysql 数据写入 Kafka 1&#xff09;环境准备2&#xff09;准备相关 jar 包3&#xff09;实现场景4&#xff09;准备工作4.1.Mysql4.2.Kafka 5&#xff09;Flink-Sql6&#xff09;验证 1&#xff09;环境…

毅速:3D打印随形水路 提高良品率和生产效率的新利器

随着科技的不断发展&#xff0c;3D打印技术已经成为模具制造领域的一种重要技术。其中&#xff0c;模具随形水路的设计和制造是提高注塑产品良品率和生产效率的关键环节。 模具随形水路是一种根据产品形状设计的水路&#xff0c;可以更靠近产品&#xff0c;并在模具内热点集中区…
最新文章