一、Dataset类的_getitem_和_len_方法
在 PyTorch 中,torch.utils.data.Dataset 是所有自定义数据集的抽象基类,它规定了数据集必须实现两个核心方法:__len__ 和 __getitem__。这两个方法是 DataLoader 加载数据的基础,决定了数据集的 “大小” 和 “如何按索引取样本”。
Dataset 类的核心作用:Dataset 类的设计目标是封装数据集的逻辑(如数据读取、预处理、标签映射等),对外暴露统一的接口,让 DataLoader 可以无感地加载、批量处理、打乱数据。
自定义数据集时,必须继承 Dataset 并实现 __len__ 和 __getitem__(否则实例化会抛出 NotImplementedError)。
__getitem__ 方法
1. 核心作用
根据传入的索引 index,返回该索引对应的单个样本(通常是 “特征 + 标签” 的组合)。
DataLoader 会循环调用该方法(按索引取样本),并将多个样本拼接成批次,是数据加载的核心逻辑。
2. 实现规则
- 入参:仅接收一个整数 index(范围:0 ≤ index < __len__());
- 返回值:格式灵活,常见形式:
元组:(feature, label)(最常用);
字典:{"feature": feature, "label": label}(多模态 / 多特征场景更易读);
单个值:仅特征(无监督学习场景)。
3. 关键注意点
- 索引合法性:DataLoader 通常会保证 index 在 [0, __len__()-1] 范围内,但自定义时建议避免越界;
- 预处理逻辑:数据预处理(如归一化、图像裁剪、文本分词)建议放在该方法中(DataLoader 支持多进程加载,预处理并行执行效率更高);
- 数据类型:返回的特征建议转为 torch.Tensor(方便后续模型计算),标签可根据需求保留 int/float 或转为 Tensor。
__len__ 方法
1. 核心作用
返回数据集的总样本数量,DataLoader 依赖该方法知道数据集的 “边界”,例如:
- 计算迭代轮次(总样本数 / 批次大小);
- 随机打乱时确定索引范围。
2. 实现规则
- 无入参,仅返回一个非负整数;
- 必须与数据集的实际样本数一致(否则会导致索引越界或数据加载不全)。
二、Dataloader类
DataLoader 核心作用:
- 自动按 batch_size 从 Dataset 中取多个样本,拼接成批次数据(如把多个 (feature, label) 拼接成 (batch_feature, batch_label));
- 支持数据打乱(shuffle),避免模型过拟合;
- 支持多进程加载(num_workers),提升数据读取效率(尤其适合大数据集 / 硬盘读取场景);
- 灵活的批次拼接逻辑(collate_fn),适配不同类型数据(如变长文本、多模态数据);
- 支持内存锁页(pin_memory),加速数据从 CPU 到 GPU 的传输。
| 参数名 | 作用与说明 | 默认值 |
| dataset | 必须传入的 Dataset 实例(自定义 / 内置均可),DataLoader 基于它取样本 | — |
| batch_size | 每个批次的样本数量 | 1 |
| shuffle | 是否在每个 epoch 开始时打乱数据索引(训练集建议 True,测试集建议 False) | False |
| num_workers | 用于数据加载的子进程数(多进程加速);0 表示主进程加载 | 0 |
| 若数据集总数不能被 batch_size 整除,是否丢弃最后一个不完整批次 | False |
| collate_fn | 自定义批次拼接函数,用于处理样本的拼接逻辑(如变长文本、自定义数据结构) | None |
| pin_memory | 是否将加载的数据存入 CUDA 锁页内存(GPU 训练时设为 True,加速传输) | False |
| timeout | 数据加载的超时时间(秒),防止子进程挂起 | 0 |
| sampler | 自定义索引采样策略(优先级高于 shuffle) | None |
| batch_sampler | 自定义批次索引采样策略(与 batch_size/shuffle/sampler 互斥) | None |
DataLoader 工作原理:
- 索引生成:根据 Dataset.__len__() 获取总索引范围,结合 shuffle/sampler 生成索引序列;
- 批次切分:将索引序列按 batch_size 切分成多个批次索引(如 [0,1], [2,3], [4]);
- 样本读取:对每个批次的索引,调用 Dataset.__getitem__(index) 获取单个样本;
- 批次拼接:通过 collate_fn 将多个单个样本拼接成批次数据(默认拼接成 Tensor 矩阵);
- 多进程加速:num_workers > 0 时,子进程并行执行 “样本读取 + 预处理”,主进程仅负责拼接和分发。
核心结论
Dataset类:定义数据的内容和格式(即“如何获取单个样本”),包括:
- 数据存储路径/来源(如文件路径、数据库查询)。
- 原始数据的读取方式(如图像解码为PIL对象、文本读取为字符串)。
- 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过`transform`参数实现)。
- 返回值格式(如`(image_tensor, label)`)。
DataLoader类:定义数据的加载方式和批量处理逻辑(即“如何高效批量获取数据”),包括:
- 批量大小(batch_size)。
- 是否打乱数据顺序(shuffle)。
三、MNIST手写数字数据集
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具 from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块 import matplotlib.pyplot as plt # 设置随机种子,确保结果可复现 torch.manual_seed(42) # 1. 数据预处理,该写法非常类似于管道pipeline # transforms 模块提供了一系列常用的图像预处理操作 # 先归一化,再标准化 transform = transforms.Compose([ transforms.ToTensor(), # 转换为张量并归一化到[0,1] transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用 ]) # 2. 加载MNIST数据集,如果没有会自动下载 train_dataset = datasets.MNIST( root='./data', train=True, download=True, transform=transform ) test_dataset = datasets.MNIST( root='./data', train=False, transform=transform )作业
# 1. 导入必要库 import torch from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np # 2. 固定随机种子(可选,保证结果一致) torch.manual_seed(42) # 3. 定义数据预处理(CIFAR-10专用均值/标准差) # 说明:CIFAR-10的全局均值和标准差是行业公认值,标准化用 transform = transforms.Compose([ transforms.ToTensor(), # 转Tensor:把0-255的PIL图片→0-1的Tensor,维度[C, H, W](3,32,32) transforms.Normalize( mean=[0.4914, 0.4822, 0.4465], # R/G/B三通道均值 std=[0.2470, 0.2435, 0.2616] # R/G/B三通道标准差 ) ]) # 4. 加载CIFAR-10数据集(自动下载) # 训练集 train_dataset = datasets.CIFAR10( root='./data', # 数据集保存路径 train=True, # 加载训练集(False则加载测试集) download=True, # 本地没有则自动下载 transform=transform # 应用预处理 ) # 5. 关键:提取单张图片并可视化 # 5.1 取数据集第0个样本(特征Tensor + 标签) img_tensor, label_idx = train_dataset[0] # img_tensor.shape = [3,32,32],label_idx是0-9的整数 print(f"图片Tensor形状:{img_tensor.shape}") # 输出:torch.Size([3, 32, 32]) print(f"图片标签索引:{label_idx}") # 输出:6(对应类别“青蛙”) # 5.2 定义CIFAR-10类别名称(对应索引0-9) cifar10_classes = [ '飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车' ] print(f"图片对应类别:{cifar10_classes[label_idx]}") # 输出:青蛙 # 5.3 预处理还原(因为Normalize后数值不在0-1,需要反归一化才能正常显示) # 反归一化公式:img = (img_tensor * std) + mean mean = np.array([0.4914, 0.4822, 0.4465]) std = np.array([0.2470, 0.2435, 0.2616]) # Tensor→numpy,维度从[C,H,W]→[H,W,C](matplotlib需要这个顺序) img_np = img_tensor.numpy().transpose((1, 2, 0)) img_np = img_np * std + mean # 反归一化 img_np = np.clip(img_np, 0, 1) # 确保数值在0-1之间(避免归一化后溢出) # 5.4 可视化图片 plt.figure(figsize=(4, 4)) # 设置图片大小 plt.imshow(img_np) # 显示图片 plt.title(f"Label: {cifar10_classes[label_idx]} (索引{label_idx})") plt.axis('off') # 隐藏坐标轴 plt.show()@浙大疏锦行