Tensorflow2.0笔记 - metrics做损失和准确度信息度量

        本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读339次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

tf.__version__


#加载fashion mnist数据集
def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)
    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)
    with gzip.open(images_path, 'rb') as imgpath:
        
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)
    return images, labels

#预处理数据
def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32)
    x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.convert_to_tensor(y, dtype=tf.int32)
    return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)

batch_size = 128

train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)

train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)

#定义网络模型
model = Sequential([
    #Layer 1: [b, 784] => [b, 256]
    layers.Dense(256, activation=tf.nn.relu),
    #Layer 2: [b, 256] => [b, 128]
    layers.Dense(128, activation=tf.nn.relu),
    #Layer 3: [b, 128] => [b, 64]
    layers.Dense(64, activation=tf.nn.relu),
    #Layer 4: [b, 64] => [b, 32]
    layers.Dense(32, activation=tf.nn.relu),
    #Layer 5: [b, 32] => [b, 10], 输出类别结果
    layers.Dense(10)
])

#编译网络
model.build(input_shape=[None, 28*28])
model.summary()

#进行训练
total_epoches = 5
learn_rate = 0.01

#Metrics统计
#参考资料:https://zhuanlan.zhihu.com/p/42438077
#1. 新建meter
#acc_meter = metrics.Accuracy()
#loss_meter = metrics.Mean()
#2. 更新状态, update_state()
#loss_meter.update_state(loss)
#acc_meter.update_state(y, pred)
#3.获取结果, result()
#print(step, 'loss:', loss_meter.result().numpy())
#print(step, 'Evaluate Acc:', total_correct/total, acc_meter.result().numpy())
#4.清除度量信息,reset_states()
#loss_meter.reset_states()
#acc_meter.reset_states()


#新建准确度和loss度量对象
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()

optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):
    for step, (x,y) in enumerate(train_db):
        with tf.GradientTape() as tape:
            logits = model(x)
            y_onehot = tf.one_hot(y, depth=10)
            #使用交叉熵作为loss
            loss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))
            #调用update_state更新loss度量信息
            loss_meter.update_state(loss_ce)
        #计算梯度
        grads = tape.gradient(loss_ce, model.trainable_variables)
        #更新梯度
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if step % 100 == 0:
            print("Epoch[", epoch, "]: step-", step, "\tloss: ", loss_meter.result().numpy())
            loss_meter.reset_states()

    #使用测试集进行验证
    total_correct = 0
    total_num = 0
    #清除准确度的统计信息
    acc_meter.reset_states()
    for x,y in test_db:
        logits = model(x)
        #使用softmax得到各个类别的概率
        prob = tf.nn.softmax(logits, axis=1)
        #求出概率最大的结果参数位置,作为预测的分类结果
        pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)
        #比较结果
        correct = tf.equal(pred, y)
        correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))
        #计算精度
        total_correct += int(correct)
        total_num += x.shape[0]
        #使用metircs的update_state进行更新
        acc_meter.update_state(y, pred)
    
    acc = total_correct / total_num
    print("Epoch[", epoch, "] Manual Accuracy:", acc, " Metrics Accuracy:", acc_meter.result().numpy())

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/499048.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI 异构计算机设计方案:902-基于6U VPX 高带宽PCIe的GPU AI 异构计算机

基于6U VPX 高带宽PCIe的GPU AI 异构计算机 一、产品概述 基于6U 6槽 VPX 高带宽PCIe的GPU AI 异构计算机以PCIe总线为架构,通过高带宽的PCIe互联,实现主控计算板、GPU AI板卡,FPGA接口板,存储板的PCIe高带宽互联访问&…

揭秘百度百科审核内幕,百科词条审核究竟需要多久?

百度百科作为国内最大的网络百科全书平台之一,致力于提供全面、准确的知识服务,同时也承担着审核百科词条的工作。在互联网时代,人们对信息的需求日益增长,因此百度百科的审核工作显得尤为重要。那么,百度百科词条审核…

152 Linux C++ 通讯架构实战7 ,makefile编写改成for cpp,读配置文件,内存泄漏查找,设置标题实战

读写配置文件代码实战。nginx.conf 一个项目要启动,需要配置很多信息,第一项就是学习如何配置一个项目 nginx.conf的内容 #是注释行, #每个有效配置项用 等号 处理,等号前不超过40个字符,等号后不超过400个字符&#…

虚幻引擎资源加密方案解析

前段时间,全球游戏开发者大会(Game Developers Conference,简称GDC)在旧金山圆满落幕,会议提供了多份值得参考的数据报告。根据 GDC 调研数据,当下游戏市场中,Unreal Engine (下文简称虚幻)和 Unity 是使用最多的游戏引…

Qlib-Server:量化库数据服务器

Qlib-Server:量化库数据服务器 介绍 Qlib-Server 是 Qlib 的配套服务器系统,它利用 Qlib 进行基本计算,并提供广泛的服务器系统和缓存机制。通过 Qlib-Server,可以以集中的方式管理 Qlib 提供的数据。 框架 Qlib 的客户端/服务器框架基于 WebSocket 构建,这是因为 WebS…

js动态设置页面高度

准备一个div <div class"card-edit"><!-- 业务需求 --> </div>开始操作 // 获取页面的中需要设置高度的元素 如&#xff1a;card-editconst autoStyle document.getElementsByClassName(card-edit)[0]// 根据业务需求做判断// 此处设定&#…

DC电源模块的设计与制造流程

