深度学习基础知识-tf.keras实例:衣物图像多分类分类器

参考书籍:《Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition (Aurelien Geron [Géron, Aurélien])》


在这里插入图片描述
本次使用的数据集是tf.keras.datasets.fashion_mnist,里面包含6w张图,涵盖10个分类。

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import pickle

'''
fashion_mnist = keras.datasets.fashion_mnist.load_data()
with open('fashion_mnist.pkl', 'wb') as f:
    pickle.dump(fashion_mnist, f)
'''
def load_data():
    with open('fashion_mnist.pkl', 'rb') as f:
        mnist = pickle.load(f)
    (X_train_full, y_train_full), (X_test, y_test) = mnist
    X_valid, X_train = X_train_full[:5000] / 255.0, X_train_full[5000:] / 255.0
    y_valid, y_train = y_train_full[:5000], y_train_full[5000:]
    X_test = X_test / 255.0
    return X_train, X_valid, X_test, y_train, y_valid, y_test

class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
 "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
#print(class_names[y_train[0]]) # coat
# 查看
some_image = X_train[0]
plt.imshow(some_image, cmap="binary")
plt.axis("off")
plt.show()

随便拿一张来看:
在这里插入图片描述
构建网络:

'''
model = keras.models.Sequential()
# 28x28 -> 1x784 也可以用InputLayer(input_shape=[28, 28])
model.add(keras.layers.Flatten(input_shape=[28, 28]))
# 其他激活函数:https://keras.io/api/layers/activations/
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
# output layer. 10个输出
model.add(keras.layers.Dense(10, activation="softmax"))
'''
# 也可以这么写:
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28, 28]),
    keras.layers.Dense(300, activation="relu"),
    keras.layers.Dense(100, activation="relu"),
    keras.layers.Dense(10, activation="softmax")
])
# 下面235500 = 784 x 300 + 300, 前面表示每个input都要跑向300个节点,所以要给权重w。然后每个节点要加一个偏置b
# 30100 = 300 x 100 + 100
print(model.summary())
'''
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 300)               235500    
                                                                 
 dense_1 (Dense)             (None, 100)               30100     
                                                                 
 dense_2 (Dense)             (None, 10)                1010      
                                                                 
=================================================================
Total params: 266,610
Trainable params: 266,610
Non-trainable params: 0
_________________________________________________________________
'''

import pydot
keras.utils.plot_model(model, 'model.png')

在这里插入图片描述


pydot安装与plot_model报错的解决:

参考:https://blog.csdn.net/shangxiaqiusuo1/article/details/85283432

先下载 https://graphviz.gitlab.io/_pages/Download/windows/graphviz-2.38.msi
然后双击,安装到D:\Program Files (x86)\Graphviz2.38\

  1. 建立变量名GRAPHVIZ_DOT,值为D:\Program Files (x86)\Graphviz2.38\bin\dot.exe
  2. 在用户环境变量添加一个新的变量:建立变量名 GRAPHVIZ_INSTALL_DIR, 值为D:\Program Files (x86)\Graphviz2.38
  3. 在系统环境变量的PATH中添加Graphviz的bin目录路径,如D:\Program Files (x86)\Graphviz2.38\bin
pip install graphviz
pip install pydot
pip install pydot-ng

在python文件中输入import pydot,然后按住Ctrl+鼠标左键点击pydot,会进入pydot的源文件,然后找到 self.prog = ‘dot’ ,改成 self.prog = ‘dot.exe’

这样改完如果还不行,在python文件里添加:

import os
os.environ["PATH"] += os.pathsep + 'D:/Program Files (x86)/Graphviz2.38/bin/'

基本上这样就ok了。


# 用index或名字都可以access层
hidden1 = model.layers[1]
print(model.get_layer('dense') is hidden1)
weights, biases = hidden1.get_weights()
print(weights)
# bias最开始初始化为0
print(biases)

设置+训练模型:

# 使用这个loss是因为数据有10种离散的、互斥的标签
# optimizer=keras.optimizers.SGD(lr=xx)
# 这样可以设置学习率。default lr=0.01
model.compile(loss="sparse_categorical_crossentropy",
              optimizer="sgd",
              metrics=["accuracy"])
X_train, X_valid, X_test, y_train, y_valid, y_test = load_data() # 获取数据的函数,略
history = model.fit(X_train, y_train, epochs=30, validation_data=(X_valid, y_valid))
'''
Epoch 1/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.6992 - accuracy: 0.7704 - val_loss: 0.4999 - val_accuracy: 0.8378
Epoch 2/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.4876 - accuracy: 0.8311 - val_loss: 0.4659 - val_accuracy: 0.8364
Epoch 3/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.4440 - accuracy: 0.8449 - val_loss: 0.4394 - val_accuracy: 0.8480
...
Epoch 29/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.2330 - accuracy: 0.9160 - val_loss: 0.3047 - val_accuracy: 0.8906
Epoch 30/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.2291 - accuracy: 0.9167 - val_loss: 0.2969 - val_accuracy: 0.8938
'''

