基于tensorflow深度学习的猫狗分类识别

 3f6a7ab0347a4af1a75e6ebadee63fc1.gif

🤵‍♂️ 个人主页:@艾派森的个人主页

✍🏻作者简介:Python学习者
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+


目录

实验背景

实验目的

实验环境

实验过程

1.加载数据

2.数据预处理

3.构建模型

4.训练模型

5.模型评估

源代码


 

实验背景

        近年来,深度学习在计算机视觉领域取得了巨大的成功,尤其是在图像分类任务上。图像分类是计算机视觉领域的基本问题之一,而猫狗分类作为图像分类中的经典问题,吸引了广泛的研究兴趣。猫狗分类问题具有很高的实际应用价值。在现实世界中,人们经常需要对动物进行分类,如在宠物识别、动物行为分析和动物保护等领域。传统的图像分类方法通常需要手工设计特征提取器和分类器,这在处理复杂的图像数据时面临着挑战。

        深度学习通过学习端到端的特征提取和分类模型,不需要手动设计特征提取器,因此在猫狗分类问题上具有巨大的潜力。卷积神经网络(Convolutional Neural Networks,简称CNN)是深度学习中最常用的模型之一,特别适用于图像数据的处理。猫狗分类问题的研究可以帮助我们深入理解深度学习在图像分类任务中的应用,并且可以为其他图像分类问题的研究提供经验和指导。此外,研究人员还可以通过比较不同深度学习模型的性能和对比传统方法的效果,评估深度学习在猫狗分类问题上的优势和局限性。

        此外,随着深度学习模型的不断发展和算力的提升,研究人员可以尝试更复杂的模型架构、数据增强技术和迁移学习方法,以进一步提高猫狗分类任务的准确性和鲁棒性。因此,基于深度学习的猫狗分类实验具有重要的研究价值,可以推动深度学习在图像分类领域的发展,同时为实际应用场景提供更好的解决方案。

实验目的

        本实验的目的是基于深度学习方法进行猫狗分类,通过设计和训练深度神经网络模型,实现对输入图像进行准确的猫狗分类。具体目标包括:

        1.建立一个高性能的猫狗分类模型:通过深度学习技术,构建一个能够从原始图像数据中自动学习到猫狗分类特征的神经网络模型。该模型能够准确地对输入图像进行分类,具备较高的分类准确率和泛化能力。

        2.探索不同深度学习模型的性能差异:比较不同深度学习模型(如卷积神经网络、残差网络等)在猫狗分类任务上的性能表现,评估它们的准确率、召回率、精确率等指标,并分析其优势和不足之处。

        3.优化模型性能:通过调整模型的超参数、网络结构以及训练策略等,进一步提高猫狗分类模型的性能。例如,可以尝试不同的激活函数、优化器、学习率调度等,以提高模型的收敛速度和泛化能力。

        4.数据增强和处理:应用数据增强技术,如随机裁剪、旋转、翻转等,扩充训练数据集的多样性,提高模型对于各种场景和变化的鲁棒性。同时,对原始图像数据进行预处理,如图像归一化、均衡化等,以便更好地适应模型输入要求。

        5.评估模型性能:使用独立的测试数据集对训练好的模型进行评估,计算分类准确率、混淆矩阵等指标,评估模型的性能。同时,可以与其他传统方法进行比较,验证基于深度学习的方法在猫狗分类问题上的优越性。

实验环境

Python3.9

Jupyter notebook

实验过程

1.加载数据

首先导入本次实验用到的第三方库

efd0993698ff41559c35b628cd72996f.png

 接着定义我们数据集的路径

821554dd344e4e84a8167bcbfc84883f.png

定义训练集、测试集、验证集生成器

88f0b8a8ddd8405e974afdfe2ab1d65e.png

 将生成器连接到文件夹中的数据a8b15d355bf9441d986e3e193352b175.png

 可视化一些数据图片,来个九宫格展示

f99d7afa18b64ddca0f33b91123197b0.png

 02d2e066cba24c14a33141bcc97ce52a.png

2.数据预处理

5bcd5225624c4ab3b56bdd3a5397160c.png

3.构建模型

