深度学习实战基础案例——卷积神经网络(CNN)基于SqueezeNet的眼疾识别|第1例

文章目录

  • 前言
  • 一、数据准备
      • 1.1 数据集介绍
      • 1.2 数据集文件结构
  • 二、项目实战
      • 2.1 数据标签划分
      • 2.2 数据预处理
      • 2.3 构建模型
      • 2.4 开始训练
      • 2.5 结果可视化
  • 三、数据集个体预测

前言

SqueezeNet是一种轻量且高效的CNN模型,它参数比AlexNet少50倍,但模型性能(accuracy)与AlexNet接近。顾名思义,Squeeze的中文意思是压缩和挤压的意思,所以我们通过算法的名字就可以猜想到,该算法一定是通过压缩模型来降低模型参数量的。当然任何算法的改进都是在原先的基础上提升精度或者降低模型参数,因此该算法的主要目的就是在于降低模型参数量的同时保持模型精度。


我的环境:

  • 基础环境:python3.7
  • 编译器:pycharm
  • 深度学习框架:pytorch
  • 数据集代码获取:链接(提取码:2357 )

一、数据准备

本案例使用的数据集是眼疾识别数据集iChallenge-PM。

1.1 数据集介绍

iChallenge-PM是百度大脑和中山大学中山眼科中心联合举办的iChallenge比赛中,提供的关于病理性近视(Pathologic Myopia,PM)的医疗类数据集,包含1200个受试者的眼底视网膜图片,训练、验证和测试数据集各400张。

  • training.zip:包含训练中的图片和标签
  • validation.zip:包含验证集的图片
  • valid_gt.zip:包含验证集的标签

该数据集是从AI Studio平台中下载的,具体信息如下:
在这里插入图片描述

1.2 数据集文件结构

数据集中共有三个压缩文件,分别是:

  • training.zip
├── PALM-Training400
│   ├── PALM-Training400.zip
│   │   ├── H0002.jpg
│   │   └── ...
│   ├── PALM-Training400-Annotation-D&F.zip
│   │   └── ...
│   └── PALM-Training400-Annotation-Lession.zip
        └── ...
  • valid_gt.zip:标记的位置 里面的PM_Lable_and_Fovea_Location.xlsx就是标记文件
├── PALM-Validation-GT
│   ├── Lession_Masks
│   │   └── ...
│   ├── Disc_Masks
│   │   └── ...
│   └── PM_Lable_and_Fovea_Location.xlsx

  • validation.zip:测试数据集
├── PALM-Validation
│   ├── V0001.jpg
│   ├── V0002.jpg
│   └── ...

二、项目实战

项目结构如下:
在这里插入图片描述

2.1 数据标签划分

该眼疾数据集格式有点复杂,这里我对数据集进行了自己的处理,将训练集和验证集写入txt文本里面,分别对应它的图片路径和标签。

import os
import pandas as pd
# 将训练集划分标签
train_dataset = r"F:\SqueezeNet\data\PALM-Training400\PALM-Training400"
train_list = []
label_list = []


train_filenames = os.listdir(train_dataset)

for name in train_filenames:
    filepath = os.path.join(train_dataset, name)
    train_list.append(filepath)
    if name[0] == 'N' or name[0] == 'H':
        label = 0
        label_list.append(label)
    elif name[0] == 'P':
        label = 1
        label_list.append(label)
    else:
        raise('Error dataset!')


with open('F:/SqueezeNet/train.txt', 'w', encoding='UTF-8') as f:
    i = 0
    for train_img in train_list:
        f.write(str(train_img) + ' ' +str(label_list[i]))
        i += 1
        f.write('\n')
# 将验证集划分标签
valid_dataset = r"F:\SqueezeNet\data\PALM-Validation400"
valid_filenames = os.listdir(valid_dataset)
valid_label = r"F:\SqueezeNet\data\PALM-Validation-GT\PM_Label_and_Fovea_Location.xlsx"
data = pd.read_excel(valid_label)
valid_data = data[['imgName', 'Label']].values.tolist()

with open('F:/SqueezeNet/valid.txt', 'w', encoding='UTF-8') as f:
    for valid_img in valid_data:
        f.write(str(valid_dataset) + '/' + valid_img[0] + ' ' + str(valid_img[1]))
        f.write('\n')

