02、Tensorflow实现手写数字识别(数字0-9)

02、Tensorflow实现手写数字识别(数字0-9)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。

基于Tensorflow 2.10.0与pycharm

1、识别目标

识别手写仅仅是为了区分手写的0到9,所以实际上是一个多分类问题。
STEP1:导入相关包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.metrics import  classification_report
import matplotlib.pyplot as plt
import logging
import warnings

import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。

import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。

from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。

from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。

from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。

from sklearn.metrics import classification_report 这行代码的主要作用是导入classification_report 函数,以便在后续的代码中使用它来评估分类模型的性能。

import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。

import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。

import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。


STEP2:屏蔽无用警告并允许中文

# 使用warnings模块来忽略特定类型的警告  
warnings.simplefilter(action='ignore', category=FutureWarning)  
# 配置tensorflow的日志记录级别  
logging.getLogger("tensorflow").setLevel(logging.ERROR)  
# 设置TensorFlow的autograph模块的详细级别  
tf.autograph.set_verbosity(0)  
# 设置numpy的打印选项  
np.set_printoptions(precision=2)  

STEP3:加载数据集并分割测试集

# load dataset
def load_data():
    X = np.load("Handwritten_Digit_Recognition_Multiclass_data/X.npy")
    y = np.load("Handwritten_Digit_Recognition_Multiclass_data/y.npy")
    return X, y

# load dataset
X, y = load_data()

print ('The shape of X is: ' + str(X.shape))
print ('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

原始的输入的数据集是5000* 400数组,共包含5000个手写数字的数据,其中400为20*20像素的图片,
在这里插入图片描述


STEP4:模型构建与训练

# 构建模型  
tf.random.set_seed(1234)  # 设置随机种子以确保每次运行的结果是一致的  
model = Sequential(
    [
        ### START CODE HERE ###  
        tf.keras.Input(shape=(400,)),  # 输入层,输入数据的形状是400维   
        Dense(100, activation='relu', name="L1"),  # 全连接层,100个神经元,使用ReLU激活函数,命名为"L1"  
        Dense(75, activation='relu', name="L2"),  # 全连接层,75个神经元,使用ReLU激活函数,命名为"L2"  
        Dense(10, activation='linear', name="L3"),  # 输出层,10个神经元,使用线性激活函数,命名为"L3"  
        ### END CODE HERE ###  
    ], name="my_model"
)  # 定义模型名称为"my_model"  
model.summary()  # 打印模型的概述信息  

# 配置模型的训练参数  
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # 使用稀疏分类交叉熵作为损失函数,且输出是logits(即未经过softmax的原始输出)  
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),  # 使用Adam优化器,并设置学习率为0.001  
)

# 训练模型  
history = model.fit(
    X_train, y_train,  # 使用X_train作为输入数据,y_train作为目标数据  
    epochs=100  # 训练100轮  
)


STEP5:结果可视化与打印准确度信息

fig, axes = plt.subplots(20, 25, figsize=(20, 25))
fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):
    # Select random indices
    random_index = np.random.randint(X_test.shape[0])
    # Select rows corresponding to the random indices and
    # reshape the image
    X_random_reshaped = X_test[random_index].reshape((20, 20)).T
    # Display the image
    ax.imshow(X_random_reshaped, cmap='gray')
    # Predict using the Neural Network
    prediction = model.predict(X_test[random_index].reshape(1, 400))
    prediction_p = tf.nn.softmax(prediction)
    yhat = np.argmax(prediction_p)
    # 错误结果标红
    if y_test[random_index, 0] == yhat:
        ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10)
        ax.set_axis_off()
    else:
        ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10, color='red')
        ax.set_axis_off()

fig.suptitle("Label, yhat", fontsize=14)
plt.show()

# 给出预测的测试集误差
def evaluation(y_test, y_predict):
    accuracy=classification_report(y_test, y_predict,output_dict=True)['accuracy']
    s=classification_report(y_test, y_predict,output_dict=True)['weighted avg']
    precision=s['precision']
    recall=s['recall']
    f1_score=s['f1-score']
    #kappa=cohen_kappa_score(y_test, y_predict)
    return accuracy,precision,recall,f1_score #, kappa

y_pred=model.predict(X_test)
prediction_p = tf.nn.softmax(y_pred)
yhat = np.argmax(prediction_p, axis=1)
accuracy,precision,recall,f1_score=evaluation(y_test,yhat)

print("测试数据集准确率为:", accuracy)
print("测试数据集精确率为:", precision)
print("测试数据集召回率为:", recall)
print("测试数据集F1_score为:", f1_score)

3、运行结果

在这里插入图片描述

4、工程下载与全部代码

工程链接:Tensorflow实现手写数字识别(数字0-9)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.metrics import  classification_report
import matplotlib.pyplot as plt
import logging
import warnings

# 使用warnings模块来忽略特定类型的警告
warnings.simplefilter(action='ignore', category=FutureWarning)
# 配置tensorflow的日志记录级别
logging.getLogger("tensorflow").setLevel(logging.ERROR)
# 设置TensorFlow的autograph模块的详细级别
tf.autograph.set_verbosity(0)
# 设置numpy的打印选项
np.set_printoptions(precision=2)

