本系列文章主要是通过手写数字识别这一经典的CNN入门例子,来让大家熟悉深度学习框架Pytorch的基本操作,达到可以实现自己网络结构的目的。
本文为该系列文章的第一篇,主要介绍了手写数字数据集(MNIST)相关信息。
本文目录
- 本系列文章目录
- 一、MNIST数据集简介
- 1、图像数据集格式解析
- 2、标签数据集格式解析
- 二、代码实现
一、MNIST数据集简介
MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据库(官网地址:http://yann.lecun.com/exdb/mnist/),包含6万个训练样本以及1万个测试样本,数据文件格式如下图:
该数据集以二进制big-endian的方式进行存储。
big-endian(大端):较高的有效字节存放在较低的存储器地址,较低的有效字节存放在较高的存储器地址。
1、图像数据集格式解析
以训练数据集train-images-idx3-ubyte.gz为例,其在官网上的数据格式介绍如下:
解释说明:
- 1-4个字节:存储的是magic number,对应大小为2051;
- 5-8个字节:存储的是图像的数量,为10000;
- 9-12个字节:每张图片的行数(高度),为28;
- 13-16个字节:每张图片的列数(宽度),为28;
- 从第17个字节开始,每个字节存储一张图片中一个像素点的值(0~255);
2、标签数据集格式解析
该训练集对应的标签文件说明如下:
解释说明:
- 1-4个字节:存储的是magic number,对应大小为2051;
- 5-8个字节:存储的是图像的数量,为10000;
- 从第9个字节开始,每个字节存储一个标签(范围为0-9);
二、代码实现
(1) 将字节数组按大端模式进行解析,并转成整数,代码如下:
def bytes2int(byte_array):
return int.from_bytes(byte_array, byteorder='big', signed=False)
(2) 显示图像数据集基本信息:
def show_image_info_data(path):
with gzip.open(path) as pfile:
data = pfile.read()
print("数据类型:",type(data))
print("数据长度:",len(data))
print("magic number:", byte2int(data[0:4]))
print("图像数量:", int.from_bytes(data[4:8], byteorder='big', signed=False))
print("图像rows:", int.from_bytes(data[8:12], byteorder='big', signed=False))
print("图像cols:", int.from_bytes(data[12:16], byteorder='big', signed=False))
运行结果:
数据类型: <class 'bytes'>
数据长度: 47040016
magic number: 2051
图像数量: 60000
图像rows: 28
图像cols: 28
(3) 手写数字图像随机展示:
(4) 显示标签数据集基本信息:
def show_label_info_data(path):
with gzip.open(path) as pfile:
data = pfile.read()
print("数据类型:",type(data))
print("数据长度:",len(data))
print("magic number:", bytes2int(data[0:4]))
print("图像数量:", bytes2int(data[4:8]))
print("前10个标签为:")
label_list = []
for i in range(0,10):
label_list.append(str(bytes2int(data[8+i:9+i])))
print(",".join(label_list))
运行结果:
数据类型: <class 'bytes'>
数据长度: 60008
magic number: 2049
图像数量: 60000
前10个标签为:
5,0,4,1,9,2,1,3,1,4