2.2 数据预处理

这里采用到的数据预处理,主要有调整图像大小、随机翻转、归一化等。

import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms

transform_BZ = transforms.Normalize(
    mean=[0.5, 0.5, 0.5],
    std=[0.5, 0.5, 0.5]
)


class LoadData(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag

        self.train_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.RandomHorizontalFlip(),  # 随机左右翻转图像
            transforms.RandomVerticalFlip(),  # 随机上下翻转图像
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])
        self.val_tf = transforms.Compose([
            transforms.Resize(224),  # 调整图像大小为224x224
            transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
            transform_BZ  # 执行某些复杂变换操作
        ])

    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))
        return imgs_info

    def padding_black(self, img):
        w, h = img.size
        scale = 224. / max(w, h)
        img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
        size_fg = img_fg.size
        size_bg = 224
        img_bg = Image.new("RGB", (size_bg, size_bg))
        img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
                              (size_bg - size_fg[1]) // 2))

        img = img_bg
        return img

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]

        img_path = os.path.join('', img_path)
        img = Image.open(img_path)
        img = img.convert("RGB")
        img = self.padding_black(img)
        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)
        return img, label

    def __len__(self):
        return len(self.imgs_info)

2.3 构建模型

import torch
import torch.nn as nn
import torch.nn.init as init


class Fire(nn.Module):

    def __init__(self, inplanes, squeeze_planes,
                 expand1x1_planes, expand3x3_planes):
        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
                                   kernel_size=1)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
                                   kernel_size=3, padding=1)
        self.expand3x3_activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.squeeze_activation(self.squeeze(x))
        return torch.cat([
            self.expand1x1_activation(self.expand1x1(x)),
            self.expand3x3_activation(self.expand3x3(x))
        ], 1)


class SqueezeNet(nn.Module):

    def __init__(self, version='1_0', num_classes=1000):
        super(SqueezeNet, self).__init__()
        self.num_classes = num_classes
        if version == '1_0':
            self.features = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=7, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(96, 16, 64, 64),
                Fire(128, 16, 64, 64),
                Fire(128, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 32, 128, 128),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(512, 64, 256, 256),
            )
        elif version == '1_1':
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(64, 16, 64, 64),
                Fire(128, 16, 64, 64),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(128, 32, 128, 128),
                Fire(256, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                Fire(512, 64, 256, 256),
            )
        else:
            # FIXME: Is this needed? SqueezeNet should only be called from the
            # FIXME: squeezenet1_x() functions
            # FIXME: This checking is not done for the other models
            raise ValueError("Unsupported SqueezeNet version {version}:"
                             "1_0 or 1_1 expected".format(version=version))

        # Final convolution is initialized differently from the rest
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            final_conv,
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
                    init.normal_(m.weight, mean=0.0, std=0.01)
                else:
                    init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return torch.flatten(x, 1)

2.4 开始训练

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from model import SqueezeNet
import torchsummary
from dataloader import LoadData
import copy

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = SqueezeNet(num_classes=2).to(device)
# print(model)
#print(torchsummary.summary(model, (3, 224, 224), 1))


# 加载训练集和验证集
train_data = LoadData(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)
test_data = LoadData(r"F:\SqueezeNet\valid.txt", True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)


# 编写训练函数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)
    print('num_batches:', num_batches)
    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率

    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失

        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss

# 编写验证函数
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小
    num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)
    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss




# 开始训练

epochs = 20

train_loss = []
train_acc = []
test_loss = []
test_acc = []

best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标


loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 定义Adam优化器

for epoch in range(epochs):

    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_function, optimizer)

    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_function)

    # 保存最佳模型到 best_model
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']

    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,
                          epoch_test_acc * 100, epoch_test_loss, lr))

# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(best_model.state_dict(), PATH)

print('Done')

在这里插入图片描述

2.5 结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Test Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Test Loss')
plt.show()

可视化结果如下:
在这里插入图片描述
可以自行调整学习率以及batch_size,这里我的超参数并没有调整。

三、数据集个体预测

import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import transforms
from model import SqueezeNet
import torch

data_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((224, 224)),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