BOSHIDA DC电源模块的设计与制造流程 DC电源模块是一种用于将交流电转换为直流电的设备。它广泛应用于各种电子设备中&#xff0c;如电子产品、工业仪器、电视等。下面是DC电源模块的设计与制造流程的简要描述&#xff1a; 1. 需求分析&#xff1a;在设计DC电源模块之前&#…

sql——对于行列转换相关的操作

目录 一、lead、lag 函数 二、wm_concat 函数 三、pivot 函数 四、判断函数 遇到需要进行行列转换的数据处理需求&#xff0c;以 oracle 自带的表作为例子复习一下&#xff1a; 一、lead、lag 函数 需要行列转换的表&#xff1a; select deptno,count(empno) emp_num from…

R语言赋值符号<-、=、->、<<-、->>的使用与区别

R语言的赋值符号有&#xff1c;-、、-&#xff1e;、&#xff1c;&#xff1c;-、-&#xff1e;&#xff1e;六种&#xff0c;它们的使用与区别如下: <-’&#xff1a;最常用的赋值符号。它将右侧表达式的值赋给左侧的变量&#xff0c;像一个向左的箭头。例如&#xff0c;x …

Python-VBA编程500例-024(入门级)

字符串写入的行数(Line Count For String Writing)在实际应用中有着广泛的应用场景。常见的应用场景有&#xff1a; 1、文本编辑及处理&#xff1a;在编写或编辑文本文件时&#xff0c;如使用文本编辑器或文本处理器&#xff0c;经常需要处理字符串并确定其在文件中的行数。这…

【数据结构 | 图论】如何用链式前向星存图(保姆级教程,详细图解+完整代码)

一、概述 链式前向星是一种用于存储图的数据结构&#xff0c;特别适合于存储稀疏图&#xff0c;它可以有效地存储图的边和节点信息&#xff0c;以及边的权重。 它的主要思想是将每个节点的所有出边存储在一起&#xff0c;通过数组的方式连接&#xff08;类似静态数组实现链表…

基于Springboot+Vue的酒店管理系统!新鲜出炉,可商用,带源码

新年了给大家分享一套基于SpringbootVue的酒店管理系统源码&#xff0c;在实际项目中可以直接复用。(免费提供&#xff0c;文末自取) 一、系统运行图&#xff08;管理端和用户端&#xff09; 1、管理登陆 2、房间管理 3、订单管理 4、用户登陆 5、房间预定 二、系统搭建视频教…

JavaEE—— HTTP协议和与Tomcat (末篇)

本篇文章&#xff0c;承接前面两篇文章&#xff1a; 在前面的两篇文章中&#xff0c;简单介绍了 什么是 HTTP 协议&#xff0c;介绍了抓包工具&#xff0c;如何构造 HTTP 请求&#xff0c;以及&#xff0c;如何使用第三方工具来简化构造请求的过程。 如果需要了解前面的知识可…

算法---动态规划练习-6(地下城游戏)

地下城游戏 1. 题目解析2. 讲解算法原理3. 编写代码 1. 题目解析 题目地址&#xff1a;点这里 2. 讲解算法原理 首先&#xff0c;定义一个二维数组 dp&#xff0c;其中 dp[i][j] 表示从位置 (i, j) 开始到达终点时的最低健康点数。 初始化数组 dp 的边界条件&#xff1a; 对…

AI赋能微服务:Spring Boot与机器学习驱动的未来应用开发

&#x1f9d1; 作者简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟。提供嵌入式方向的学习指导、简历面…

实践笔记-harbor搭建(版本:2.9.0)

harbor搭建 1.下载安装包&#xff08;版本&#xff1a;2.9.0&#xff09;2.修改配置文件3.安装4.访问harbor5.可能用得上的命令: 环境&#xff1a;centos7 1.下载安装包&#xff08;版本&#xff1a;2.9.0&#xff09; 网盘资源&#xff1a;https://pan.baidu.com/s/1fcoJIa4x…

Vue中的一些指令与计算方法

语法 插值语法 HTML的双标签内容中使用&#xff0c;在{{}}之内书写JS代码 属性语法 1.v-bind或: 2.:属性名"值"或v-bind"值" 事件语法 v-on或 v-on:事件名"方法名"或事件名"方法名" 选项 选项&#xff1a;可选的配置项——官方…

vue3封装Element动态表单组件

1. 封装组件DymanicForm.vue 使用component实现动态组件组件不能直接使用字符串传入&#xff0c;所以根据传入的组件名称找到对应的组件校验规则&#xff0c;可使用rule传入自定义规则&#xff0c;也可以使用封装好的基本规则 示例中使用了checkRequired暴露重置方法和校验方法…

奥比中光Astra SDK相机SDK openni相机成像原理

目录 1.1 成像原理简介 1.1.1 结构光 1.1.2 双目视觉 1.1.3 光飞行时间TOF​ 2.使用手册 参考网址 2.1 产品集成设计 2.2 SDK介绍与使用 2.3 常用API介绍 OPENNI API 2 OpenNI类&#xff08;OpenNI.h&#xff09; 1.1 成像原理简介 1.1.1 结构光 结构光&#xff0…

Elastic 8.13:Elastic AI 助手中 Amazon Bedrock 的正式发布 (GA) 用于可观测性

作者&#xff1a;来自 Elastic Brian Bergholm 今天&#xff0c;我们很高兴地宣布 Elastic 8.13 的正式发布。 有什么新特性&#xff1f; 8.13 版本的三个最重要的组件包括 Elastic AI 助手中 Amazon Bedrock 支持的正式发布 (general availability - GA)&#xff0c;新的向量…