构建模型、定义优化器

aec104b621444d7692be3c8ecf2c23a5.png

 保存模型

f47c6eb09b024b6a9f6b4647fac41818.png

4.训练模型

ae585ea527a54d319af96f54786679f3.png

5.模型评估

将模型训练和验证的损失可视化出来、以及训练和验证的准确率

44b1487c2e5940ec84c40667c66e1c0d.png

a88ce5d207424f51b77dd3304f317550.png

对验证数据集进行评估 

cbfe6ac76d774a7f9378c160c8c88d8c.png

对测试数据集进行评估 

 8073c95d0959412487d0bc957e92b63f.png

 将模型的混淆矩阵一热力图的形式展示4a5bdae21e8a42b8865523e83614d45c.png

946454010fc940f385af9ceaf036e80e.png

源代码

import numpy as np
import random
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.metrics import confusion_matrix
import seaborn as sns
sns.set(style='darkgrid', font_scale=1.4)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# 数据集路径
train_dir = './train'
test_dir = './test'
CFG = dict(
    seed = 77,
    batch_size = 16,
    img_size = (299,299),
    epochs = 5,
    patience = 5
)
train_data_generator = ImageDataGenerator(
        validation_split=0.15,
        rotation_range=15,
        width_shift_range=0.1,
        height_shift_range=0.1,
        preprocessing_function=preprocess_input,
        shear_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

val_data_generator = ImageDataGenerator(preprocessing_function=preprocess_input, validation_split=0.15)
test_data_generator = ImageDataGenerator(preprocessing_function=preprocess_input)
# 将生成器连接到文件夹中的数据
train_generator = train_data_generator.flow_from_directory(train_dir, target_size=CFG['img_size'], shuffle=True, seed=CFG['seed'], class_mode='categorical', batch_size=CFG['batch_size'], subset="training")
validation_generator = val_data_generator.flow_from_directory(train_dir, target_size=CFG['img_size'], shuffle=False, seed=CFG['seed'], class_mode='categorical', batch_size=CFG['batch_size'], subset="validation")
test_generator = test_data_generator.flow_from_directory(test_dir, target_size=CFG['img_size'], shuffle=False, seed=CFG['seed'], class_mode='categorical', batch_size=CFG['batch_size'])

# 样本和类的数量
nb_train_samples = train_generator.samples
nb_validation_samples = validation_generator.samples
nb_test_samples = test_generator.samples
classes = list(train_generator.class_indices.keys())
print('Classes:'+str(classes))
num_classes = len(classes)
# 可视化一些例子
plt.figure(figsize=(15,15))
for i in range(9):
    ax = plt.subplot(3,3,i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    batch = train_generator.next()
    imgs = (batch[0] + 1) * 127.5
    label = int(batch[1][0][0])
    image = imgs[0].astype('uint8')
    plt.imshow(image)
    plt.title('cat' if label==1 else 'dog')
plt.show()
base_model = InceptionResNetV2(weights='imagenet', include_top=False, input_shape=(CFG['img_size'][0], CFG['img_size'][1], 3))
x = base_model.output
x = Flatten()(x)
x = Dense(100, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax', kernel_initializer='random_uniform')(x)

# 构建模型
model = Model(inputs=base_model.input, outputs=predictions)

for layer in base_model.layers:
    layer.trainable = False
    
# 定义优化器
optimizer = Adam()
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# 保存模型
save_checkpoint = keras.callbacks.ModelCheckpoint(filepath='model.h5', monitor='val_loss', save_best_only=True, verbose=1)
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=CFG['patience'], verbose=True)
# 训练模型
history = model.fit(
        train_generator,
        steps_per_epoch=nb_train_samples // CFG['batch_size'],
        epochs=CFG['epochs'],
        callbacks=[save_checkpoint,early_stopping],
        validation_data=validation_generator,
        verbose=True,
        validation_steps=nb_validation_samples // CFG['batch_size'])
history_dict = history.history
loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']
epochs_x = range(1, len(loss_values) + 1)
plt.figure(figsize=(10,10))
plt.subplot(2,1,1)
plt.plot(epochs_x, loss_values, 'b-o', label='Training loss')
plt.plot(epochs_x, val_loss_values, 'r-o', label='Validation loss')
plt.title('Training and validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

# Accuracy
plt.subplot(2,1,2)
acc_values = history_dict['accuracy']
val_acc_values = history_dict['val_accuracy']
plt.plot(epochs_x, acc_values, 'b-o', label='Training acc')
plt.plot(epochs_x, val_acc_values, 'r-o', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Acc')
plt.legend()
plt.tight_layout()
plt.show()
# 对验证数据集进行评估
score = model.evaluate(validation_generator, verbose=False)
print('Val loss:', score[0])
print('Val accuracy:', score[1])
# 对测试数据集进行评估
score = model.evaluate(test_generator, verbose=False)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# 混淆矩阵
y_pred = np.argmax(model.predict(test_generator), axis=1)
cm = confusion_matrix(test_generator.classes, y_pred)

# 热力图
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cbar=True, cmap='Blues',xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.title('Confusion matrix')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/33218.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

机器学习之K-means聚类算法

目录 K-means聚类算法 算法流程 优点 缺点 随机点聚类 人脸聚类 旋转物体聚类 K-means聚类算法 K-means聚类算法是一种无监督的学习方法,通过对样本数据进行分组来发现数据内在的结构。K-means的基本思想是将n个实例分成k个簇,使得同一簇内数据相…

基于小程序的用户服务技术研究

目录 1. 小程序开发技术原理 2. 用户服务设计3. 数据库设计和管理4. 安全和隐私保护5. 性能优化和测试总结 关于基于小程序的用户服务技术研究,这是一个非常广泛和复杂的领域,需要涉及多个方面的知识和技术。一般来说,基于小程序的用户服务技…

怎么学习数据库连接与操作? - 易智编译EaseEditing

学习数据库连接与操作可以按照以下步骤进行: 理解数据库基础知识: 在学习数据库连接与操作之前,首先要了解数据库的基本概念、组成部分和工作原理。 学习关系型数据库和非关系型数据库的区别,了解常见的数据库管理系统&#xff…

HTTP协议

HTTP协议专门用于定义浏览器与服务器之间交互数据的过程以及数据本身的格式 HTTP概述 HTTP是一种客户端(用户)请求和服务器(网站)应答的标准,它作为一种应用层协议,应用于分布式、协作式和超媒体信息系统…

【springboot】—— 后端Springboot项目开发

后端Springboot项目开发 步骤1 先创建数据库,并在下面创建一个user表,插入数据,sql如下: CREATE TABLE user (id int(11) NOT NULL AUTO_INCREMENT COMMENT ID,email varchar(255) NOT NULL COMMENT 邮箱,password varchar(255)…

王益分布式机器学习讲座~Random Notes (1)

0 并行计算是什么?并行计算框架又是什么 并行计算是一种同时使用多个计算资源(如处理器、计算节点)来执行计算任务的方法。通过将计算任务分解为多个子任务,这些子任务可以同时在不同的计算资源上执行,从而实现加速计…

ChatGLM2-6B发布,位居C-Eval榜首

ChatGLM-6B自2023年3月发布以来,就已经爆火,如今6月25日,清华二代发布(ChatGLM2-6B),位居C-Eval榜单的榜首! 项目地址:https://github.com/THUDM/ChatGLM2-6B HuggingFace&#xf…

Sequential用法

目录 1.官方文档解释 1.1原文参照 1.2中文解释 2.参考代码 3.一些参考使用 3.1生成网络 3.2 感知机的实现 3.3组装网络层 1.官方文档解释 1.1原文参照 A sequential container. Modules will be added to it in the order they are passed in the constructor. A…

【书】《Python全栈测试开发》——浅谈我所理解的『自动化』测试

目录 1. 自动化测试的What and Why?1.1 What1.2 Why2. 自动化的前戏需要准备哪些必备技能?3. 自动化测试类型3.1 Web自动化测试3.1.1 自动化测试设计模式3.1.2 自动化测试驱动方式3.1.3 自动化测试框架3.2 App自动化测试3.3 接口自动化测试4. 自动化调优《Python全栈测试开发…

Springboot钉钉免密登录集成(钉钉小程序和H5微应用)

欢迎访问我的个人博客:www.ifueen.com RT,因为业务需要把我们系统集成到钉钉里面一个小程序和一个H5应用,并且在钉钉平台上面实现无感登录,用户打开我们系统后不需要再输入密码即可登录进系统,查阅文档实际操作过之后记录一下过程…

Qt6.2教程——4.QT常用控件QPushButton

一,QPushButton简介 QPushButton是Qt框架中的一种基本控件,它是用户界面中最常见和最常用的控件之一。QPushButton提供了一个可点击的按钮,用户可以通过点击按钮来触发特定的应用程序操作。比如,你可能会在一个对话框中看到"…

VMware Tools安装“保熟“技巧

网上关于如何安装VMware Tools也有很多帖子,但是基本很难对症下药。下面笔者给出两种情况,读者可根据自己概况定位自己的问题,从而进行解决。 如果读者安装操作系统时是如笔者如下截图 那么读者可参考这个解决方案 安装VMware Tools选项显示灰色的正确解…

高等数学下拾遗+与matlab结合

如何学好高等数学 高等数学是数学的一门重要分支,包括微积分、线性代数、常微分方程等内容,它是许多理工科专业的基础课程。以下是一些学好高等数学的建议: 扎实的基础知识:高等数学的内容很多,包括初等数学的一些基…

【数据库】关系型数据库与非关系型数据库解析

【数据库】关系型数据库与非关系型数据库解析 文章目录 【数据库】关系型数据库与非关系型数据库解析1. 介绍2. 关系型数据库3. 非关系型数据库4. 区别4.1 数据存储方式不同4.2 扩展方式不同4.3 对事务性的支持不同4.4 总结 参考 1. 介绍 一个通俗易懂的比喻:关系型…

哈工大计算机网络传输层协议详解之:可靠数据传输的基本原理

哈工大计算机网络传输层协议详解之:可靠数据传输的基本原理 哈工大计算机网络课程传输层协议详解之:流水线机制与滑动窗口协议哈工大计算机网络课程传输层协议详解之:TCP协议哈工大计算机网络课程传输层协议详解之:拥塞控制原理剖…

Postman中读取外部文件

目录 前言: 一、postman中读取外部文件的格式 二、Postman中如何导入文件 三、在Postman读取导入的数据文件 前言: 在Postman中,您可以使用"数据文件"功能来读取外部文件,如CSV、JSON或Excel文件。这使得在测试中使用…

Bootstrap CSS 概览

文章目录 Bootstrap CSS 概览HTML 5 文档类型(Doctype)移动设备优先响应式图像全局显示、排版和链接基本的全局显示排版链接样式 避免跨浏览器的不一致容器(Container)Bootstrap 浏览器/设备支持 Bootstrap CSS 概览 在这一章中&a…

成为行业风向标,亚马逊云科技近年在数据库排名逐年上升

近10年,全球数据库市场加速变革,云数据库尤其是云原生数据库成为整个数据库市场的关键变量。某种程度上,亚马逊云科技作为全球云原生数据库的领导者,具有行业风向标的价值。 近期,发生了一件对全球数据库市场具有标志性…

爬虫入门指南(4): 使用Selenium和API爬取动态网页的最佳方法

文章目录 动态网页爬取静态网页与动态网页的区别使用Selenium实现动态网页爬取Selenium 的语法及介绍Selenium简介安装和配置创建WebDriver对象页面交互操作 元素定位 等待机制页面切换和弹窗处理截图和页面信息获取关闭WebDriver对象 使用API获取动态数据未完待续.... 动态网页…

GB51309实施后对于消防应急照明和疏散指示系统在城市隧道应用中的影响

安科瑞 崔丽洁 【摘要】:应急照明和疏散指示系统被广泛运用于城市隧道、楼宇建筑、地下管廊等各个方面。当隧道这类特殊建筑内出现火灾或事故时,可靠的应急照明和疏散指示系统对于人员的安全逃生有着重要的作用。随着GB51309-2018《消防应急照明和疏散指…
最新文章