这里设置了30次循环,其实未必达到最优,也基本不会过拟合。

如果训练集是有偏的,比如某些类overrepresented,某些类underrepresented,那么在fit()前应该设置class_weight,给underrepresented类以更大的权重,overrepresented类以更小的权重。如果有的case需要格外注意,比如某些cases是专家标注,另外一些是普通标注的,那么可以用per-instance weights,即设置sample_weight。如果class_weight和sample_weight都设置了,keras会把它们相乘。另外,也可以为验证集单独设置sample weights。

另外,history.history是个字典,里面包含loss,accuracy,val_loss, val_accuracy(每个都是epochs个数据),所以可以画图:
(就是把上面打印的信息以图的方式反映出来)

import pandas as pd
import matplotlib.pyplot as plt

print(history.history)
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
# 'gca'代表Get Current Axis
plt.gca().set_ylim(0, 1) 
plt.show()

在这里插入图片描述
为什么在前几个epoch上,validation的结果看起来比train好?
因为validation error是在每个epoch结束时计算的,而training error是一个running mean,即在每个epoch运行时计算的,所以training的图像应该移半个epoch。此时前几个epoch的图像应该是比较近似的,甚至overlap。

等调完超参数(如learning rate, layer num, batch_size…),评估一下模型:

print(model.evaluate(X_test, y_test))
'''
loss, accuracy
[0.34293538331985474, 0.8896999955177307]
'''

注意:如果loss很大可能是X_test没有归一化。

保存模型:其他保存模型的方法:https://blog.csdn.net/qq_22841387/article/details/130194553

import joblib
joblib.dump(model, "my_model.joblib")

# 导入使用load,导入后可以使用新数据继续训练 model.fit()
model = joblib.load("my_model.joblib")

预测

X_new = X_test[:3]
y_proba = model.predict(X_new)
print(y_proba.round(2))
'''
[[0.   0.   0.   0.   0.   0.   0.   0.   0.   1.  ]
 [0.   0.   0.99 0.   0.01 0.   0.   0.   0.   0.  ]
 [0.   1.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
'''

import numpy as np
y_pred = model.predict(X_new)
# tf 2.6前可以用model.predict_classes(X_new), 2.6开始删除了该函数
labels = np.argmax(y_pred, axis=1)
print(labels) # [9 2 1]
# 显示分类名称
print(np.array(class_names)[labels]) # ['Ankle boot' 'Pullover' 'Trouser']

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/26939.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Jmeter性能测试 (入门)

Jmeter是一款优秀的开源测试工具, 是每个资深测试工程师,必须掌握的测试工具,熟练使用Jmeter能大大提高工作效率。 熟练使用Jmeter后, 能用Jmeter搞定的事情,你就不会使用LoadRunner了。 本文将通过一个实际的测试例…

IPV6地址基础

IPv6是英文“Internet Protocol Version 6”(互联网协议第6版)的缩写,是互联网工程任务组(IETF)设计的用于替代IPv4的下一代IP协议。其地址数量号称可以为全世界的每一粒沙子编上一个地址 1. ipv6地址表示方法 IPv6的…

MQTT与EMQ

文章目录 1 MQTT协议与EMQ中间件1.1 物联网消息协议MQTT1.1.1 什么是MQTT1.1.2 MQTT相关概念1.1.3 消息服务质量QoS——信息的可靠投递1.1.3.1 QoS0——消息服务质量为0,消息发送至多一次1.1.3.2 QoS1——消息发送至少一次1.1.3.3 QoS2——消息发送仅一次1.1.3.4 不…

Oracle中的数据导出(4)

目录 法一:使用SQL plus命令脚本 法二:使用PLSQL Developer工具 前几篇文章描述了如何将Oracle中的数据导出到库外,但是导出的数据结果都是文本文档,这样页面查看不和谐,编辑又略显麻烦。因此这篇文章将描述如何将Or…

Pb协议的接口测试

【摘要】 Protocol Buffers 是谷歌开源的序列化与反序列化框架。它与语言无关、平台无关、具有可扩展的机制。用于序列化结构化数据,此工具对标 XML ,支持自动编码,解码。比 XML 性能好,且数据易于解析。更多有关工具的介绍可参考…

氟化物选择吸附树脂Tulsimer ®CH-87 ,锂电行业废水行业矿井水除氟专用树脂

