news 2026/2/8 7:59:39

Day 42 图像数据与显存

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 42 图像数据与显存

文章目录

  • Day 42 · 图像数据与显存
    • 1. 图像数据基础
      • 1.1 灰度图像(MNIST)
      • 1.2 彩色图像(CIFAR-10)
    • 2. 图像相关的神经网络
      • 2.1 灰度图像 MLP(MNIST)
      • 2.2 彩色图像 MLP(CIFAR-10)
      • 2.3 batch_size 与模型结构
    • 3. 显存占用与 batch_size 选择
      • 3.1 模型参数与梯度
      • 3.2 优化器状态
      • 3.3 数据批量(batch_size)
      • 3.4 中间激活
      • 3.5 batch_size 的显存估算示例(SGD)
      • 3.6 实战建议

Day 42 · 图像数据与显存

本节聚焦图像数据(MNIST 与 CIFAR-10)的形状特征、MLP 模型定义以及 batch_size 与显存的关系,目标是把概念清晰拆解并配合可运行的示例。

1. 图像数据基础

  • 图像数据需要同时保留通道、宽和高,因此形状比结构化表格更复杂。
  • 了解灰度图和彩色图的差异是后续处理(预处理、建模、显存估算)的前提。

1.1 灰度图像(MNIST)

下面对比图像与结构化数据的形状:

数据类型形状示例说明
表格数据(1000, 5)1000 个样本、5 个特征,顺序不体现空间关系
MNIST 图像(1, 28, 28)1 个通道(灰度),28×28 像素保留空间位置信息

像素最初是 0~255 的uint8,通过transforms.ToTensor()会被缩放到 [0,1] 并转成float32,方便做梯度计算。

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoaderfromtorchvisionimportdatasets,transformsimporttorchvisionimportmatplotlib.pyplotaspltimportnumpyasnp# 为了得到可复现的随机结果torch.manual_seed(42)# MNIST 的预处理:归一化 + 标准化transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])# 下载/加载训练与测试集train_dataset=datasets.MNIST(root='./data',train=True,download=True,transform=transform)test_dataset=datasets.MNIST(root='./data',train=False,download=True,transform=transform)

随机挑一张手写数字观察像素与标签,顺便复习反标准化操作。

# 随机取一张图片并可视化sample_idx=int(torch.rand(1).item()*len(train_dataset))image,label=train_dataset[sample_idx]defshow_gray_tensor(img):"""反标准化并展示灰度图。"""img=img*0.3081+0.1307# 还原为 0~1 范围np_img=img.numpy()plt.imshow(np_img[0],cmap='gray')plt.title(f'MNIST Label:{label}')plt.axis('off')plt.show()show_gray_tensor(image)

使用Channels × Height × Width的格式是 PyTorch 的默认约定。下面确认一下单张图片的形状。

image.shape
torch.Size([1, 28, 28])

1.2 彩色图像(CIFAR-10)

  • 彩色图像通常具有 3 个通道(RGB),张量形状为(3, H, W)
  • 在可视化时,matplotlib期望(H, W, C),需要permute调整顺序。
  • 模型输入通常再额外加上一维 batch,形状(Batch, C, H, W)
# 下载 Cifar10 并展示一张彩色图片color_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])cifar_train=datasets.CIFAR10(root='./data',train=True,download=True,transform=color_transform)color_img,color_label=cifar_train[0]# Cifar10 的标签名称classes=['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']plt.imshow((color_img.permute(1,2,0)*0.5+0.5).numpy())# 反标准化并调整维度顺序plt.title(f'CIFAR-10 Label:{classes[color_label]}')plt.axis('off')plt.show()
100%|██████████| 170M/170M [00:30<00:00, 5.66MB/s]

2. 图像相关的神经网络

图像输入在送入全连接层前通常需要展平,通道维度也会体现在输入特征数上。这里先看两个基础 MLP,再讨论 batch_size 与模型的关系。

2.1 灰度图像 MLP(MNIST)

关键点:

  1. nn.Flatten()view负责把[1, 28, 28]拉直成 784 维。
  2. 输入层参数数量直接与像素数相关。
  3. 在推理或训练中,batch_size只会额外增加一个前缀维度[batch, 1, 28, 28]
