news 2026/5/15 14:27:31

day 44 简单 CNN 实战

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
day 44 简单 CNN 实战

文章目录

    • Day 44 · 简单 CNN 实战
      • 今日目标
      • 0. 回顾与动机
      • 1. 数据准备与增强
      • 2. 模型设计路线
        • Batch Normalization 的作用
        • 特征图尺寸推导
      • 3. 损失函数、优化器与学习率调度
      • 4. 训练与可视化流程
      • 5. 启动训练
      • 6. MLP vs CNN 快速对比

Day 44 · 简单 CNN 实战

用一次完整的深度学习流程来体验从 MLP 迈向 CNN:数据增强 → 卷积建模 → 训练调度 → 结果分析。

今日目标

  • 回顾 MLP 的局限,明确为何要引入卷积结构
  • 掌握一套可复用的图像数据增强配置
  • 手写一个 3 个卷积块 + 分类头的基础 CNN,并梳理关键超参数
  • 通过可视化观察学习率调度、损失与准确率的演化

0. 回顾与动机

MLP 会将3 × 32 × 32 = 3072 3\times32\times32=30723×32×32=3072个输入映射到 0 个像素完全展平,导致局部空间关系被破坏,只能依赖权重自行“猜测”边缘与纹理信息。卷积层通过局部感受野、权值共享与池化压缩,能够自动提取从低级到高级的多层特征,因此在同等参数量下通常能获得更高的准确率。

如果不清楚什么是卷积神经网络,可以学习参考:https://docs.qq.com/doc/DTFNucmRzc3RlRk5k

1. 数据准备与增强

CIFAR-10 官方给出了每个通道的均值(0.4914, 0.4822, 0.4465)和标准差(0.2023, 0.1994, 0.2010),直接用于Normalize。训练集添加随机裁剪/翻转/颜色抖动/旋转来模拟真实拍摄噪声,测试集只做张量化 + 标准化以保持分布一致。

小技巧:增强操作不会改变 DataLoader 的 batch 数量,而是对样本在线扰动,让每个 epoch 都看到“新”图像。

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorchvisionimportdatasets,transformsfromtorch.utils.dataimportDataLoaderimportmatplotlib.pyplotaspltimportnumpyasnp# Matplotlib 的中文显示与负号配置plt.rcParams["font.family"]=["SimHei"]plt.rcParams['axes.unicode_minus']=False# 检查可用设备device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')print(f"使用设备:{device}")# —— 数据增强 ——train_transform=transforms.Compose([transforms.RandomCrop(32,padding=4),# 先填充再裁剪,模拟平移transforms.RandomHorizontalFlip(),# 左右翻转提升鲁棒性transforms.ColorJitter(# 控制亮度/对比度/饱和度/色调brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1),transforms.RandomRotation(15),# 轻微旋转匹配拍摄角度变化transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])# 测试集仅做确定性处理test_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])# —— 数据集与加载器 ——train_dataset=datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform)test_dataset=datasets.CIFAR10(root='./data',train=False,download=False,transform=test_transform)batch_size=64train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
使用设备: cuda

2. 模型设计路线

  1. 卷积块Conv2d → BatchNorm → ReLU → MaxPool提取多尺度特征并逐步减小空间维度。
  2. 分类头:将卷积输出展平成向量,接 1~2 层全连接 + Dropout,输出 10 维 logits。
  3. 正则化:BatchNorm 稳定分布、Dropout 降低 co-adaptation。
Batch Normalization 的作用
  • 以 batch 为单位对每个通道做标准化,缓解“内部协变量偏移”;
  • 学习gammabeta进行再缩放,允许网络恢复到任意分布;
  • 由于分布稳定,可使用更大学习率,加速收敛并起到轻微正则化效果。
特征图尺寸推导

假设输入尺寸为 (32\times32):

  1. 第 1 个卷积块:保持尺寸,池化后得到32×16×16
  2. 第 2 个卷积块:池化后得到64×8×8
  3. 第 3 个卷积块:池化后得到128×4×4,展平即2048维向量。
    这些中间输出就是“特征图”,可借助 Grad-CAM 等方法做可视化解释。
