Python深度学习基于Tensorflow(3)Tensorflow 构建模型

文章目录

        • 数据导入和数据可视化
        • 数据集制作以及预处理
        • 模型结构
        • 低阶 API 构建模型
        • 中阶 API 构建模型
        • 高阶 API 构建模型
        • 保存和导入模型

这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。

这里以CIFAR-10为数据集,CIFAR-10为小型数据集,一共包含10个类别的 RGB 彩色图像:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图像的尺寸为 32×32(像素),3个通道 ,数据集中一共有 50000 张训练圄片和 10000 张测试图像。CIFAR-10数据集有3个版本,这里使用Python版本。

数据导入和数据可视化

这里不用书中给的CIFAR-10数据,直接使用TensorFlow自带的玩意导入数据,可能需要魔法,其实TensorFlow中的数据特别的经典。

![[Pasted image 20240506194103.png]]

接下来导入cifar10数据集并进行可视化展示

import matplotlib.pyplot as plt
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# x_train.shape, y_train.shape, x_test.shape, y_test.shape
# ((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))

index_name = {
    0:'airplane',
    1:'automobile',
    2:'bird',
    3:'cat',
    4:'deer',
    5:'dog',
    6:'frog',
    7:'horse',
    8:'ship',
    9:'truck'
}

def plot_100_img(imgs, labels):
    fig = plt.figure(figsize=(20,20))
    for i in range(10):
        for j in range(10):
            plt.subplot(10,10,i*10+j+1)
            plt.imshow(imgs[i*10+j])
            plt.title(index_name[labels[i*10+j][0]])
            plt.axis('off')
    plt.show()

plot_100_img(x_test[:100])

![[Pasted image 20240506200312.png]]

数据集制作以及预处理

数据集预处理很简单就能实现,直接一行代码。

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))

# 提取出一行数据
# train_data.take(1).get_single_element()

这里接着对数据预处理操作,也很容易就能实现。

def process_data(img, label):
    img = tf.cast(img, tf.float32) / 255.0
    return img, label

train_data = train_data.map(process_data)

# 提取出一行数据
# train_data.take(1).get_single_element()

这里对数据还有一些存储和提取操作

dataset 中 shuffle()、repeat()、batch()、prefetch()等函数的主要功能如下。
1)repeat(count=None) 表示重复此数据集 count 次,实际上,我们看到 repeat 往往是接在 shuffle 后面的。为何要这么做,而不是反过来,先 repeat 再 shuffle 呢? 如果shuffle 在 repeat 之后,epoch 与 epoch 之间的边界就会模糊,出现未遍历完数据,已经计算过的数据又出现的情况。
2)shuffle(buffer_size, seed=None, reshuffle_each_iteration=None) 表示将数据打乱,数值越大,混乱程度越大。为了完全打乱,buffer_size 应等于数据集的数量。
3)batch(batch_size, drop_remainder=False) 表示按照顺序取出 batch_size 大小数据,最后一次输出可能小于batch ,如果程序指定了每次必须输入进批次的大小,那么应将drop_remainder 设置为 True 以防止产生较小的批次,默认为 False。
4)prefetch(buffer_size) 表示使用一个后台线程以及一个buffer来缓存batch,提前为模型的执行程序准备好数据。一般来说,buffer的大小应该至少和每一步训练消耗的batch数量一致,也就是 GPU/TPU 的数量。我们也可以使用AUTOTUNE来设置。创建一个Dataset便可从该数据集中预提取元素,注意:examples.prefetch(2) 表示将预取2个元素(2个示例),而examples.batch(20).prefetch(2) 表示将预取2个元素(2个批次,每个批次有20个示例),buffer_size 表示预提取时将缓冲的最大元素数返回 Dataset。

![[Pasted image 20240506201344.png]]

最后我们对数据进行一些缓存操作

learning_rate = 0.0002
batch_size = 64
training_steps = 40000
display_step = 1000

AUTOTUNE = tf.data.experimental.AUTOTUNE
train_data = train_data.map(process_data).shuffle(5000).repeat(training_steps).batch(batch_size).prefetch(buffer_size=AUTOTUNE)

目前数据准备完毕!

模型结构

模型的结构如下,现在使用低阶,中阶,高阶 API 来构建这一个模型

![[Pasted image 20240506202450.png]]

低阶 API 构建模型
import matplotlib.pyplot as plt
import tensorflow as tf

