Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13)

        本笔记记录CNN做CIFAR100数据集的训练相关内容,代码中使用了类似VGG13的网络结构,做了两个Sequetial(CNN和全连接层),没有用Flatten层而是用reshape操作做CNN和全连接层的中转操作。由于网络层次较深,参数量相比之前的网络多了不少,因此只做了10次epoch(RTX4090),没有继续跑了,最终准确率大概在33.8%左右。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)

def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    return x,y

y_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)

batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)

sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, 
         tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))


#创建CNN网络,总共4个unit,每个unit主要是两个卷积层和Max Pooling池化层
cnn_layers = [
    #unit 1
    layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 2
    layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 3
    layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 4
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 5
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),
]


def main():
    #[b, 32, 32, 3] => [b, 1, 1, 512]
    cnn_net = Sequential(cnn_layers)
    cnn_net.build(input_shape=[None, 32, 32, 3])
    
    #测试一下卷积层的输出
    #x = tf.random.normal([4, 32, 32, 3])
    #out = cnn_net(x)
    #print(out.shape)

    #创建全连接层, 输出为100分类
    fc_net = Sequential([
        layers.Dense(256, activation='relu'),
        layers.Dense(128, activation='relu'),
        layers.Dense(100, activation=None),
    ])
    fc_net.build(input_shape=[None, 512])

    #设置优化器
    optimizer = optimizers.Adam(learning_rate=1e-4)

    #记录cnn层和全连接层所有可训练参数, 实现的效果类似list拼接,比如
    # [1, 2] + [3, 4] => [1, 2, 3, 4]
    variables = cnn_net.trainable_variables + fc_net.trainable_variables
    #进行训练
    num_epoches = 10
    for epoch in range(num_epoches):
        for step, (x,y) in enumerate(train_db):
            with tf.GradientTape() as tape:
                #[b, 32, 32, 3] => [b, 1, 1, 512]
                out = cnn_net(x)
                #flatten打平 => [b, 512]
                out = tf.reshape(out, [-1, 512])
                #使用全连接层做100分类logits输出
                #[b, 512] => [b, 100]
                logits = fc_net(out)
                #标签做one_hot encoding
                y_onehot = tf.one_hot(y, depth=100)
                #计算损失
                loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
                loss = tf.reduce_mean(loss)
            #计算梯度
            grads = tape.gradient(loss, variables)
            #更新参数
            optimizer.apply_gradients(zip(grads, variables))

            if (step % 100 == 0):
                print("Epoch[", epoch + 1, "/", num_epoches, "]: step-", step, " loss:", float(loss))
        #进行验证
        total_samples = 0
        total_correct = 0
        for x,y in test_db:
            out = cnn_net(x)
            out = tf.reshape(out, [-1, 512])
            logits = fc_net(out)
            prob = tf.nn.softmax(logits, axis=1)
            pred = tf.argmax(prob, axis=1)
            pred = tf.cast(pred, dtype=tf.int32)
            correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
            correct = tf.reduce_sum(correct)

            total_samples += x.shape[0]
            total_correct += int(correct)

        #统计准确率
        acc = total_correct / total_samples
        print("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)
if __name__ == '__main__':
    main()

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/558702.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

在 Node.js 中配置代理 IP 采集文章

不说废话,直接上代码: const http require(http); const https require(https);// 之后可以使用 http 或 https 模块发起请求,它们将自动使用配置的代理 // 代理ip:https://www.kuaidaili.com/?refrg3jlsko0ymg const proxy …

JavaScript算数运算符

源码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> </head> <b…

Bert语言大模型基础

一、Bert整体模型架构 基础架构是transformer的encoder部分&#xff0c;bert使用多个encoder堆叠在一起。 主要分为三个部分&#xff1a;1、输入部分 2、注意力机制 3、前馈神经网络 bertbase使用12层encoder堆叠在一起&#xff0c;6个encoder堆叠在一起组成编码端&#xf…

ZooKeeper设置监听器

ZooKeeper设置监听器&#xff0c;通过getData()/getChildern()/xists()方法。 步骤&#xff1a; 1.创建监听器&#xff1a;创建一个实现Watcher接口的类&#xff0c;实现process()方法。这个方法会在ZooKeeper向客户端发送一个Watcher事件通知的时候被调用。 2.注册监听器&…

【工厂模式】工厂方法模式、抽象工厂模式-简单例子

简单工厂模式&#xff0c;请跳转到我的另一篇博客【工厂模式】简单工厂模式-简单例子-CSDN博客 四、工厂方法模式 &#xff08;1&#xff09;这部分还是不变&#xff0c;创建一个Car接口&#xff0c;和两个实现类。 public interface Car {void name(); }public class WuLing…

深入刨析 mysql 底层索引结构B+树

文章目录 前言一、什么是索引&#xff1f;二、不同索引结构对比2.1 二叉树2.2 平衡二叉树2.3 B-树2.4 B树 三、mysql 的索引3.1 聚簇索引3.2 非聚簇索引 前言 很多人看过mysql索引的介绍&#xff1a;hash表、B-树、B树、聚簇索引、主键索引、唯一索引、辅助索引、二级索引、联…

C#语法知识之循环语句

5、循环语句 文章目录 1、while思考1 斐波那契数列思考2 判断一个数是否为质数思考3 找出100以内的质数 2、do...while3、for思考1 找水仙花数思考2 乘法表 1、while 1、作用 让代码重复去执行 2、语法相关 while(bool类型值){//当满足条件时&#xff0c;就会执行while语句…

大话设计模式-里氏代换原则

里氏代换原则&#xff08;Liskov Substitution Principle&#xff0c;LSP&#xff09; 概念 里氏代换原则是面向对象设计的基本原则之一&#xff0c;由美国计算机科学家芭芭拉利斯科夫&#xff08;Barbara Liskov&#xff09;提出。这个原则定义了子类型之间的关系&#xff0…

linux下使用qt+mpv调用GPU硬件解码

linux下GPU硬件解码接口&#xff0c;常用的有vdpau和vaapi。 mpv是基于mplayer开发的一个播放器。此外&#xff0c;mpv还提供了函数库libmpv&#xff0c;通过使用libmpv可以编写一个简单的播放器。 基于qtlibmpv的demo&#xff0c;官方例子代码如下&#xff1a;https://github.…

Java maven项目打包自动测试并集成jacoco生成代码测试覆盖度报告

引入Junit 引入 junit5 单元测试依赖 <properties><junit.version>5.10.2</junit.version><jacoco.version>0.8.12</jacoco.version></properties><dependencies><!-- 单元测试 --><dependency><groupId>org.jun…

JUC 线程间通信

前言 本篇文章我将解释《并发编程的艺术》一书中一个经典的实现线程间通信的案例&#xff0c;主要是使用wait() 和 notifyAll() 方法来实现的。 这段代码的作用是通过 wait() 和 notifyAll() 方法实现线程间的等待和通知机制。具体来说&#xff0c;代码中创建了两个线程&…

论文阅读-Multiple Targets Directed Greybox Fuzzing (Hongliang Liang,2024)

标题: Multiple Targets Directed Greybox Fuzzing (Hongliang Liang,2024) 作者: Hongliang Liang, Xinglin Yu, Xianglin Cheng, Jie Liu, Jin Li 期刊: IEEE Transactions on Dependable and Secure Computing 研究问题: 发现局限性&#xff1a;之前的定向灰盒测试在有…

webAssembly学习及使用rust

学习理解 webAssembly 概念知识&#xff0c;使用 API 进行 web 前端开发。 概念 是一种运行在现代网络浏览器中的新型代码&#xff0c;并且提供新的性能特性和效果。它有一种紧凑的二进制格式&#xff0c;使其能够以接近原生性能的速度运行。C/C、 C#、Rust等语言可以编译为 …

ruby 配置代理 ip(核心逻辑)

在 Ruby 中配置代理 IP&#xff0c;可以通过设置 Net::HTTP 类的 Proxy 属性来实现。以下是一个示例&#xff1a; require net/http// 获取代理Ip&#xff1a;https://www.kuaidaili.com/?refrg3jlsko0ymg proxy_address 代理IP:端口 uri URI(http://www.example.com)Net:…

【React】Sigma.js框架网络图-入门篇

一、介绍 Sigma.js是一个专门用于图形绘制的JavaScript库。 它使在Web页面上发布网络变得容易&#xff0c;并允许开发人员将网络探索集成到丰富的Web应用程序中。 Sigma.js提供了许多内置功能&#xff0c;例如Canvas和WebGL渲染器或鼠标和触摸支持&#xff0c;以使用户在网页上…

MATLAB R2024a:重塑商业数学软件的未来

在数字化浪潮席卷全球的今天&#xff0c;商业数学软件已经成为企业、研究机构乃至个人不可或缺的工具。而在这其中&#xff0c;MATLAB R2024a以其卓越的性能和广泛的应用领域&#xff0c;正逐步成为商业数学软件的新标杆。 MATLAB R2024a不仅继承了前代版本的优秀基因&#xf…

Golang 采集爬虫如何配置代理 IP

在 Golang 中配置代理 IP&#xff0c;可以通过设置 http.Transport 的 Proxy 属性来实现&#xff1a; 下述代码中的 代理IP 和 端口 替换为实际的代理服务器地址和端口&#xff0c;然后运行该程序即可通过代理服务器访问对应网站。 package mainimport ("fmt""…

超详细的Maven安装与使用还有内容讲解

文章目录 作用简介模型仓库 安装配置IDEA配置Maven坐标概念主要组成 IDEA创建Maven项目基本使用常用命令生命周期使用坐标导入jar包 注意事项清理maven仓库更新索引依赖 作用 Maven是专门用于管理和构建Java项目的工具&#xff0c;它的主要功能有&#xff1a; 提供了一套标准化…

MATLAB实现禁忌搜索算法优化柔性车间调度fjsp

禁忌搜索算法的流程可以归纳为以下几个步骤&#xff1a; 初始化&#xff1a; 利用贪婪算法或其他局部搜索算法生成一个初始解。清空禁忌表。设置禁忌长度&#xff08;即禁忌表中禁止操作的期限&#xff09;。邻域搜索产生候选解&#xff1a; 通过特定的搜索算子&#xff08;如…

AWS账号注册以及Claude 3 模型使用教程!

哈喽哈喽大家好呀&#xff0c;伙伴们&#xff01;你听说了吗&#xff1f;最近AWS托管了大热模型&#xff1a;Claude 3 Opus&#xff01;想要一探究竟吗&#xff1f;那就赶紧来注册AWS账号吧&#xff01;别担心&#xff0c;现在注册还免费呢&#xff01;而且在AWS上还有更多的大…