news 2026/5/23 19:09:43

从简单CNN到ResNet18:我是如何一步步把MNIST手写数字识别准确率提到99.5%以上的

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从简单CNN到ResNet18:我是如何一步步把MNIST手写数字识别准确率提到99.5%以上的

从简单CNN到ResNet18:我是如何一步步把MNIST手写数字识别准确率提到99.5%以上的

当第一次接触MNIST数据集时,我天真地以为用几层卷积神经网络就能轻松达到99%以上的准确率。现实很快给了我一记耳光——我的第一个简单CNN模型在测试集上只能达到97%左右的准确率。这促使我开启了一段持续优化的旅程,最终将准确率提升到99.5%以上。在这个过程中,我深刻体会到模型优化不是简单的堆叠层数,而是需要系统性地思考数据、架构和训练策略的协同作用。

1. 基础CNN模型搭建与初步优化

我的起点是一个典型的LeNet风格架构,包含两个卷积层和两个全连接层。这个基础版本在10个epoch后达到了97.11%的测试准确率,但存在几个明显问题:

class BasicCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) return self.fc2(x)

第一轮优化主要关注代码结构和训练效率:

  1. 使用nn.Sequential重构网络模块,提升可读性和复用性
  2. 添加批归一化层(BatchNorm)加速收敛
  3. 采用nn.Flatten()替代手动展平操作
  4. 设置ReLU的inplace参数为True减少内存占用

优化后的模型结构如下:

class ImprovedCNN(nn.Module): def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 10, 5), nn.MaxPool2d(2), nn.ReLU(True), nn.BatchNorm2d(10), nn.Conv2d(10, 20, 5), nn.MaxPool2d(2), nn.ReLU(True), nn.BatchNorm2d(20), nn.Flatten() ) self.classifier = nn.Linear(320, 10)

这些改动看似简单,却带来了显著提升:

优化项准确率提升训练时间变化
BatchNorm+0.8%-15%
结构化代码-代码可维护性↑
inplace ReLU内存占用↓20%

2. 训练策略的精细调整

当模型架构达到一个平台期后,我开始关注训练过程的优化。这一阶段的关键发现是:好的模型需要匹配好的训练策略

2.1 学习率动态调整

固定学习率就像用恒定的速度爬山——开始可能合适,但随着地形变化就会变得低效。我实现了学习率动态调整:

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=0.5, patience=3, threshold=0.0001 )

配合验证集准确率监控,当指标停滞时自动降低学习率。这种策略在第85个epoch帮助模型突破了99.5%的关键瓶颈。

2.2 数据增强的艺术

MNIST虽然是干净的数据集,但适度的数据增强能显著提升模型鲁棒性。我采用了以下增强组合:

transform = transforms.Compose([ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), transforms.RandomRotation((-10, 10)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])

增强策略对比实验

增强方式测试准确率过拟合程度
无增强98.9%中等
仅平移99.2%
平移+旋转99.5%很低
过度增强98.1%极低(欠拟合)

2.3 正则化技术组合

Dropout与权重衰减的协同使用产生了意想不到的效果:

self.classifier = nn.Sequential( nn.Linear(64*3*3, 256), nn.ReLU(), nn.Dropout(0.5), # 关键位置的高dropout率 nn.Linear(256, 10) )

配合权重初始化策略:

def weights_init(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') model.apply(weights_init)

3. 深度架构探索:从CNN到ResNet

当传统CNN的优化空间逐渐缩小,我开始尝试更先进的架构。ResNet的残差连接设计特别适合解决深度网络中的梯度消失问题。

3.1 残差块实现要点

class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) return F.relu(out)

3.2 自定义ResNet18架构

针对MNIST的28x28小尺寸特点,我对标准ResNet18做了适配调整:

class ResNetMNIST(nn.Module): def __init__(self, block, layers, num_classes=10): super().__init__() self.in_channels = 16 self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(16) self.layer1 = self._make_layer(block, 16, layers[0], stride=1) self.layer2 = self._make_layer(block, 32, layers[1], stride=2) self.layer3 = self._make_layer(block, 64, layers[2], stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1,1)) self.fc = nn.Linear(64, num_classes)

