news 2025/12/25 4:38:09

DAY 40 Dataset类和Dataloader类

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
DAY 40 Dataset类和Dataloader类

一、Dataset类的_getitem_和_len_方法

在 PyTorch 中,torch.utils.data.Dataset 是所有自定义数据集的抽象基类,它规定了数据集必须实现两个核心方法:__len__ 和 __getitem__。这两个方法是 DataLoader 加载数据的基础,决定了数据集的 “大小” 和 “如何按索引取样本”。

Dataset 类的核心作用:Dataset 类的设计目标是封装数据集的逻辑(如数据读取、预处理、标签映射等),对外暴露统一的接口,让 DataLoader 可以无感地加载、批量处理、打乱数据。
自定义数据集时,必须继承 Dataset 并实现 __len__ 和 __getitem__(否则实例化会抛出 NotImplementedError)。

__getitem__ 方法
1. 核心作用
根据传入的索引 index,返回该索引对应的单个样本(通常是 “特征 + 标签” 的组合)。
DataLoader 会循环调用该方法(按索引取样本),并将多个样本拼接成批次,是数据加载的核心逻辑。
2. 实现规则

  • 入参:仅接收一个整数 index(范围:0 ≤ index < __len__());
  • 返回值:格式灵活,常见形式:

元组:(feature, label)(最常用);

字典:{"feature": feature, "label": label}(多模态 / 多特征场景更易读);

单个值:仅特征(无监督学习场景)。

3. 关键注意点

  • 索引合法性:DataLoader 通常会保证 index 在 [0, __len__()-1] 范围内,但自定义时建议避免越界;
  • 预处理逻辑:数据预处理(如归一化、图像裁剪、文本分词)建议放在该方法中(DataLoader 支持多进程加载,预处理并行执行效率更高);
  • 数据类型:返回的特征建议转为 torch.Tensor(方便后续模型计算),标签可根据需求保留 int/float 或转为 Tensor。

__len__ 方法
1. 核心作用
返回数据集的总样本数量,DataLoader 依赖该方法知道数据集的 “边界”,例如:

  • 计算迭代轮次(总样本数 / 批次大小);
  • 随机打乱时确定索引范围。

2. 实现规则

  • 无入参,仅返回一个非负整数;
  • 必须与数据集的实际样本数一致(否则会导致索引越界或数据加载不全)。

二、Dataloader类

DataLoader 核心作用:

  1. 自动按 batch_size 从 Dataset 中取多个样本,拼接成批次数据(如把多个 (feature, label) 拼接成 (batch_feature, batch_label));
  2. 支持数据打乱(shuffle),避免模型过拟合;
  3. 支持多进程加载(num_workers),提升数据读取效率(尤其适合大数据集 / 硬盘读取场景);
  4. 灵活的批次拼接逻辑(collate_fn),适配不同类型数据(如变长文本、多模态数据);
  5. 支持内存锁页(pin_memory),加速数据从 CPU 到 GPU 的传输。
参数名作用与说明默认值
dataset必须传入的 Dataset 实例(自定义 / 内置均可),DataLoader 基于它取样本
batch_size每个批次的样本数量1
shuffle是否在每个 epoch 开始时打乱数据索引(训练集建议 True,测试集建议 False)False
num_workers用于数据加载的子进程数(多进程加速);0 表示主进程加载0
  1. drop_last
若数据集总数不能被 batch_size 整除,是否丢弃最后一个不完整批次False
collate_fn自定义批次拼接函数,用于处理样本的拼接逻辑(如变长文本、自定义数据结构)None
pin_memory是否将加载的数据存入 CUDA 锁页内存(GPU 训练时设为 True,加速传输)False
timeout数据加载的超时时间(秒),防止子进程挂起0
sampler自定义索引采样策略(优先级高于 shuffle)None
batch_sampler自定义批次索引采样策略(与 batch_size/shuffle/sampler 互斥)None

DataLoader 工作原理:

  1. 索引生成:根据 Dataset.__len__() 获取总索引范围,结合 shuffle/sampler 生成索引序列;
  2. 批次切分:将索引序列按 batch_size 切分成多个批次索引(如 [0,1], [2,3], [4]);
  3. 样本读取:对每个批次的索引,调用 Dataset.__getitem__(index) 获取单个样本;
  4. 批次拼接:通过 collate_fn 将多个单个样本拼接成批次数据(默认拼接成 Tensor 矩阵);
  5. 多进程加速:num_workers > 0 时,子进程并行执行 “样本读取 + 预处理”,主进程仅负责拼接和分发。

核心结论

Dataset类:定义数据的内容和格式(即“如何获取单个样本”),包括:

- 数据存储路径/来源(如文件路径、数据库查询)。

- 原始数据的读取方式(如图像解码为PIL对象、文本读取为字符串)。

- 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过`transform`参数实现)。

- 返回值格式(如`(image_tensor, label)`)。

DataLoader类:定义数据的加载方式和批量处理逻辑(即“如何高效批量获取数据”),包括:

- 批量大小(batch_size)。

- 是否打乱数据顺序(shuffle)。

三、MNIST手写数字数据集

