news 2026/4/17 5:55:47

从零开始做一个最简单的CNN实例

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从零开始做一个最简单的CNN实例

一、CNN基本概念

站内已经有详细的教程

【深度学习】一文搞懂卷积神经网络(CNN)的原理(超详细)_卷积神经网络原理-CSDN博客、


二、完成一个简单实例需要掌握什么

1.张量基本操作

我们将张量基本操作分为 4 个层次

(1)创建(固定形状、随机初始化)

x = torch.randn(1, 1, 28, 28) # 模拟一张 28×28 单通道灰度图,batch=1 #1 N - batch 个数(一次喂 1 张图) #1 C - 通道数(单通道灰度,所以是 1;RGB 的话这里是 3) #28 H - 高(纵向像素行数) #28 W - 宽(横向像素列数) y = torch.zeros(2, 3, 4) # 全 0 z = torch.ones(3, 3) # 全 1 w = torch.eye(5) # 5×5 单位矩阵 v = torch.arange(12) # 0..11 的一维向量

(2)查看形状 & 维度压缩

print(x.shape) # torch.Size([1, 1, 28, 28]) print(x.size()) # 同上,shape 的别名 print(x.numel()) # 总元素个数 1×1×28×28=784 x_squeezed = x.squeeze() # 去掉所有长度为 1 的维度 → [28,28] x_unsqueezed = x_squeezed.unsqueeze(0).unsqueeze(0) # 加回去 → [1,1,28,28]

(3)切片 & 索引

# 取 batch 里第 0 张图,通道 0,高 10:18,宽 10:18 → [8,8] patch = x[0, 0, 10:18, 10:18] # 隔行隔列采样 sub = x[0, 0, ::2, ::2] # 14×14 # 布尔索引 mask = x > 0.5 # 同 shape 的 BoolTensor x_pos = x[mask] # 一维向量,只含 >0.5 的元素

(4)变形 / 压平

# 把图片拉成一维向量 flat = x.flatten() # 等价于 x.view(-1) → 784 flat = x.view(-1) # 同上 # 保留 batch,把每张图拉成 784 维特征 feat = x.view(1, -1) # → [1, 784] # 更安全的写法(内存连续) feat = x.reshape(1, -1) # 若内存不连续会自动复制
2.nn.Module骨架

nn.Module 骨架 = PyTorch 里“所有可训练模型”的唯一官方模板

模板要求:

(1)必须继承nn.Module

(2)所有“可学习参数”放进__init__,用nn.Xxx层封装

(3)计算图写在forward,PyTorch 自动帮你做反向传播

模板示例:

#最小可运行骨架 class Net(nn.Module): # ① 继承 def __init__(self): # ② 放层 super().__init__() # 初始化父类 self.conv1 = nn.Conv2d(1, 16, 3, padding=1) self.pool = nn.MaxPool2d(2) self.fc1 = nn.Linear(16*14*14, 10) def forward(self, x): # ③ 写计算图 x = self.pool(F.relu(self.conv1(x))) # [B,16,14,14] x = x.view(x.size(0), -1) # 压平 x = self.fc1(x) # [B,10] return x #使用 model = Net() # 实例化 out = model(torch.randn(4,1,28,28)) # 前向一次 print(out.shape) # torch.Size([4, 10])
3.卷积层nn.Conv2d的 5 个关键参数

nn.Conv2d通常这样写

nn.Conv2d( in_channels, # 必填 out_channels, # 必填 kernel_size, # 必填 stride=1, # 默认 1 padding=0, # 默认 0 ) nn.Conv2d(1, 16, 3) # 最简 nn.Conv2d(1, 16, 3, stride=2, padding=1) # 常用

(1) in_channels
输入特征图的“通道”张数。
例:灰度 1,RGB 彩色 3,上一层 feature map 64 张就当 64。

(2) out_channels
想让这一层“生”出多少张新特征图,就是 out_channels。
每个输出通道对应一个独立的卷积核,所以也是“卷积核个数”。

(3) kernel_size
卷积核的空间大小,常用 3(即 3×3)、5、7;也可给正方形 (3) 或长方形 (3,5)。

(4) stride
核在图上滑动的步长。stride=1 逐像素滑,stride=2 隔一跳一,会把输出尺寸减半(配合后面公式)。

(5) padding
在输入图四周补 0 的圈数。
padding=1 相当于给 28×28 外边再包一圈 0,变成 30×30,用来“保尺寸”或控制输出大小。

在卷积神经网络中,我们输入特征图,输入特征图的纵向像素大小我们称为输入高,经过cnn处理后的输出特征图的纵向像素大小我们称为输出高。

我们设输入高为H_in,输出高为H_out,那么我们可以根据以下公式得到输入高和输出高的关系:

H_out = (H_in + 2×padding − kernel_size) // stride + 1 (//表示向下取整)

例如:输入 32×32,卷积层

nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2)

H_out = (32 + 2×2 − 5) // 1 + 1 = 32

4.池化层nn.MaxPool2d

nn.MaxPool2d 就是“用一个小窗口在特征图上滑,每个窗口只留最大值”,用来快速砍掉冗余、把宽高减半。

示例:

pool = nn.MaxPool2d(kernel_size=2, stride=2) #窗口 2×2 步长 2

效果:
输入[B, C, 28, 28]→ 输出[B, C, 14, 14]
高宽直接砍半,通道数不变,没有可学习参数。

除了MaxPool2d还有AvgPool2d(平均池化层)

5.激活函数

激活函数就是“给线性输出加非线性”,没有它,再深的 CNN 也只是一层线性变换。

常见两种写法:

(1)函数式(最常用,直接调用)

x = F.relu(x) # 一行搞定,无需在 __init__ 注册

(2)层式(先实例化,再当层用)