## 定义模型
class CustomModel(tf.Module):
    def __init__(self, name=None):
        super(CustomModel, self).__init__(name=name)
        self.w1 = tf.Variable(tf.initializers.RandomNormal()([32*32*3, 256]))
        self.b1 = tf.Variable(tf.initializers.RandomNormal()([256]))
        self.w2 = tf.Variable(tf.initializers.RandomNormal()([256, 128]))
        self.b2 = tf.Variable(tf.initializers.RandomNormal()([128]))
        self.w3 = tf.Variable(tf.initializers.RandomNormal()([128, 64]))
        self.b3 = tf.Variable(tf.initializers.RandomNormal()([64]))
        self.w4 = tf.Variable(tf.initializers.RandomNormal()([64, 10]))
        self.b4 = tf.Variable(tf.initializers.RandomNormal()([10]))

    def __call__(self, x):
        x = tf.cast(x, tf.float32)
        x = tf.reshape(x, [x.shape[0], -1])
        x = tf.nn.relu(x @ self.w1 + self.b1)
        x = tf.nn.relu(x @ self.w2 + self.b2)
        x = tf.nn.relu(x @ self.w3 + self.b3)
        x = tf.nn.softmax(x @ self.w4 + self.b4)
        return x
model = CustomModel()


## 定义损失
def compute_loss(y, y_pred):
    y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)
    loss = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred)
    return tf.reduce_mean(loss)

## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)

## 定义准确率
def compute_accuracy(y, y_pred):
    correct_pred = tf.equal(tf.argmax(y_pred, axis=1), tf.cast(tf.reshape(y, -1), tf.int64))
    correct_pred = tf.cast(correct_pred, tf.float32)
    return tf.reduce_mean(correct_pred)

## 定义一次epoch
def train_one_epoch(x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x)
        loss = compute_loss(y, y_pred)
        accuracy = compute_accuracy(y, y_pred)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss.numpy(), accuracy.numpy()

## 开始训练

loss_list, acc_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):
    loss, acc = train_one_epoch(batch_x, batch_y)
    loss_list.append(loss)
    acc_list.append(acc)
    if i % 10 == 0:
        print(f'第{i}次训练->', 'loss:' ,loss, 'acc:', acc)
中阶 API 构建模型
## 定义模型
class CustomModel(tf.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense_1 = tf.keras.layers.Dense(256, activation='relu')
        self.dense_2 = tf.keras.layers.Dense(128, activation='relu')
        self.dense_3 = tf.keras.layers.Dense(64, activation='relu')
        self.dense_4 = tf.keras.layers.Dense(10, activation='softmax')
        
    def __call__(self, x):
        x = self.flatten(x)
        x = self.dense_1(x)
        x = self.dense_2(x)
        x = self.dense_3(x)
        x = self.dense_4(x)
        return x

model = CustomModel()

## 定义损失以及准确率
compute_loss = tf.keras.losses.SparseCategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)


## 定义一次epoch
def train_one_epoch(x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x)
        loss = compute_loss(y, y_pred)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss)
    train_accuracy(y, y_pred)


## 开始训练
loss_list, accuracy_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):
    train_one_epoch(batch_x, batch_y)
    loss_list.append(train_loss.result())
    accuracy_list.append(train_accuracy.result())
    if i % 10 == 0:
        print(f"第{i}次训练: loss: {train_loss.result()} accuarcy: {train_accuracy.result()}")
高阶 API 构建模型
## 定义模型
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=[32,32,3]),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax'),
])

## 定义optimizer,loss, accuracy
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

## 开始训练
model.fit(train_data.take(10000))
保存和导入模型

保存模型

tf.keras.models.save_model(model, 'model_folder')

导入模型

model = tf.keras.models.load_model('model_folder')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/609587.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SparkStructuredStreaming状态编程

spark官网关于spark有状态编程介绍比较少,本文是一篇个人理解关于spark状态编程。 官网关于状态编程代码例子: spark/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredComplexSessionization.scala at v3.5.0 apache/spark (github…

智能评估时代:SurveyKing开源问卷系统YYDS

最近有同事在设计问卷系统,我碰巧在 GitHub 上发现了一个开源的问卷/考试系统,觉得它非常不错,给他推荐了下。今天我打算和家人们分享一下这个发现。 项目介绍 官方网站:https://surveyking.cn/ github地址:https://…

springboot整合websocket,超简单入门

springBoot整合webSocket,超简单入门 webSocket简洁 WebSocket 是一种基于 TCP 协议的全双工通信协议,它允许客户端和服务器之间建立持久的、双向的通信连接。相比传统的 HTTP 请求 - 响应模式,WebSocket 提供了实时、低延迟的数据传输能力。…

数据库(MySQL)基础:约束

一、概述 1.概念:约束是作用于表中字段上的规则,用于限制存储在表中的数据。 2.目的:保证数据库中数据的正确、有效性和完整性。 3.分类 约束描述关键字非空约束限制该字段的数据不能为nullnot null唯一约束保证该字段的所有数据都是唯一…

QX---mini51单片机学习---(6)独立键盘

目录 1键盘简绍 2按键的工作原理 3键盘类型 4独立键盘与矩阵键盘的特点 5本节相关原理图 6按键特性 7实践 1键盘简绍 2按键的工作原理 内部使用轻触按键,常态按下按键触点才闭合 3键盘类型 编码键盘与非编码键盘 4独立键盘与矩阵键盘的特点 5本节相关原理…

硬性清空缓存的方法

前端发布代码后,我们是需要刷新页面再验证的。有时候仅仅f5 或者ctrlshiftdelete快捷键仍然有历史缓存,这时可以通过下面的方法硬性清空缓存。 以谷歌浏览器为例,打开f12,右键点击刷新按钮,选择【清空缓存并硬性加载】…

计算机网络5——运输层2TCP原理

文章目录 一、传输控制协议 TCP 概述1、TCP最主要的特点2、TCP的连接 二、可靠传输的工作原理1、停止等待协议1)无差错情况2)出现差错3)确认丢失和确认迟到4)信道利用率 2、连续 ARQ协议 三、TCP 报文段的首部格式 一、传输控制协…

