mnist-数据集的学习-3.6手写数字识别

📅 2026/7/6 4:15:59 👁️ 阅读次数 📝 编程学习
mnist-数据集的学习-3.6手写数字识别

# coding: utf-8

# 指定文件编码为utf-8,防止中文注释/中文打印乱码

# 兼容Python版本:尝试导入网络请求库

try:

# Python3内置网络下载工具,用于远程下载MNIST数据集文件

import urllib.request

# 如果导入失败,说明是Python2环境,主动抛出错误提示用户

except ImportError:

raise ImportError('You should use Python 3.x')

# 绘图库,展示手写数字图片

import matplotlib.pyplot as plt

# 随机库,随机抽取样本图片

import random

# 处理文件路径相关工具

import os.path

# 解压gz压缩包,MNIST原始数据是.gz格式

import gzip

# 序列化/反序列化,用来把数据集保存为pkl缓存文件,加速下次读取

import pickle

# 文件系统操作:创建路径、判断文件是否存在等

import os

# 数值计算核心库,存储、处理图像像素数据

import numpy as np

# -------------------------- 全局字体配置,彻底解决中文方框乱码 --------------------------

plt.rcParams["font.sans-serif"] = ["SimHei"] # Windows黑体,支持中文

plt.rcParams["axes.unicode_minus"] = False # 解决图像负号显示异常

# MNIST官方数据集亚马逊镜像下载地址

url_base = 'https://ossci-datasets.s3.amazonaws.com/mnist/' # mirror site

# 字典存储4个MNIST数据文件的文件名

key_file = {

'train_img':'train-images-idx3-ubyte.gz', # 训练集图片压缩包

'train_label':'train-labels-idx1-ubyte.gz', # 训练集标签压缩包

'test_img':'t10k-images-idx3-ubyte.gz', # 测试集图片压缩包

'test_label':'t10k-labels-idx1-ubyte.gz' # 测试集标签压缩包

}

# 获取当前mnist.py脚本所在的文件夹绝对路径

dataset_dir = os.path.dirname(os.path.abspath(__file__))

# 拼接缓存文件完整路径:脚本目录下的mnist.pkl

save_file = dataset_dir + "/mnist.pkl"

# 训练样本总数6万张

train_num = 60000

# 测试样本总数1万张

test_num = 10000

# 单张图片原始维度:(通道数1, 高28像素, 宽28像素)

img_dim = (1, 28, 28)

# 单张图片展平后一维长度:28*28=784

img_size = 784


def _download(file_name):

"""私有函数:下载单个gz数据集文件,已存在则跳过"""

# 拼接文件完整本地路径:脚本目录 + 文件名

file_path = dataset_dir + "/" + file_name

# 判断本地是否已有该文件,存在直接返回,不用重复下载

if os.path.exists(file_path):

return

# 打印日志,提示正在下载哪个文件

print("Downloading " + file_name + " ... ")

# 远程下载文件,第一个参数下载链接,第二个参数本地保存路径

urllib.request.urlretrieve(url_base + file_name, file_path)

# 下载完成提示

print("Done")


def download_mnist():

"""公有调用函数:循环下载key_file里全部4个gz文件"""

# 遍历字典里4个文件名,逐个调用下载函数

for v in key_file.values():

_download(v)


def _load_label(file_name):

"""私有函数:读取标签gz文件,转换为numpy数组返回"""

# 拼接标签文件本地完整路径

file_path = dataset_dir + "/" + file_name

# 打印转换日志

print("Converting " + file_name + " to NumPy Array ...")

# 以二进制只读模式打开gz压缩文件

with gzip.open(file_path, 'rb') as f:

# 读取全部二进制数据,转uint8无符号整数,offset=8跳过文件头部8字节描述信息

labels = np.frombuffer(f.read(), np.uint8, offset=8)

# 转换完成提示

print("Done")

# 返回全部标签数组,形状(样本数,),数值0~9

return labels


def _load_img(file_name):

"""私有函数:读取图片gz文件,转换为展平后的numpy数组返回"""

# 拼接图片文件本地完整路径

file_path = dataset_dir + "/" + file_name

# 打印转换日志

print("Converting " + file_name + " to NumPy Array ...")

# 二进制只读打开gz图片压缩包

with gzip.open(file_path, 'rb') as f:

# 读取二进制像素,跳过头部16字节文件描述信息,转uint8像素值(0~255)

data = np.frombuffer(f.read(), np.uint8, offset=16)

# 将一维像素数组重塑:(样本数量, 784),每张图片展平为784维向量

data = data.reshape(-1, img_size)

# 转换完成提示

print("Done")

# 返回展平后的图片数组,形状(样本数,784)

return data


def _convert_numpy():

"""私有函数:整合4个文件,组装完整数据集字典"""

# 定义空字典存放全部图像、标签数据

dataset = {}

# 读取训练图片存入字典key

dataset['train_img'] = _load_img(key_file['train_img'])

# 读取训练标签存入字典key

dataset['train_label'] = _load_label(key_file['train_label'])

# 读取测试图片存入字典key

dataset['test_img'] = _load_img(key_file['test_img'])

# 读取测试标签存入字典key

dataset['test_label'] = _load_label(key_file['test_label'])

# 返回完整数据集字典

return dataset


def init_mnist():

"""初始化MNIST:下载文件 + 转numpy + 保存本地pkl缓存"""

# 第一步:下载全部4个gz原始文件

download_mnist()

# 第二步:把gz文件转成numpy数组,得到完整数据集

dataset = _convert_numpy()

# 提示开始生成缓存文件

print("Creating pickle file ...")

# 二进制写入模式打开缓存文件

with open(save_file, 'wb') as f:

# 将数据集序列化存入pkl,-1使用最高压缩协议