氟化物选择吸附树脂 Tulsimer CH-87 是一款去除水溶液中氟离子的专用的凝胶型选择性离子交换树脂。它是具有氟化物选择性官能团的交联聚苯乙烯共聚物架构的树脂。 去除氟离子的能力可以达到 1ppm 以下的水平。中性至碱性的PH范围内有较好的工作效率,并且很容易再生…

Vue.js 中的过渡动画是什么?如何使用过渡动画?

Vue.js 中的过渡动画是什么?如何使用过渡动画? 在 Vue.js 中,过渡动画是一种在元素插入、更新或删除时自动应用的动画效果,可以为应用程序增加一些动态和生动的效果。本文将介绍 Vue.js 中过渡动画的概念、优势以及如何使用过渡动…

Nginx正则表达式、location、rewrite

目录 一、常用的Nginx正则表达式 二:localtion 1、location 分类 2、 location 常用的匹配规则 3、location 优先级 4、 location 示例 5、优先级总结 6、实际网站使用中,至少有三个匹配规则定义 (1)第一个必选规则 &…

chatgpt赋能python:将一行数变成列——Python简单实现

将一行数变成列——Python简单实现 在数据处理时,我们常常会遇到将一行的数据转换成列的情况,例如将多个数据在Excel表格中拆分为不同的列。这时候,Python可以帮助我们快速实现这个功能。 什么是Python? Python是一种高级&…

N-propargyloxycarbonyl-L-lysine,1215204-46-8,是一种基于赖氨酸的非天然氨基酸 (UAA)

产品描述: N-ε-propargyloxycarbonyl-L-lysine (H-L-Lys(Poc)-OH) 是一种基于赖氨酸的非天然氨基酸 (UAA)。 广泛用于多种生物体中荧光探针的生物偶联。 N- ε- Propargyloxycarbonyl-L-lysine (H-L-Lys (Poc) - OH) is a non natural amino acid (UAA) based on …

IDEA 终端命令行设置

一、说明 在使用 IDEA 进行程序开发时,需要使用到终端 Terminal 的功能,便于能够快速使用 shell 命令,进行各种相关的操作。 这些操作可以包括代码的版本控制、程序的打包部署等等 比如,前后端的集成开发环境(IDEA、We…

.gitignore忽略文件不生效

前言 .gitignore忽略文件时git仓库很重要的一个配置,在创建仓库时就会有模板选择和忽略文件。 .gitignore忽略文件意思是在上传到代码仓库时,控制把哪些代码文件不上传到代码仓库。 在实际开发中其实写的代码是没有多大的,主要的是插件本地…

基于牛顿拉夫逊的配电网潮流计算研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

Symfony v6.2.11 正式发布,经典 PHP Web 开发框架

导读Symfony v6.2.11 发布了!Symfony 是一款基于 MVC 架构的 PHP 框架,致力于减少重复代码的编写,以加速 Web 应用的开发和维护。Symfony 与许多关系型数据库集成的也非常好,成本也较小。 此外,Symfony 致力于在企业背…

技术赋能-混流编排功能,助力京东618直播重保 | 京东云技术团队

每每到618、双11这样的大型活动的时候,每天都有几个重要的大v或者品牌直播需要保障。 以往的重点场次监播方式是这么造的: 对每路直播的源流、各档转码流分别起一个ffplay播放窗口,再手动调整尺寸在显示器桌面进行布局,排到一屏…

怎么用u盘制作pe系统启动盘

PE系统是一种小型的windows系统,通俗的说法也就是在电脑出现问题不能正常进入系统时的一种紧急备用系统。它容量小能量大,可以解决win系统中经常遇到的一些问题,对于经常使用电脑的用户来说,制作一个pe系统启动盘放在身边是很有必…

C++11:右值引用,实现移动语义和完美转发

目录 1、右值引用 2、移动语义(std::move) 3、完美转发(std::forward) 1、右值引用 右值引用(Rvalue reference)是C11引入的一个新特性,它是一种新的引用类型,用于表示将要被移动…

接口测试 —— 接口测试定义

1、接口测试概念 (重点) 接口测试是测试系统组件间接口的一种测试,它界于单元测试与系统测试中间。 接口测试主要用于检测外部系统与系统之间以及内部各个子系统之间的交互点。 测试的重点是要检查数据的交换,传递和控制管理过…

浅入浅出 iptables 网络隔离原理

01 iptables简介 iptables ipfirewall(内核1.x时代) ipchains(内核2.x时代) iptables 网络协议栈 Link Layer 数据链路层的数据流向,根据mac寻址找到对应的网卡后向上进入网络层 Network Layer 网络层的数据流向&am…

磁盘配额与进阶文件系统管理(一)

磁盘配额Quota 用途 针对www server,例如 每个人网页空间的容量限制;针对mail server,例如 每个人的邮件空间限制;针对file server,例如 每个人最大可用的网络硬盘空间;限制某一群组所能使用的最大磁盘空…