classMnistMLP(nn.Module):def__init__(self):super().__init__()self.flatten=nn.Flatten()# 只展平成 784 维,不动 batch 维self.layer1=nn.Linear(28*28,128)self.relu=nn.ReLU()self.layer2=nn.Linear(128,10)# 10 个数字类别defforward(self,x):x=self.flatten(x)x=self.layer1(x)x=self.relu(x)x=self.layer2(x)returnx device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')mnist_model=MnistMLP().to(device)

参数计算示例(float32):

  • layer1:784 × 128 = 100,352个权重 +128个偏置 → 100,480 参数,约 402 KB。
  • layer2:128 × 10 = 1,280个权重 +10个偏置 → 1,290 参数,约 5 KB。
  • 总计 101,770 参数,梯度在反向传播时会复制一份,显存翻倍。

2.2 彩色图像 MLP(CIFAR-10)

彩色图片有 3 个通道,展平后特征数量直接变为3 × 32 × 32 = 3072,因此参数规模更大。

classColorMLP(nn.Module):def__init__(self,input_size=3*32*32,hidden_size=128,num_classes=10):super().__init__()self.flatten=nn.Flatten()self.fc1=nn.Linear(input_size,hidden_size)self.relu=nn.ReLU()self.fc2=nn.Linear(hidden_size,num_classes)defforward(self,x):x=self.flatten(x)x=self.fc1(x)x=self.relu(x)x=self.fc2(x)returnx color_model=ColorMLP().to(device)

参数量对比

  • 第一层:3072 × 128 = 393,216权重 + 128 偏置 → 393,344 参数。
  • 第二层与灰度模型相同:1,290 参数。
  • 总量 394,634 参数,仅第一层就比灰度模型多约 4 倍显存。

2.3 batch_size 与模型结构

  • 模型定义只需关注单个样本的形状(C, H, W)
  • torchsummary.summary(model, input_size=(1, 28, 28))也是如此,batch 维度无需传入。
  • DataLoader决定 batch_size,模型在 forward 中自动沿着 batch 维进行并行运算。
classBatchAgnosticMLP(nn.Module):def__init__(self):super().__init__()self.flatten=nn.Flatten()# nn.Flatten()会将每个样本的图像展平为 784 维向量,但保留 batch 维度。self.layer1=nn.Linear(784,128)self.relu=nn.ReLU()self.layer2=nn.Linear(128,10)defforward(self,x):x=self.flatten(x)# 输入:[batch_size, 1, 28, 28] → [batch_size, 784]x=self.layer1(x)# [batch_size, 784] → [batch_size, 128]x=self.relu(x)x=self.layer2(x)# [batch_size, 128] → [batch_size, 10]returnx batch_model=BatchAgnosticMLP()
组件是否依赖 batch_size说明
模型定义只需传入单样本形状
torchsummary / torchinfoinput_size=(C,H,W)即可
DataLoader在此设置batch_size=64等参数
训练循环for data, target in loader:自动批处理

3. 显存占用与 batch_size 选择

当数据集太大无法一次放入显存时,需要合理选择 batch_size。
如果 batch_size 过小,算力利用不足;过大则容易 OOM。下面拆分显存的主要组成部分。

# 示例数据加载器,方便后续讨论 batch_sizetrain_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=1000,shuffle=False)

3.1 模型参数与梯度

  • MNIST MLP 共有 101,770 个 float32 参数,单份参数占用约 403 KB。
  • 反向传播会为梯度再开辟一份同等大小的显存,训练时约 806 KB。
  • 基本换算:1 Byte = 8 bit1 KB = 1024 Byte,便于将参数量转成显存大小。

3.2 优化器状态

  • SGD 默认不额外保存动量,显存占用 = 参数 + 梯度。
  • Adam 需要同时保存一阶动量m与二阶动量v,额外占用约 2 倍参数大小。
  • 以 MNIST MLP 为例:Adam 需要额外 ~806 KB,总显存约 1.6 MB。