classCNN(nn.Module):def__init__(self):super().__init__()# 第 1 段:输入 3 通道 -> 32 通道,保持空间尺寸,池化对 2self.conv_block1=nn.Sequential(nn.Conv2d(3,32,kernel_size=3,padding=1),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2)# 32x32 -> 16x16)# 第 2 段:32 -> 64 通道,继续减半空间尺寸self.conv_block2=nn.Sequential(nn.Conv2d(32,64,kernel_size=3,padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2)# 16x16 -> 8x8)# 第 3 段:64 -> 128 通道,再次减半self.conv_block3=nn.Sequential(nn.Conv2d(64,128,kernel_size=3,padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2)# 8x8 -> 4x4)# 分类头:2048 -> 512 -> 10self.classifier=nn.Sequential(nn.Linear(128*4*4,512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512,10))defforward(self,x):x=self.conv_block1(x)x=self.conv_block2(x)x=self.conv_block3(x)x=x.view(x.size(0),-1)# 展平成 (batch, 2048)returnself.classifier(x)model=CNN().to(device)

3. 损失函数、优化器与学习率调度

  • 损失函数CrossEntropyLoss直接接收 logits 和标签。
  • 优化器:Adam 适合快速验证,配合适度权重衰减稳定训练。
  • 调度器ReduceLROnPlateau在验证指标停滞时自动降低学习率,比固定周期衰减更智能。
criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=1e-3,weight_decay=1e-4)scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=3,factor=0.5,verbose=True)
/home/ubuntu24/anaconda3/envs/torch-gpu/lib/python3.13/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate. warnings.warn(

4. 训练与可视化流程

训练函数需要做几件事:

  1. 记录每个 batch 的损失,以观察局部波动情况;
  2. 每个 epoch 统计训练/测试损失与准确率,供调度器和图表使用;
  3. 训练结束后绘制迭代级别与 epoch 级别的曲线,帮助定位过拟合或欠拟合。
defplot_iter_losses(losses,indices):plt.figure(figsize=(10,4))plt.plot(indices,losses,label='Iteration Loss')plt.xlabel('Iteration (batch index)')plt.ylabel('Loss')plt.title('Batch 级别的损失波动')plt.grid(True)plt.tight_layout()plt.legend()plt.show()defplot_epoch_metrics(train_acc,test_acc,train_loss,test_loss):epochs=range(1,len(train_acc)+1)plt.figure(figsize=(12,4))plt.subplot(1,2,1)plt.plot(epochs,train_acc,label='Train Acc')plt.plot(epochs,test_acc,label='Test Acc')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('训练 / 测试准确率')plt.grid(True)plt.legend()plt.subplot(1,2,2)plt.plot(epochs,train_loss,label='Train Loss')plt.plot(epochs,test_loss,label='Test Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('训练 / 测试损失')plt.grid(True)plt.legend()plt.tight_layout()plt.show()deftrain(model,train_loader,test_loader,criterion,optimizer,scheduler,device,epochs):model.train()iter_losses,iter_indices=[],[]train_acc_history,test_acc_history=[],[]train_loss_history,test_loss_history=[],[]forepochinrange(epochs):running_loss,correct,total=0.0,0,0forbatch_idx,(data,target)inenumerate(train_loader):data,target=data.to(device),target.to(device)optimizer.zero_grad()logits=model(data)loss=criterion(logits,target)loss.backward()optimizer.step()iter_losses.append(loss.item())iter_indices.append(epoch*len(train_loader)+batch_idx+1)running_loss+=loss.item()preds=logits.argmax(dim=1)total+=target.size(0)correct+=preds.eq(target).sum().item()if(batch_idx+1)%100==0:avg_loss=running_loss/(batch_idx+1)print(f"Epoch{epoch+1}/{epochs}| Batch{batch_idx+1}/{len(train_loader)}| Loss:{loss.item():.4f}| AvgLoss:{avg_loss:.4f}")epoch_train_loss=running_loss/len(train_loader)epoch_train_acc=100.0*correct/total train_loss_history.append(epoch_train_loss)train_acc_history.append(epoch_train_acc)model.eval()test_loss=0.0correct_test,total_test=0,0withtorch.no_grad():fordata,targetintest_loader:data,target=data.to(device),target.to(device)logits=model(data)loss=criterion(logits,target)test_loss+=loss.item()preds=logits.argmax(dim=1)total_test+=target.size(0)correct_test+=preds.eq(target).sum().item()epoch_test_loss=test_loss/len(test_loader)epoch_test_acc=100.0*correct_test/total_test test_loss_history.append(epoch_test_loss)test_acc_history.append(epoch_test_acc)scheduler.step(epoch_test_loss)print(f"Epoch{epoch+1}/{epochs}| Train Acc:{epoch_train_acc:.2f}% | Test Acc:{epoch_test_acc:.2f}% | LR:{optimizer.param_groups[0]['lr']:.5f}")model.train()plot_iter_losses(iter_losses,iter_indices)plot_epoch_metrics(train_acc_history,test_acc_history,train_loss_history,test_loss_history)returntest_acc_history[-1]

5. 启动训练

默认跑 20 个 epoch 就能看到 70%+ 的测试准确率,如需更高精度,可:

  • 增加卷积层深度或使用更大的特征维度;
  • 加入CosineAnnealingLRMixup等更强的数据/调度策略;
  • 使用AutoAugmentCutMix等进阶增强方式。
epochs=20print('开始训练 CNN ...')final_acc=train(model,train_loader,test_loader,criterion,optimizer,scheduler,device,epochs)print(f'训练结束,最终测试准确率:{final_acc:.2f}%')
开始训练 CNN ... Epoch 1/20 | Batch 100/782 | Loss: 1.9105 | AvgLoss: 2.0645 Epoch 1/20 | Batch 200/782 | Loss: 1.7723 | AvgLoss: 1.9431 Epoch 1/20 | Batch 300/782 | Loss: 1.4931 | AvgLoss: 1.8741 Epoch 1/20 | Batch 400/782 | Loss: 1.6812 | AvgLoss: 1.8228 Epoch 1/20 | Batch 500/782 | Loss: 1.7031 | AvgLoss: 1.7801 Epoch 1/20 | Batch 600/782 | Loss: 1.5635 | AvgLoss: 1.7461 Epoch 1/20 | Batch 700/782 | Loss: 1.5986 | AvgLoss: 1.7165 Epoch 1/20 | Train Acc: 37.31% | Test Acc: 50.18% | LR: 0.00100 ... Epoch 20/20 | Train Acc: 74.02% | Test Acc: 79.80% | LR: 0.00100

训练结束,最终测试准确率: 79.80%

6. MLP vs CNN 快速对比

模型参数规模 (示例)特征提取方式计算效率典型准确率
MLP (3072→512→256→10)≈3.7M全连接、打平输入、无空间感知需遍历全部权重,难以复用局部模式50%~55%
CNN (3×3 卷积×3 + FC)≈0.1M局部卷积 + 池化 + 权值共享以卷积核滑动复用计算,GPU 友好70%~80%

不同模型的参数量仅为估算,实际取决于具体层数、通道数和分类头设计。

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/12 1:17:26

Langchain-Chatchat微服务拆分可行性分析

Langchain-Chatchat微服务拆分可行性分析 在企业智能化转型加速的今天,越来越多组织希望借助大语言模型(LLM)构建专属的知识问答系统。然而,通用模型面对私有知识库时常常“答非所问”,而直接调用云端API又存在数据泄露…

作者头像 李华
网站建设 2026/5/8 5:58:34

Sublime Text Markdown Preview终极使用指南:高效写作与实时预览

Sublime Text Markdown Preview终极使用指南:高效写作与实时预览 【免费下载链接】sublimetext-markdown-preview markdown preview and build plugin for sublime text 2/3 项目地址: https://gitcode.com/gh_mirrors/su/sublimetext-markdown-preview Subl…

作者头像 李华
网站建设 2026/5/13 14:20:50

Langchain-Chatchat能否替代传统搜索引擎?企业内部知识检索新范式

Langchain-Chatchat:企业内部知识检索的新范式 在智能办公日益普及的今天,一个看似简单却困扰无数企业的难题正变得愈发突出:员工每天花多少时间在翻找文档? 一份制度文件藏在共享盘第三级目录,技术手册分散在多个部门…

作者头像 李华
网站建设 2026/5/15 9:17:19

DepthCrafter:无相机姿态的视频深度生成

DepthCrafter:无相机姿态的视频深度生成 【免费下载链接】DepthCrafter DepthCrafter是一款开源工具,能为开放世界视频生成时间一致性强、细节丰富的长深度序列,无需相机姿态或光流等额外信息。助力视频深度估计任务,效果直观可通…

作者头像 李华
网站建设 2026/5/8 9:17:07

仿写prompt:esbuild跨域与代理配置技术指南

仿写prompt:esbuild跨域与代理配置技术指南 【免费下载链接】esbuild An extremely fast bundler for the web 项目地址: https://gitcode.com/GitHub_Trending/es/esbuild 任务要求 请基于提供的参考文章,创作一篇关于esbuild跨域与代理配置的技…

作者头像 李华