# load dataset
def load_data():
    X = np.load("Handwritten_Digit_Recognition_Multiclass_data/X.npy")
    y = np.load("Handwritten_Digit_Recognition_Multiclass_data/y.npy")
    return X, y

# load dataset
X, y = load_data()

print ('The shape of X is: ' + str(X.shape))
print ('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

# # 绘图可选
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(5, 5))
# fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
# # fig.tight_layout(pad=0.5)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # reshape the image
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
#     fig.suptitle("Label, image", fontsize=14)
# plt.show()

# 构建模型
tf.random.set_seed(1234)  # 设置随机种子以确保每次运行的结果是一致的
model = Sequential(
    [
        ### START CODE HERE ###
        tf.keras.Input(shape=(400,)),  # 输入层,输入数据的形状是400维
        Dense(100, activation='relu', name="L1"),  # 全连接层,100个神经元,使用ReLU激活函数,命名为"L1"
        Dense(75, activation='relu', name="L2"),  # 全连接层,75个神经元,使用ReLU激活函数,命名为"L2"
        Dense(10, activation='linear', name="L3"),  # 输出层,10个神经元,使用线性激活函数,命名为"L3"
        ### END CODE HERE ###
    ], name="my_model"
)  # 定义模型名称为"my_model"
model.summary()  # 打印模型的概述信息

# 配置模型的训练参数
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # 使用稀疏分类交叉熵作为损失函数,且输出是logits(即未经过softmax的原始输出)
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),  # 使用Adam优化器,并设置学习率为0.001
)

# 训练模型
history = model.fit(
    X_train, y_train,  # 使用X_train作为输入数据,y_train作为目标数据
    epochs=100  # 训练100轮
)


fig, axes = plt.subplots(20, 25, figsize=(20, 25))
fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):
    # Select random indices
    random_index = np.random.randint(X_test.shape[0])
    # Select rows corresponding to the random indices and
    # reshape the image
    X_random_reshaped = X_test[random_index].reshape((20, 20)).T
    # Display the image
    ax.imshow(X_random_reshaped, cmap='gray')
    # Predict using the Neural Network
    prediction = model.predict(X_test[random_index].reshape(1, 400))
    prediction_p = tf.nn.softmax(prediction)
    yhat = np.argmax(prediction_p)
    # Display the label above the image
    if y_test[random_index, 0] == yhat:
        ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10)
        ax.set_axis_off()
    else:
        ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10, color='red')
        ax.set_axis_off()

fig.suptitle("Label, yhat", fontsize=14)
plt.show()

# 给出预测的测试集误差
def evaluation(y_test, y_predict):
    accuracy=classification_report(y_test, y_predict,output_dict=True)['accuracy']
    s=classification_report(y_test, y_predict,output_dict=True)['weighted avg']
    precision=s['precision']
    recall=s['recall']
    f1_score=s['f1-score']
    #kappa=cohen_kappa_score(y_test, y_predict)
    return accuracy,precision,recall,f1_score #, kappa

y_pred=model.predict(X_test)
prediction_p = tf.nn.softmax(y_pred)
yhat = np.argmax(prediction_p, axis=1)
accuracy,precision,recall,f1_score=evaluation(y_test,yhat)

print("测试数据集准确率为:", accuracy)
print("测试数据集精确率为:", precision)
print("测试数据集召回率为:", recall)
print("测试数据集F1_score为:", f1_score)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/190725.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SASS的导入文件详细教程

文章目录 前言导入SASS文件使用SASS部分文件默认变量值嵌套导入原生的CSS导入后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:Sass和Less 🐱‍👓博主在前端领域还有很多知识和技术需要掌握,正在不断努…