3.3 数据批量(batch_size)

  • 单张 MNIST 张量大小:1 × 28 × 28 × 4 Byte ≈ 3 KB
  • batch_size=64→ 数据约 192 KB;batch_size=1024→ 数据约 3 MB。
  • 数据越多,显存越高;需要兼顾显卡上限和训练稳定性。

3.4 中间激活

  • 对两层 MLP 来说,中间特征约为batch_size × 128 × 4 Byte
  • batch_size=1024时约 512 KB,规模不大,但在深层网络或高分辨率输入时会快速增长。

3.5 batch_size 的显存估算示例(SGD)

batch_size数据占用中间激活估算总占用*
64~192 KB~32 KB~1 MB
256~768 KB~128 KB~1.7 MB
1024~3 MB~512 KB~4.5 MB
4096~12 MB~2 MB~15 MB

*示例总量 = 参数 + 梯度 + 数据 + 激活,忽略优化器额外状态,仅用于估算量级。

3.6 实战建议

  • DataLoader默认batch_size=1,并不会自动一次读取全部数据。
  • 寻找合适 batch_size 的常见流程:从 16/32 起步 → 逐渐增大 → 监控nvidia-smi,在 OOM 之前略微回退。
  • 较大的 batch_size 可充分利用 GPU 并行能力,梯度更平滑;但在数据较少或显存受限时应灵活调整。

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/2/7 18:48:05

HTML如何设计JQuery支持大文件上传的拖拽功能?

2023年11月2日 星期四 阴有小雨 外包项目日志 - 企业级大文件传输系统Day3 项目背景与架构设计 客户是某地质勘探研究院&#xff0c;每日需上传**20GB**的勘探数据&#xff08;含激光扫描点云、地质剖面图等&#xff09;&#xff0c;要求&#xff1a; 文件夹结构保留&#xf…

作者头像 李华
网站建设 2026/2/5 18:09:24

yolo-ORBSLAM2复现

这个也是一个经典的问题了&#xff0c;我是想复现&#xff0c;再进行修改&#xff0c;因为我不使用yolo作为检测&#xff0c;但要先搞清楚检测框是怎么送入slam的&#xff0c;所以先复现各位大佬们的。主要参考&#xff1a; https://github.com/JinYoung6/orbslam_addsemantic…

作者头像 李华
网站建设 2026/2/7 5:49:42

python基于大数据技术的购房推荐系统的设计与实现

Python基于大数据技术的购房推荐系统的设计与实现是一个复杂但具有广泛应用前景的项目。以下是对该系统的详细介绍&#xff1a; 一、系统概述 购房推荐系统利用Python编程语言的强大功能和丰富的大数据技术&#xff0c;结合机器学习算法和推荐算法&#xff0c;对购房数据进行深…

作者头像 李华
网站建设 2026/2/6 20:25:17

介观交通流仿真软件:DynusT_(20).DynusT在实际项目中的应用

DynusT在实际项目中的应用 在上一节中&#xff0c;我们已经了解了DynusT的基本功能和使用方法。本节将详细介绍如何在实际项目中应用DynusT进行交通流仿真。我们将通过具体的案例来展示如何设置仿真参数、导入交通网络数据、模拟交通流量以及分析仿真结果。这些案例将涵盖城市交…

作者头像 李华
网站建设 2026/2/6 18:15:51

深入JVM(三):JVM执行引擎

JVM执行引擎 一、JVM前后端编译 前端编译&#xff1a;使用编译器将Java文件编译成class字节码文件后端编译&#xff1a;将class字节码文件编译成机器码指令java 跨平台直接理解&#xff1a;前端编译将java文件编译成class文件&#xff0c; 然后使用jvm&#xff08;后端编译&…

作者头像 李华
网站建设 2026/2/5 11:00:35

通信系统仿真:通信系统基础理论_(8).抗干扰技术

抗干扰技术 1. 引言 在通信系统中,信号的传输会受到各种干扰的影响,这些干扰可能来自自然环境(如电磁波、雷电等)或人为因素(如其他通信系统、电子设备等)。这些干扰会降低通信系统的性能,导致信号失真、误码率增加等问题。因此,研究和应用抗干扰技术是非常重要的。本…

作者头像 李华