img = Image.open("F:\SqueezeNet\data\PALM-Validation400\V0008.jpg")
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
name = ['非病理性近视', '病理性近视']
model_weight_path = r"F:\SqueezeNet\best_model.pth"
model = SqueezeNet(num_classes=2)
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
    output = torch.squeeze(model(img))

    predict = torch.softmax(output, dim=0)
    # 获得最大可能性索引
    predict_cla = torch.argmax(predict).numpy()
    print('索引为', predict_cla)
print('预测结果为:{},置信度为: {}'.format(name[predict_cla], predict[predict_cla].item()))
plt.show()
索引为 1
预测结果为:病理性近视,置信度为: 0.9768268465995789

在这里插入图片描述

更详细的请看paddle版本的实现:深度学习实战基础案例——卷积神经网络(CNN)基于SqueezeNet的眼疾识别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/77682.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Springboot写单元测试

导入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><exclusions><exclusion><groupId>org.junit.vintage</groupId><artifactId>junit-vintag…

Python学习笔记_基础篇(七)_常用模块

模块&#xff0c;用一砣代码实现了某个功能的代码集合。 类似于函数式编程和面向过程编程&#xff0c;函数式编程则完成一个功能&#xff0c;其他代码用来调用即可&#xff0c;提供了代码的重用性和代码间的耦合。而对于一个复杂的功能来&#xff0c;可能需要多个函数才能完成…

机器人CPP编程基础-03变量类型Variables Types

机器人CPP编程基础-02变量Variables 全文AI生成。 C #include<iostream>using namespace std;main() {int a10,b35; // 4 bytescout<<"Value of a : "<<a<<" Address of a : "<<&a <<endl;cout<<"Val…

详细记录Pycharm配置已安装好的Conda虚拟环境

当安装好conda环境之后&#xff0c;想要在Pycharm中使用&#xff0c;那么就要在Pycharm中导入&#xff0c;我这里使用的pycharm-professional-2023.2这个版本&#xff0c;下面是详细步骤&#xff1a; 1.打开File->Settings&#xff1a; 2.找到Project——>Python Inter…

网络层协议

网络层协议 IP协议基本概念协议头格式网段划分特殊的IP地址IP地址的数量限制私有IP地址和公网IP地址路由IP协议头格式后续 在复杂的网络环境中确定一个合适的路径 IP协议 承接上文&#xff0c;TCP协议并不会直接将数据传递给对方&#xff0c;而是交付给下一层协议&#xff0c;…

LeetCode Top100 Liked 题单(序号34~51)

​34. Find First and Last Position of Element in Sorted Array ​ 题意&#xff1a;找到非递减序列中目标的开头和结尾 我的思路 用二分法把每一个数字都找到&#xff0c;最后返回首尾两个数 代码 Runtime12 ms Beats 33.23% Memory14 MB Beats 5.16% class Solution {…

ssh远程连接服务器

一、远程连接服务器简介 二、连接加密技术简介 三、ssh服务配置 四、用户登录ssh服务 Enforcing会强制限制&#xff0c;如端口为22&#xff0c;可以访问&#xff0c;如果是2000端口&#xff0c;不能使用 Permissive是宽容的模式&#xff0c;不限制使用端口 Enforcing会重启失败…

Redis系列(一):深入了解Redis数据类型和底层数据结构

Redis有以下几种常用的数据类型&#xff1a; redis数据是如何组织的 为了实现从键到值的快速访问&#xff0c;Redis 使用了一个哈希表来保存所有键值对。 Redis全局哈希表&#xff08;Global Hash Table&#xff09;是指在Redis数据库内部用于存储所有键值对的主要数据结构。…

(三)行为模式:2、命令模式(Command Pattern)(C++示例)

目录 1、命令模式&#xff08;Command Pattern&#xff09;含义 2、命令模式的UML图学习 3、命令模式的应用场景 4、命令模式的优缺点 5、C实现命令模式的实例 1、命令模式&#xff08;Command Pattern&#xff09;含义 命令模式&#xff08;Command&#xff09;&#xff…

亚马逊、ebay、虾皮电商卖家如何做测评,提高店铺排名?

