手写数字识别实战:从MNIST数据集到神经网络推理
MNIST数据集简介
MNIST是机器学习领域最著名的数据集之一,包含0-9的手写数字图像:
- 训练图像:60,000张
- 测试图像:10,000张
- 图像尺寸:28×28像素(灰度图像)
- 像素值范围:0-255
数据加载与预处理
使用提供的mnist.py脚本加载数据:
importsys,os sys.path.append(os.pardir)fromdataset.mnistimportload_mnist# 加载MNIST数据集(x_train,t_train),(x_test,t_test)=load_mnist(flatten=True,# 将图像展开为一维数组normalize=True,# 将像素值正规化到0.0-1.0one_hot_label=False)print(f"训练数据形状:{x_train.shape}")# (60000, 784)print(f"训练标签形状:{t_train.shape}")# (60000,)关键参数说明:
- normalize: 是否将像素值除以255进行正规化
- flatten: 是否将28×28的图像展开为784维向量
- one_hot_label: 是否使用one-hot编码表示标签
神经网络推理实现
1. 网络结构
- 输入层:784个神经元(对应28×28像素)
- 隐藏层1:50个神经元
- 隐藏层2:100个神经元
- 输出层:10个神经元(对应0-9数字分类)
2. 核心函数
importpickleimportnumpyasnpdefsigmoid(x):"""Sigmoid激活函数"""return1/(1+np.exp(-x))defsoftmax(x):"""Softmax函数"""ifx.ndim==2:x=x-np.max(x,axis=1,keepdims=True)exp_x=np.exp(x)returnexp_x/np.sum(exp_x,axis=1,keepdims=True)else:x=x-np.max(x)exp_x=np.exp(x)returnexp_x/np.sum(exp_x)definit_network():"""加载预训练的网络参数"""withopen("sample_weight.pkl",'rb')asf:network=pickle.load(f)returnnetworkdefpredict(network,x):"""前向传播推理"""W1,W2,W3=network['W1'],network['W2'],network['W3']b1,b2,b3=network['b1'],network['b2'],network['b3']# 第1层a1=np.dot(x,W1)+b1 z1=sigmoid(a1)# 第2层a2=np.dot(z1,W2)+b2 z2=sigmoid(a2)# 输出层a3=np.dot(z2,W3)+b3 y=softmax(a3)returny单张图像推理
# 获取测试数据x_test,t_test=get_data()# 初始化网络network=init_network()# 单张图像推理accuracy_cnt=0foriinrange(len(x_test)):y=predict(network,x_test[i])# 预测概率分布p=np.argmax(y)# 取概率最高的索引ifp==t_test[i]:# 与真实标签比较accuracy_cnt+=1accuracy=float(accuracy_cnt)/len(x_test)print(f"识别精度:{accuracy:.4f}")# 约93.52%批处理优化
批处理能大幅提升计算效率,充分利用NumPy的矩阵运算优化:
defpredict_batch(network,x_batch):"""批量推理"""W1,W2,W3=network['W1'],network['W2'],network['W3']b1,b2,b3=network['b1'],network['b2'],network['b3']a1=np.dot(x_batch,W1)+b1 z1=sigmoid(a1)a2=np.dot(z1,W2)+b2 z2=sigmoid(a2)a3=np.dot(z2,W3)+b3 y=softmax(a3)returny# 批处理推理batch_size=100accuracy_cnt=0foriinrange(0,len(x_test),batch_size):# 提取批数据x_batch=x_test[i:i+batch_size]t_batch=t_test[i:i+batch_size]# 批量推理y_batch=predict_batch(network,x_batch)# 获取预测结果(axis=1表示按行取最大值索引)p_batch=np.argmax(y_batch,axis=1)# 统计正确数量accuracy_cnt+=np.sum(p_batch==t_batch)accuracy=float(accuracy_cnt)/len(x_test)print(f"批处理识别精度:{accuracy:.4f}")关键概念解析
1. 数据正规化(Normalization)
将输入数据的值调整到特定范围(如0.0-1.0),有助于:
- 提高训练稳定性
- 加速收敛速度
- 改善模型性能
2. 批处理(Batch Processing)
- 优势:减少数据读取开销,充分利用硬件并行计算能力
- 实现:通过
axis参数指定计算维度 - 效果:显著提升推理速度
3. 网络形状对应
# 网络各层参数形状W1.shape# (784, 50) 输入层→隐藏层1W2.shape# (50, 100) 隐藏层1→隐藏层2W3.shape# (100, 10) 隐藏层2→输出层# 数据处理形状变化x_test.shape# (10000, 784) 10000张图像,每张784维x_batch.shape# (100, 784) 100张图像批处理y_batch.shape# (100, 10) 100个预测结果,每个10维概率分布可视化图像数据
fromPILimportImageimportmatplotlib.pyplotaspltdefimg_show(img_array):"""显示MNIST图像"""img=img_array.reshape(28,28)# 从784维恢复为28×28plt.imshow(img,cmap='gray')plt.axis('off')plt.show()# 显示第一张训练图像img_show(x_train[0])print(f"标签:{t_train[0]}")总结与展望
通过本文,我们实现了:
- MNIST数据集的加载和预处理
- 神经网络的前向传播推理
- 单张图像和批量图像的分类
- 达到了约93.5%的识别精度
下一步优化方向:
- 调整网络结构(层数、神经元数量)
- 使用更好的优化算法
- 添加正则化防止过拟合
- 使用卷积神经网络(CNN)进一步提升精度
批处理技术不仅是推理阶段的优化手段,在训练阶段同样重要。后续学习深度神经网络时,我们将看到批处理如何与梯度下降算法结合,实现高效的学习过程。
思考题:如果将批大小从100改为1,会对计算效率和精度产生什么影响?为什么大多数情况下我们不使用批大小为1?