代码审计-PHP模型开发篇动态调试反序列化变量覆盖TP框架原生POP链

知识点 1、PHP审计-动态调试-变量覆盖 2、PHP审计-动态调试-原生反序列化 3、PHP审计-动态调试-框架反序列化PHP常见漏洞关键字 SQL注入: select insert update delete mysql_query mysqli等 文件上传: $_FILES,type"file"&…

Kafka 执行命令超时异常: Timed out waiting for a node assignment

Kafka 执行命令超时异常: Timed out waiting for a node assignment 问题描述: 搭建了一个kafka集群环境,在使用命令行查看已有topic时,报错如下: [rootlocalhost bin]# kafka-topics.sh --list --bootstrap-server…

Vue自定义封装音频播放组件(带拖拽进度条)

Vue自定义封装音频播放组件(带拖拽进度条) 描述 该款自定义组件可作为音频、视频播放的进度条,用于控制音频、视频的播放进度、暂停开始、拖拽进度条拓展性极高。 实现效果 具体效果可以根据自定义内容进行位置调整 项目需求 有播放暂停…

51单片机软件环境安装

keli5的安装 把CID放到破解程序中 破解程序会给一串数字然后填到那个框中 驱动程序的安装 安装完了以后 设备管理器会出现这个 同时c盘会出现这个文件夹

巨量千川的投放技巧,一站式全自动千川投流工具(抖音玩家必备)

随着抖音平台的快速发展,越来越多的品牌和广告商意识到抖音的潜力,并希望能够通过投放广告来获取更多的曝光和用户参与。在这个过程中,巨量千川成为了抖音玩家必备的一站式全自动千川投流工具,为广告商提供了投放技巧,…

word-快速入门

1、熟悉word界面 2、word排版习惯 3、排版文本基本格式 1、word界面 选项卡 功能组 点击功能组右下角小三角可以开启完整功能组,获得启动器 软件右上角有功能显示折叠按钮 2、排版好习惯 (1)随时保存 (2)规范文件命…

408算法题专项-2015

题目: 分析:时间复杂度尽可能高效,提示可能存在一种空间换时间的算法 思路一:空间换时间 思考:开数组储存结点数据域,对于只出现一次或多次出现第一次的,保留,对于多次出现的&…

流程详解!2024年成都市发明专利申请流程及各阶段操作要点

一、受理阶段 时间期限: 电子申请2天内,纸质申请当天现场提交,邮寄约为半月。 申请人: 1. 委托专利代理机构,签订委托代理协议和保密协议等; 2. 提供原始技术资料和个人以及单位信息等; 3…

片冰机工作原理

片冰机工作原理 1、制冰用的水需要加盐(行话叫做加药)至于多少量。看制冰量多少调制泵(柱塞泵)自动调整。 2、制冰机主体分两腔体外腔体内盘的一定密度的铜管。专业术语叫(蒸发腔)就是俗话讲的制冷的东西。 3、外腔体内是一个很规则的圆不锈钢腔体,中心有一三叶刮…

基于Django图像识别系统毕业设计(付源码)

前言:Django是一个由Python编写的具有完整架站能力的开源Web框架,Django本身基于MVC模型,即Model(模型)View(视图) Controller(控制器)设计模式,因此天然具有…

零售数据分析之连带销售分析怎么做

连带销售是指顾客在购买某款产品后,通常会顺手也买上另一款产品。这种情况在超市零售中屡见不鲜,因此通常来说在做超市零售数据分析时,都需要做一个详尽的连带销售分析。那么做零售数据分析中的连带销售分析,要计算分析哪些指标&a…

MBR与GPT分区表

文章目录 MBR分区表MBR分区表结构MBR分区表项查看U盘的分区表信息查看系统中所有磁盘的分区类型获取分区表信息 GPT分区表保护性MBRGPT分区表头格式GPT分区表项格式分区类型分区属性分区表项内容 MBR分区表 CHS :磁头(Heads)、柱面(Cylinder…

AH8651-220V转3.3V低成本方案

本篇文章将介绍一种220V转3.3V低成本方案,该方案采用AH8651芯片,无需外接电感,具有高效率的智能控制、宽广的交流输入范围、内置过流保护、欠压保护和过热自动关断等功能。AH8651可以通过SEL引脚选择输出电压,启动时通过内部高压电…
最新文章