pickle.dump(dataset, f, -1)

# 缓存生成完成提示

print("Done!")


def _change_one_hot_label(X):

"""私有函数:普通数字标签转独热编码标签

输入:一维数组 [5,0,3...]

输出:二维数组 [[0,0,0,0,0,1,0,0,0,0], [1,0,0...], ...]

"""

# 创建全0矩阵:行数=样本总数,列数=数字类别10(0-9)

T = np.zeros((X.size, 10))

# 遍历每一个样本的索引和对应标签

for idx, row in enumerate(T):

# 将该行对应数字位置设为1,其余保持0,完成独热编码

row[X[idx]] = 1

# 返回独热编码二维标签数组

return T


def load_mnist(normalize=True, flatten=True, one_hot_label=False):

"""

对外主接口函数:读取MNIST数据集(优先读取本地缓存,无缓存自动初始化下载)

Parameters

----------

normalize : bool

True=像素值从0~255归一化为0.0~1.0浮点数;False=保留0~255整数

flatten : bool

True=每张图片展平为一维784向量;False=保留四维形状(样本数,1,28,28)灰度图

one_hot_label : bool

True=标签转为10维独热编码;False=直接返回0~9数字

Returns

-------

(训练图像, 训练标签), (测试图像, 测试标签)

"""

# 判断缓存文件mnist.pkl是否存在,不存在则执行全套下载+转换初始化

if not os.path.exists(save_file):

init_mnist()

# 二进制读取缓存文件,加载序列化数据集

with open(save_file, 'rb') as f:

dataset = pickle.load(f)

# 归一化分支处理

if normalize:

# 遍历训练、测试图片两个key

for key in ('train_img', 'test_img'):

# 将uint8像素转为32位浮点型

dataset[key] = dataset[key].astype(np.float32)

# 全部像素除以255,映射到0~1区间

dataset[key] /= 255.0

# 独热编码标签分支处理

if one_hot_label:

# 训练标签转独热

dataset['train_label'] = _change_one_hot_label(dataset['train_label'])

# 测试标签转独热

dataset['test_label'] = _change_one_hot_label(dataset['test_label'])

# 不展平图片分支:恢复原图四维结构

if not flatten:

# 遍历训练、测试图片

for key in ('train_img', 'test_img'):

# 重塑维度:(样本数量, 通道1, 高28, 宽28)

dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

# 返回 ((训练图,训练标签), (测试图,测试标签)) 二元组

return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])

# -------------------------- 简易神经网络工具函数 --------------------------

def sigmoid(x):

"""激活函数sigmoid"""

return 1 / (1 + np.exp(-x))

def simple_two_layer_net_predict(x, W1, b1, W2, b2):

"""两层全连接网络前向传播预测"""

a1 = np.dot(x, W1) + b1

z1 = sigmoid(a1)

a2 = np.dot(z1, W2) + b2

return a2

# ===================== 方式1:直接在本文件运行测试 =====================

# 当前脚本直接执行时才运行下方测试代码;被其他文件import导入时不执行

if __name__ == '__main__':

print("========== 测试1:默认参数加载数据集 ==========")

# 调用主函数,默认配置:归一化+展平一维+普通数字标签

(x_train, t_train), (x_test, t_test) = load_mnist()

# 打印训练图片数组维度:60000张,每张784像素

print("训练图片 shape:", x_train.shape)

# 打印训练标签数组维度:60000个数字标签

print("训练标签 shape:", t_train.shape)

# 打印测试图片数组维度:10000张,每张784像素

print("测试图片 shape:", x_test.shape)

# 打印测试标签数组维度:10000个数字标签

print("测试标签 shape:", t_test.shape)

# 打印第一张训练图片对应的数字标签

print("第一张训练图片标签数字:", t_train[0])

print("\n========== 测试2:自定义参数加载 ==========")

# 自定义参数:不展平图片、标签转为独热编码

(x_train_2, t_train_2), (x_test_2, t_test_2) = load_mnist(flatten=False, one_hot_label=True)

# 打印未展平图片维度:(60000, 1通道, 28高, 28宽)

print("四维原图训练集 shape:", x_train_2.shape)

# 打印独热标签维度:60000行,每行10个0/1

print("独热编码训练标签 shape:", t_train_2.shape)

# 打印第一张图片的独热编码数组

print("第一张图片独热标签:", t_train_2[0])

# 从独热编码还原真实数字

print("独热标签对应数字:", np.argmax(t_train_2[0]))

print("\n========== 测试3:可视化手写数字图片 ==========")

# 1. 展示第一张图片

img = x_train[0].reshape(28,28)

plt.figure(figsize=(4,4))

plt.imshow(img, cmap="gray")

plt.title(f"手写数字:{t_train[0]}")

plt.show()

# 2. 随机抽取5张图片批量展示

print("随机抽取5张手写数字样本:")

for _ in range(5):

idx = random.randint(0, len(x_train)-1)

img = x_train[idx].reshape(28, 28)

label = t_train[idx]

plt.figure(figsize=(3,3))

plt.imshow(img, cmap="gray")

plt.title(f"手写数字:{label}")

plt.show()

print("\n========== 测试4:简易两层神经网络预测演示 ==========")

# 初始化两层网络权重与偏置

W1 = np.random.randn(784, 50)

b1 = np.zeros(50)

W2 = np.random.randn(50, 10)

b2 = np.zeros(10)

# 对第一张图片做预测

pred_out = simple_two_layer_net_predict(x_train[0], W1, b1, W2, b2)

pred_num = np.argmax(pred_out)

real_num = t_train[0]

print(f"网络预测数字:{pred_num},图片真实数字:{real_num}")

print("(权重随机初始化,预测结果不准确,后续训练后可提升识别准确率)")