self.relu = nn.ReLU() # 写在 __init__ x = self.relu(x) # 写在 forward

ReLU函数图像:

公式:

ReLU(x) = max(0, x)

  • 当 x ≥ 0 时,输出 x

  • 当 x < 0 时,输出 0

除ReLU外,常见激活函数还有Sigmoid和Tanh等。

6.把“特征图”拉平接全连接

把“特征图”拉平 = 把二维/三维的“图像”变成一维“向量”,才能塞进全连接层。

需要是先算好拉平后的节点数,再写Linear(节点数, 类别数)

节点数 = C × H × W

(1) 卷积+池化后看形状
x.shape # [B, 32, 7, 7]→ C=32, H=7, W=7
节点数 = 32×7×7 = 1568

(2) 拉平(两种写法等价)

x = torch.flatten(x, 1) # 从第1维开始压扁,保留batch # 或 x = x.view(x.size(0), -1) # x.size(0)就是B

结果:[B, 1568]

(3) 接全连接

self.fc = nn.Linear(1568, 10)
7.训练循环最小模板

会写“zero_grad→前向→loss→反向→step”四连击(zero_grad必须在backward之前,否则梯度会累加)

for data, target in loader: # 1. 取一批数据 optimizer.zero_grad() # 2. 清空旧梯度 output = model(data) # 3. 前向 → logits loss = F.cross_entropy(output, target) # 4. 算交叉熵 loss.backward() # 5. 反向传播,求梯度 optimizer.step() # 6. 用梯度更新权重

三、一个CNN的简单实例

import torch, torch.nn as nn, torch.nn.functional as F from matplotlib import pyplot as plt from torchvision import datasets, transforms from torch.utils.data import DataLoader # 1. 网络 ========================================================== class CNN(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 16, 3, padding=1) # 28x28 → 28x28 self.pool = nn.MaxPool2d(2) # 28x28 → 14x14 self.fc = nn.Linear(16 * 14 * 14, 10) # 3136 → 10类 def forward(self, x): x = self.pool(F.relu(self.conv(x))) # [B,16,14,14] x = x.view(x.size(0), -1) # [B,3136] return self.fc(x) # [B,10] # 2. 数据 ========================================================== transform = transforms.ToTensor() train_set = datasets.MNIST(root='.', train=True, download=True, transform=transform) test_set = datasets.MNIST(root='.', train=False, download=True, transform=transform) train_loader = DataLoader(train_set, batch_size=64, shuffle=True) test_loader = DataLoader(test_set, batch_size=64, shuffle=True) # 3. 训练准备 ====================================================== device = 'cuda' if torch.cuda.is_available() else 'cpu' model = CNN().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # 4. 训练循环 ====================================================== for epoch in range(3): # 3 个 epoch 意思一下 model.train() for x, y in train_loader: x, y = x.to(device), y.to(device) optimizer.zero_grad() out = model(x) loss = F.cross_entropy(out, y) loss.backward() optimizer.step() print(f'epoch {epoch+1} loss={loss.item():.4f}') # 5. 测试 ========================================================== model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in test_loader: x, y = x.to(device), y.to(device) pred = model(x).argmax(1) total += y.size(0) correct += (pred == y).sum().item() print(f'测试准确率: {100*correct/total:.2f}%') # =====随机 12 张预测可视化 ===== model.eval() sample_iter = iter(test_loader) images, labels = next(sample_iter) images, labels = images[:12].to(device), labels[:12] preds = model(images).argmax(1).cpu() fig, axes = plt.subplots(3, 4, figsize=(8,6)) for i, ax in enumerate(axes.ravel()): img = images[i].cpu().squeeze() ax.imshow(img, cmap='gray') ax.set_title(f'True:{labels[i].item()} Pred:{preds[i].item()}', color='green' if preds[i]==labels[i] else 'red') ax.axis('off') plt.tight_layout() plt.show()
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/10 23:20:02

从人工智障到得力助手:构建稳定AI Agent的5个核心原则

构建稳定AI Agent需遵循五大原则&#xff1a;1)定义清晰规格说明书(角色边界、技术栈、输入输出样本)&#xff1b;2)采用微服务化指令(Plan-Code-Test-Deploy)&#xff1b;3)实现状态持久化(记录思考过程、文件差异、任务清单)&#xff1b;4)合理使用上下文(文件检索、及时遗忘…

作者头像 李华
网站建设 2026/4/16 17:00:24

计算机小程序毕设实战-基于springboot+微信小程序的闲置物品处置平台的设计与实现 社区二手物品交易【完整源码+LW+部署说明+演示视频,全bao一条龙等】

博主介绍&#xff1a;✌️码农一枚 &#xff0c;专注于大学生项目实战开发、讲解和毕业&#x1f6a2;文撰写修改等。全栈领域优质创作者&#xff0c;博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围&#xff1a;&am…

作者头像 李华
网站建设 2026/4/16 9:05:05

stm32蜂鸣器实验

一、实验目的及要求1、掌握GPIO及其输出的识别方法。2、熟悉蜂鸣器和STM32微控制器的接口方法。3、了解蜂鸣器的工作原理及硬件电路。二、实验内容及原理蜂鸣器是一种一体化结构的电子讯响器&#xff0c;采用直流电压供电&#xff0c;广泛应用于计算机、打印机、 复印机、报警器…

作者头像 李华
网站建设 2026/4/10 7:38:03

12、网页元素盒子属性全解析

网页元素盒子属性全解析 在网页设计中,对元素盒子属性的控制至关重要,它能帮助我们精确地塑造页面上各个容器的外观和布局。下面将详细介绍一些关键的盒子属性,包括溢出(Overflow)、可见性(Visibility)、外边距(Margin)、边框(Borders)、内边距(Padding)以及背景…

作者头像 李华