mnist-数据集的学习-3.6手写数字识别
# coding: utf-8
# 指定文件编码为utf-8,防止中文注释/中文打印乱码
# 兼容Python版本:尝试导入网络请求库
try:
# Python3内置网络下载工具,用于远程下载MNIST数据集文件
import urllib.request
# 如果导入失败,说明是Python2环境,主动抛出错误提示用户
except ImportError:
raise ImportError('You should use Python 3.x')
# 绘图库,展示手写数字图片
import matplotlib.pyplot as plt
# 随机库,随机抽取样本图片
import random
# 处理文件路径相关工具
import os.path
# 解压gz压缩包,MNIST原始数据是.gz格式
import gzip
# 序列化/反序列化,用来把数据集保存为pkl缓存文件,加速下次读取
import pickle
# 文件系统操作:创建路径、判断文件是否存在等
import os
# 数值计算核心库,存储、处理图像像素数据
import numpy as np
# -------------------------- 全局字体配置,彻底解决中文方框乱码 --------------------------
plt.rcParams["font.sans-serif"] = ["SimHei"] # Windows黑体,支持中文
plt.rcParams["axes.unicode_minus"] = False # 解决图像负号显示异常
# MNIST官方数据集亚马逊镜像下载地址
url_base = 'https://ossci-datasets.s3.amazonaws.com/mnist/' # mirror site
# 字典存储4个MNIST数据文件的文件名
key_file = {
'train_img':'train-images-idx3-ubyte.gz', # 训练集图片压缩包
'train_label':'train-labels-idx1-ubyte.gz', # 训练集标签压缩包
'test_img':'t10k-images-idx3-ubyte.gz', # 测试集图片压缩包
'test_label':'t10k-labels-idx1-ubyte.gz' # 测试集标签压缩包
}
# 获取当前mnist.py脚本所在的文件夹绝对路径
dataset_dir = os.path.dirname(os.path.abspath(__file__))
# 拼接缓存文件完整路径:脚本目录下的mnist.pkl
save_file = dataset_dir + "/mnist.pkl"
# 训练样本总数6万张
train_num = 60000
# 测试样本总数1万张
test_num = 10000
# 单张图片原始维度:(通道数1, 高28像素, 宽28像素)
img_dim = (1, 28, 28)
# 单张图片展平后一维长度:28*28=784
img_size = 784
def _download(file_name):
"""私有函数:下载单个gz数据集文件,已存在则跳过"""
# 拼接文件完整本地路径:脚本目录 + 文件名
file_path = dataset_dir + "/" + file_name
# 判断本地是否已有该文件,存在直接返回,不用重复下载
if os.path.exists(file_path):
return
# 打印日志,提示正在下载哪个文件
print("Downloading " + file_name + " ... ")
# 远程下载文件,第一个参数下载链接,第二个参数本地保存路径
urllib.request.urlretrieve(url_base + file_name, file_path)
# 下载完成提示
print("Done")
def download_mnist():
"""公有调用函数:循环下载key_file里全部4个gz文件"""
# 遍历字典里4个文件名,逐个调用下载函数
for v in key_file.values():
_download(v)
def _load_label(file_name):
"""私有函数:读取标签gz文件,转换为numpy数组返回"""
# 拼接标签文件本地完整路径
file_path = dataset_dir + "/" + file_name
# 打印转换日志
print("Converting " + file_name + " to NumPy Array ...")
# 以二进制只读模式打开gz压缩文件
with gzip.open(file_path, 'rb') as f:
# 读取全部二进制数据,转uint8无符号整数,offset=8跳过文件头部8字节描述信息
labels = np.frombuffer(f.read(), np.uint8, offset=8)
# 转换完成提示
print("Done")
# 返回全部标签数组,形状(样本数,),数值0~9
return labels
def _load_img(file_name):
"""私有函数:读取图片gz文件,转换为展平后的numpy数组返回"""
# 拼接图片文件本地完整路径
file_path = dataset_dir + "/" + file_name
# 打印转换日志
print("Converting " + file_name + " to NumPy Array ...")
# 二进制只读打开gz图片压缩包
with gzip.open(file_path, 'rb') as f:
# 读取二进制像素,跳过头部16字节文件描述信息,转uint8像素值(0~255)
data = np.frombuffer(f.read(), np.uint8, offset=16)
# 将一维像素数组重塑:(样本数量, 784),每张图片展平为784维向量
data = data.reshape(-1, img_size)
# 转换完成提示
print("Done")
# 返回展平后的图片数组,形状(样本数,784)
return data
def _convert_numpy():
"""私有函数:整合4个文件,组装完整数据集字典"""
# 定义空字典存放全部图像、标签数据
dataset = {}
# 读取训练图片存入字典key
dataset['train_img'] = _load_img(key_file['train_img'])
# 读取训练标签存入字典key
dataset['train_label'] = _load_label(key_file['train_label'])
# 读取测试图片存入字典key
dataset['test_img'] = _load_img(key_file['test_img'])
# 读取测试标签存入字典key
dataset['test_label'] = _load_label(key_file['test_label'])
# 返回完整数据集字典
return dataset
def init_mnist():
"""初始化MNIST:下载文件 + 转numpy + 保存本地pkl缓存"""
# 第一步:下载全部4个gz原始文件
download_mnist()
# 第二步:把gz文件转成numpy数组,得到完整数据集
dataset = _convert_numpy()
# 提示开始生成缓存文件
print("Creating pickle file ...")
# 二进制写入模式打开缓存文件
with open(save_file, 'wb') as f:
# 将数据集序列化存入pkl,-1使用最高压缩协议
pickle.dump(dataset, f, -1)
# 缓存生成完成提示
print("Done!")
def _change_one_hot_label(X):
"""私有函数:普通数字标签转独热编码标签
输入:一维数组 [5,0,3...]
输出:二维数组 [[0,0,0,0,0,1,0,0,0,0], [1,0,0...], ...]
"""
# 创建全0矩阵:行数=样本总数,列数=数字类别10(0-9)
T = np.zeros((X.size, 10))
# 遍历每一个样本的索引和对应标签
for idx, row in enumerate(T):
# 将该行对应数字位置设为1,其余保持0,完成独热编码
row[X[idx]] = 1
# 返回独热编码二维标签数组
return T
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
"""
对外主接口函数:读取MNIST数据集(优先读取本地缓存,无缓存自动初始化下载)
Parameters
----------
normalize : bool
True=像素值从0~255归一化为0.0~1.0浮点数;False=保留0~255整数
flatten : bool
True=每张图片展平为一维784向量;False=保留四维形状(样本数,1,28,28)灰度图
one_hot_label : bool
True=标签转为10维独热编码;False=直接返回0~9数字
Returns
-------
(训练图像, 训练标签), (测试图像, 测试标签)
"""
# 判断缓存文件mnist.pkl是否存在,不存在则执行全套下载+转换初始化
if not os.path.exists(save_file):
init_mnist()
# 二进制读取缓存文件,加载序列化数据集
with open(save_file, 'rb') as f:
dataset = pickle.load(f)
# 归一化分支处理
if normalize:
# 遍历训练、测试图片两个key
for key in ('train_img', 'test_img'):
# 将uint8像素转为32位浮点型
dataset[key] = dataset[key].astype(np.float32)
# 全部像素除以255,映射到0~1区间
dataset[key] /= 255.0
# 独热编码标签分支处理
if one_hot_label:
# 训练标签转独热
dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
# 测试标签转独热
dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
# 不展平图片分支:恢复原图四维结构
if not flatten:
# 遍历训练、测试图片
for key in ('train_img', 'test_img'):
# 重塑维度:(样本数量, 通道1, 高28, 宽28)
dataset[key] = dataset[key].reshape(-1, 1, 28, 28)
# 返回 ((训练图,训练标签), (测试图,测试标签)) 二元组
return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])
# -------------------------- 简易神经网络工具函数 --------------------------
def sigmoid(x):
"""激活函数sigmoid"""
return 1 / (1 + np.exp(-x))
def simple_two_layer_net_predict(x, W1, b1, W2, b2):
"""两层全连接网络前向传播预测"""
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
return a2
# ===================== 方式1:直接在本文件运行测试 =====================
# 当前脚本直接执行时才运行下方测试代码;被其他文件import导入时不执行
if __name__ == '__main__':
print("========== 测试1:默认参数加载数据集 ==========")
# 调用主函数,默认配置:归一化+展平一维+普通数字标签
(x_train, t_train), (x_test, t_test) = load_mnist()
# 打印训练图片数组维度:60000张,每张784像素
print("训练图片 shape:", x_train.shape)
# 打印训练标签数组维度:60000个数字标签
print("训练标签 shape:", t_train.shape)
# 打印测试图片数组维度:10000张,每张784像素
print("测试图片 shape:", x_test.shape)
# 打印测试标签数组维度:10000个数字标签
print("测试标签 shape:", t_test.shape)
# 打印第一张训练图片对应的数字标签
print("第一张训练图片标签数字:", t_train[0])
print("\n========== 测试2:自定义参数加载 ==========")
# 自定义参数:不展平图片、标签转为独热编码
(x_train_2, t_train_2), (x_test_2, t_test_2) = load_mnist(flatten=False, one_hot_label=True)
# 打印未展平图片维度:(60000, 1通道, 28高, 28宽)
print("四维原图训练集 shape:", x_train_2.shape)
# 打印独热标签维度:60000行,每行10个0/1
print("独热编码训练标签 shape:", t_train_2.shape)
# 打印第一张图片的独热编码数组
print("第一张图片独热标签:", t_train_2[0])
# 从独热编码还原真实数字
print("独热标签对应数字:", np.argmax(t_train_2[0]))
print("\n========== 测试3:可视化手写数字图片 ==========")
# 1. 展示第一张图片
img = x_train[0].reshape(28,28)
plt.figure(figsize=(4,4))
plt.imshow(img, cmap="gray")
plt.title(f"手写数字:{t_train[0]}")
plt.show()
# 2. 随机抽取5张图片批量展示
print("随机抽取5张手写数字样本:")
for _ in range(5):
idx = random.randint(0, len(x_train)-1)
img = x_train[idx].reshape(28, 28)
label = t_train[idx]
plt.figure(figsize=(3,3))
plt.imshow(img, cmap="gray")
plt.title(f"手写数字:{label}")
plt.show()
print("\n========== 测试4:简易两层神经网络预测演示 ==========")
# 初始化两层网络权重与偏置
W1 = np.random.randn(784, 50)
b1 = np.zeros(50)
W2 = np.random.randn(50, 10)
b2 = np.zeros(10)
# 对第一张图片做预测
pred_out = simple_two_layer_net_predict(x_train[0], W1, b1, W2, b2)
pred_num = np.argmax(pred_out)
real_num = t_train[0]
print(f"网络预测数字:{pred_num},图片真实数字:{real_num}")
print("(权重随机初始化,预测结果不准确,后续训练后可提升识别准确率)")