电子学会C/C++编程等级考试2022年12月(二级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:数组逆序重放 将一个数组中的值按逆序重新存放。例如,原来的顺序为8,6,5,4,1。要求改为1,4,5,6,8。输入 输入为两行:第一行数组中元素的个数n(1输出 输出为一行:输出逆序后数组的整数,每两个整数之间用空格分隔。样例输入 …

Linux:docker基础操作(3)

docker的介绍 Linux:Docker的介绍(1)-CSDN博客https://blog.csdn.net/w14768855/article/details/134146721?spm1001.2014.3001.5502 通过yum安装docker Linux:Docker-yum安装(2)-CSDN博客https://blog.…

Mac 最佳使用指南

官方 Mac 使用手册如何在macOS系统安装根证书mac Terminal config proxy 【mac 终端配置代理】iPhone 安装 iOS 17公测版(Public Beta)macOS 最佳命令行客户端:iTermMac 配置与 Linux 互信Mac mini 外接移动硬盘无法写入或者无法显示的解决方法如何在 ma…

2016年五一杯数学建模B题能源总量控制下的城市工业企业协调发展问题解题全过程文档及程序

2016年五一杯数学建模 B题 能源总量控制下的城市工业企业协调发展问题 原题再现 能源是国民经济的重要物质基础,是工业企业发展的动力,但是过度的能源消耗,会破坏资源和环境,不利于经济的可持续发展。目前我国正处于经济转型的关键时期&…

牛客网刷题笔记四 链表节点k个一组翻转

NC50 链表中的节点每k个一组翻转 题目: 思路: 这种题目比较习惯现在草稿本涂涂画画链表处理过程。整体思路是赋值新的链表,用游离指针遍历原始链表进行翻转操作,当游离个数等于k时,就将翻转后的链表接到新的链表后&am…

线性模型加上正则化

使用弹性网络回归(Elastic Net Regression)算法来预测波士顿房屋价格。弹性网络回归是一种结合了L1和L2正则化惩罚的线性回归模型,能够处理高维数据和具有多重共线性的特征。弹性网络回归的目标函数包括数据拟合损失和正则化项: m…

前端学习--React(4)路由

一、认识ReactRouter 一个路径path对应一个组件component,当我们在浏览器中访问一个path,对应的组件会在页面进行渲染 创建路由项目 // 创建项目 npx create router-demo// 安装路由依赖包 npm i react-router-dom// 启动项目 npm run start 简单的路…

一文读懂MySQL基础与进阶

Mysql基础与进阶 Part1 基础操作 数据库操作 在MySQL中,您可以使用一些基本的命令来创建和删除数据库。以下是这些操作的示例: 创建数据库: 要创建一个新的数据库,您可以使用CREATE DATABASE命令。以下是示例: CREA…

C++ day36 贪心算法 无重叠区间 划分字母区间 合并区间

题目1:435 无重叠区间 题目链接:无重叠区间 对题目的理解 移除数组中的元素,使得区间互不重叠,保证移除的元素数量最少,数组中至少包含一个元素 贪心算法 局部最优,使得重叠区间的个数最大&#xff0c…

MyBatis Generator使用总结

MyBatis Generator使用总结 介绍具体使用数据准备插件引入配置条件构建讲解demo地址 介绍 MyBatis Generator (MBG) 是 MyBatis 的代码生成器。它能够根据数据库表,自动生成 java 实体类、dao 层接口(mapper 接口)及m…

【STM32单片机】自动售货机控制系统设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用STM32F103C8T6单片机控制器,使用OLED显示模块、矩阵按键模块、LED和蜂鸣器、继电器模块等。 主要功能: 系统运行后,OLED显示系统初始界面,可通过…

Java王者荣耀

第一步是创建项目 项目名自拟 第二部创建个包名 来规范class 然后是创建类 GameFrame 运行类 package com.sxt;import java.awt.Graphics; import java.awt.Image; import java.awt.Toolkit; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; im…

网络和Linux网络_5(应用层)HTTP协议(方法+报头+状态码)

目录 1. HTTP协议介绍 1.1 URL介绍 1.2 urlencode和urldecode 1.3 HTTP协议格式 1.4 HTTP的方法和报头和状态码 2. 代码验证HTTP协议格式 HttpServer.hpp 2.2 html正式测试 Util.hpp index.html 2.3 再看HTTP方法和报头和状态码 2.3.1 方法_GET和POST等 2.3.2 报头…

springboot函数式web

1.通常是路由(请求路径)业务 2.函数式web:路由和业务分离 一个configure类 配置bean 路由等 实现业务逻辑 这样实现了业务和路由的分离

代码常见问题

1. 前端页面出现404了: 1)那说明你该页面里面有某个接口地址(url)写错了,后台没有这个接口 2)你后台写了这个接口,但是后台忘了重启服务了,这样的话前端也映射不上的 所以404的时…

歌曲《兄弟情深》:歌手荆涛歌曲中的真挚情感

在人生的道路上,有时我们会遇到迷茫、失落、困惑等种种情境。而在这些时刻,身边有一个真挚的兄弟,其意义是无法估量的。歌手荆涛演唱的《兄弟情深》即是对这种深厚情感的美妙歌颂。 一、迷茫时的指引 “当我迷茫时,有你帮目标重新…

箱型图 Box Plot 数据分析的法宝

文章目录 一、箱形图的介绍二、六大因数三、Box plot的应用四、箱形图的优劣势五、图形拓展 一、箱形图的介绍 箱形图又称为盒须图、盒式图、盒状图或箱线图,是一种用作显示一组数据分散情况资料的统计图。因型状如箱子而得名。在各种领域也经常被使用,…

Sqlserver 在数据库‘master’中拒绝了Create Database的权限

解决方案 打开SqlServer Manament Studio软件,然后登陆 选择安全性->登录名->找到您当前的用户 在您的登陆名上,点击右键-属性,配置相应的服务器角色权限(这块需要勾选dbcreator的权限,这块如果不清楚还需要啥…

c语言:模拟实现各种字符串函数(2)

strncpy函数: 功能:拷贝指定长度的字符串a到字符串b中 代码模拟实现: //strncpy char* my_strncpy(char* dest, char* str,size_t num) {char* ret dest;assert(dest && str);//断言,如果其中有一个为空指针&#xff…