import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具 from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块 import matplotlib.pyplot as plt # 设置随机种子,确保结果可复现 torch.manual_seed(42) # 1. 数据预处理,该写法非常类似于管道pipeline # transforms 模块提供了一系列常用的图像预处理操作 # 先归一化,再标准化 transform = transforms.Compose([ transforms.ToTensor(), # 转换为张量并归一化到[0,1] transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用 ]) # 2. 加载MNIST数据集,如果没有会自动下载 train_dataset = datasets.MNIST( root='./data', train=True, download=True, transform=transform ) test_dataset = datasets.MNIST( root='./data', train=False, transform=transform )

作业

# 1. 导入必要库 import torch from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np # 2. 固定随机种子(可选,保证结果一致) torch.manual_seed(42) # 3. 定义数据预处理(CIFAR-10专用均值/标准差) # 说明:CIFAR-10的全局均值和标准差是行业公认值,标准化用 transform = transforms.Compose([ transforms.ToTensor(), # 转Tensor:把0-255的PIL图片→0-1的Tensor,维度[C, H, W](3,32,32) transforms.Normalize( mean=[0.4914, 0.4822, 0.4465], # R/G/B三通道均值 std=[0.2470, 0.2435, 0.2616] # R/G/B三通道标准差 ) ]) # 4. 加载CIFAR-10数据集(自动下载) # 训练集 train_dataset = datasets.CIFAR10( root='./data', # 数据集保存路径 train=True, # 加载训练集(False则加载测试集) download=True, # 本地没有则自动下载 transform=transform # 应用预处理 ) # 5. 关键:提取单张图片并可视化 # 5.1 取数据集第0个样本(特征Tensor + 标签) img_tensor, label_idx = train_dataset[0] # img_tensor.shape = [3,32,32],label_idx是0-9的整数 print(f"图片Tensor形状:{img_tensor.shape}") # 输出:torch.Size([3, 32, 32]) print(f"图片标签索引:{label_idx}") # 输出:6(对应类别“青蛙”) # 5.2 定义CIFAR-10类别名称(对应索引0-9) cifar10_classes = [ '飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车' ] print(f"图片对应类别:{cifar10_classes[label_idx]}") # 输出:青蛙 # 5.3 预处理还原(因为Normalize后数值不在0-1,需要反归一化才能正常显示) # 反归一化公式:img = (img_tensor * std) + mean mean = np.array([0.4914, 0.4822, 0.4465]) std = np.array([0.2470, 0.2435, 0.2616]) # Tensor→numpy,维度从[C,H,W]→[H,W,C](matplotlib需要这个顺序) img_np = img_tensor.numpy().transpose((1, 2, 0)) img_np = img_np * std + mean # 反归一化 img_np = np.clip(img_np, 0, 1) # 确保数值在0-1之间(避免归一化后溢出) # 5.4 可视化图片 plt.figure(figsize=(4, 4)) # 设置图片大小 plt.imshow(img_np) # 显示图片 plt.title(f"Label: {cifar10_classes[label_idx]} (索引{label_idx})") plt.axis('off') # 隐藏坐标轴 plt.show()

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2025/12/17 1:03:00

二维码QRCode的属性

TQRCode组件生成二维码的核心属性配置&#xff0c;TQRCode是 Delphi 中常用的二维码生成组件&#xff08;多为第三方 / QRCode 库封装&#xff09;&#xff0c;以下逐一解析每个属性的功能、取值规则和实际应用场景&#xff1a; 一、核心属性解析 属性名代码赋值功能详解取值…

作者头像 李华
网站建设 2025/12/17 1:02:32

LobeChat知乎内容分发策略

LobeChat在知乎内容生态中的智能生成与分发实践 当知乎上一个关于“2024年大模型技术趋势”的提问悄然登上热榜&#xff0c;却迟迟没有高质量回答时&#xff0c;背后可能正有一套自动化系统在悄然运转——它监听话题热度、调用AI模型检索最新论文、整合权威观点&#xff0c;并在…

作者头像 李华
网站建设 2025/12/23 20:01:47

FCC认证是否有有效期?有哪些认证方式?需要审厂吗?

FCC 认证的有效期、认证方式与审厂要求&#xff0c;会根据认证的类型&#xff08;FCC ID/SDoC&#xff09;有明确区别&#xff0c;以下是详细说明&#xff1a;有效期规则FCC 认证本身没有固定的有效期限制&#xff0c;但是会受两个因素影响有效性&#xff1a;产品的设计变更&am…

作者头像 李华
网站建设 2025/12/17 1:02:19

电池做CE认证的流程是怎样的?

依据欧盟新电池法规&#xff08;EU&#xff09;2023/1542&#xff0c;电池 CE 认证需覆盖安全、环保等多维度合规要求&#xff0c;流程清晰且环节明确&#xff0c;具体步骤如下&#xff1a;前期规划与资料准备确定合规标准&#xff1a;先明确电池对应的适用标准&#xff0c;比如…

作者头像 李华
网站建设 2025/12/17 1:00:19

LobeChat 360搜索推广策略

LobeChat&#xff1a;构建私有化AI交互入口的技术实践 在生成式AI浪潮席卷各行各业的今天&#xff0c;一个现实问题摆在开发者和企业面前&#xff1a;如何在享受大语言模型强大能力的同时&#xff0c;不牺牲数据安全与系统可控性&#xff1f;市面上的主流对话产品虽然体验流畅&…

作者头像 李华