3.3 预训练模型适配

直接使用torchvision的ResNet需要处理通道数不匹配问题:

model = torchvision.models.resnet18(pretrained=False) model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

架构对比实验结果

模型类型参数量测试准确率训练时间(每epoch)
基础CNN50K97.1%12s
优化CNN55K99.1%15s
自定义ResNet181.1M99.3%45s
torchvision ResNet1811M98.4%60s

4. 工程实践与性能优化

在实际部署中,我发现几个影响模型效用的关键因素:

4.1 GPU加速技巧

# 数据加载优化 train_loader = DataLoader( dataset, batch_size=512, shuffle=True, num_workers=4, pin_memory=True # 减少CPU-GPU传输延迟 ) # 混合精度训练 scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

4.2 训练监控与分析

使用TensorBoard记录关键指标:

writer = SummaryWriter() writer.add_scalar('Loss/train', loss.item(), global_step) writer.add_scalar('Accuracy/test', accuracy, global_step) writer.add_histogram('conv1/weights', model.conv1.weight, global_step)

4.3 模型压缩与部署

达到目标准确率后,我尝试了模型量化:

quantized_model = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=torch.qint8 )

量化前后对比

指标原始模型量化模型
模型大小4.7MB1.2MB
推理延迟8.2ms3.1ms
准确率99.5%99.4%

这段优化之旅让我明白,在深度学习中,没有银弹式的解决方案。每个百分点的提升都需要数据、模型和训练策略的精心配合。当我在第85个epoch看到99.51%的测试准确率时,所有的调试和等待都变得值得。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/23 19:09:12

ops-nn MatMul 算子深度解读:从 Tiling 到 Cube/Vector 双缓冲

前言 昇腾CANN的ops-nn仓库里,MatMul算子是优化最深入的的一个。做模型适配的时候,很多人以为MatMul就是调个矩阵乘,没什么好调的,结果跑起来发现NPU利用率只有40%,同样的模型在A100上能跑满90%。问题不在NPU算力不够&…

作者头像 李华
网站建设 2026/5/23 19:08:15

AI-HF_Patch完全指南:解锁AI-Shoujo游戏的无限潜能

AI-HF_Patch完全指南:解锁AI-Shoujo游戏的无限潜能 【免费下载链接】AI-HF_Patch Automatically translate, uncensor and update AI-Shoujo! 项目地址: https://gitcode.com/gh_mirrors/ai/AI-HF_Patch 你是否正在寻找一款能够彻底提升AI-Shoujo游戏体验的增…

作者头像 李华
网站建设 2026/5/23 19:04:21

AT32F435飞控实战:如何利用其4MB Flash和288MHz主频解锁新功能

AT32F435飞控开发实战:解锁4MB Flash与288MHz主频的隐藏潜力 当大多数飞控开发者还在为STM32F405的1MB Flash捉襟见肘时,AT32F435RGT7带来的4MB存储空间和288MHz主频就像打开了新世界的大门。这款国产MCU不仅完美兼容原有生态,更在性能上实现…

作者头像 李华
网站建设 2026/5/23 18:56:03

体验分钟级接入为网站原型注入AI能力

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 体验分钟级接入为网站原型注入AI能力 在验证一个网站创意原型时,能否快速为其注入智能对话能力,往往决定了…

作者头像 李华
网站建设 2026/5/23 18:51:26

告别对齐烦恼:用PyTorch的CTCLoss搞定OCR和语音识别(附实战代码)

告别对齐烦恼:用PyTorch的CTCLoss搞定OCR和语音识别(附实战代码) 在序列学习任务中,数据对齐一直是困扰开发者的核心难题。想象一下这样的场景:当你试图从一张手写笔记图片中识别文字时,每个字符的位置、大…

作者头像 李华