测评是什么呢&#xff1f; 不管是在亚马逊&#xff0c;速卖通&#xff0c;阿里国际&#xff0c;虾皮&#xff0c;Lazada&#xff0c;沃尔玛&#xff0c;美客多&#xff0c;ebay等跨境电商平台&#xff0c;测评都是成本最低且最有效的一种推广方式。 通俗来说&#xff0c;测评…

leetcode292. Nim 游戏(博弈论 - java)

Nim 游戏 Nim 游戏题目描述博弈论 上期经典算法 Nim 游戏 难度 - 简单 原题链接 - Nim游戏 题目描述 你和你的朋友&#xff0c;两个人一起玩 Nim 游戏&#xff1a; 桌子上有一堆石头。 你们轮流进行自己的回合&#xff0c; 你作为先手 。 每一回合&#xff0c;轮到的人拿掉 1 -…

MySQL 中 不等于 会过滤掉 Null 的问题

null值与任意值比较时都为fasle not in 、"!"、"not like"条件过滤都会过滤掉null值的数据 SELECT * from temp; SELECT * from temp where score not in (70); 返回null解决方法: SELECT * from temp where score not in (70) or score is null;SELECT…

OC调用Swift编写的framework

一、前言 随着swift趋向稳定&#xff0c;越来越多的公司都开始用swift来编写苹果相关的业务了&#xff0c;关于swift的利弊这里就不多说了。这里详细介绍OC调用swift编写的framework库的步骤 二、制作framework 1、新建项目&#xff0c;选择framework 2、填写framework的名称…

Excel革命,基于电子表格开发的新工具,不是Access和Power Fx

深谙其道 在日常工作中&#xff0c;Excel是许多人不可或缺的办公工具。 是微软的旗下产品&#xff0c;属于Microsoft 365套件中的一部分&#xff0c;强大的数据处理和计算功能&#xff0c;被普遍应用在全球各行各业的人群当中&#xff0c;是一款强大且普及的电子表格软件。 于…

Salient主题 - 创意多用途和WooCommerce商城主题

Salient主题是下一代WordPress主题&#xff0c;给任何人带来专业的设计结果&#xff0c;而不需要任何编码。Salient 提供永久更新的专业剖面模板库&#xff0c;目前有超过425个可供选择 – 所有这些都充满热情并坚持高标准的审美质量。 网址: Salient主题 - 创意多用途和WooCo…

Google play应用成功上架要点——如何防止封号、拒审、下架?

Google Play是全球最大的移动应用商店之一&#xff0c;它是运行Android操作系统的设备的官方应用商店。它提供各种数字内容&#xff0c;包括应用程序&#xff08;应用&#xff09;、游戏、音乐、书籍等&#xff0c;包括免费和付费选项。这也为许多游戏/APP出海的企业或开发者提…

Vue CLI创建Vue项目详细步骤

&#x1f680; 一、安装Node环境&#xff08;建议使用LTS版本&#xff09; 在开始之前&#xff0c;请确保您已经安装了Node.js环境。您可以从Node.js官方网站下载LTS版本&#xff0c;以确保稳定性和兼容性。 中文官网下载 确认已安装 Node.js。可以在终端中运行 node -v 命令…

【AI大模型】训练Al大模型 (上篇)

大模型超越AI 前言 洁洁的个人主页 我就问你有没有发挥&#xff01; 知行合一&#xff0c;志存高远。 目前所指的大模型&#xff0c;是“大规模深度学习模型”的简称&#xff0c;指具有大量参数和复杂结构的机器学习模型&#xff0c;可以处理大规模的数据和复杂的问题&#x…

缓存平均的两种算法

引言 线边库存物料的合理性问题是物流仿真中研究的重要问题之一,如果线边库存量过多,则会对生产现场的布局产生负面影响,增加成本,降低效益。 写在前面 仿真分析后对线边Buffer的使用情况进行合理的评估就是一个非常重要的事情。比较关心的参数包括:缓存位最大值…

探索网络架构的关键角色:六种常用的服务器类型

在今天的数字时代&#xff0c;服务器是支撑各种在线服务和应用的基石。不同类型的服务器在网络架构中扮演着不同的角色&#xff0c;从网页传输到电子邮件交换&#xff0c;再到文件传输和内容分发。本文将深入探讨六种最常用的服务器类型&#xff0c;解释它们的功能和重要性&…