ResNet18持续学习方案:云端保存checkpoint,随时继续训练
引言
当你训练一个深度学习模型时,最让人崩溃的事情莫过于:电脑突然死机、断电或者程序崩溃,导致几个小时的训练进度全部丢失。这种情况就像写论文时忘记保存,突然蓝屏一样令人绝望。特别是对于ResNet18这样的经典图像分类模型,训练时间往往需要数小时甚至数天。
本文将介绍一种云端保存checkpoint的解决方案,让你可以随时中断训练,随时继续训练,再也不用担心训练进度丢失。这种方法特别适合研究团队在长时间训练ResNet18时使用,即使本地电脑不稳定,也能保证训练进度安全可靠。
我们将使用PyTorch框架,结合CSDN星图镜像广场提供的预置环境,一步步教你如何实现这个方案。即使你是深度学习新手,也能轻松上手。
1. 为什么需要云端保存checkpoint?
在深入技术细节之前,我们先理解为什么这个方案如此重要。
想象你在玩一个没有存档功能的游戏,每次退出都要从头开始。深度学习训练也是如此,如果没有保存中间状态(checkpoint),一旦训练中断,所有进度都会丢失。云端保存checkpoint可以解决以下痛点:
- 硬件不稳定:本地电脑可能因为各种原因(断电、过热等)崩溃
- 资源限制:本地GPU可能无法一次性完成长时间训练
- 团队协作:多人可以共享同一个训练进度继续工作
- 模型选择:可以基于不同checkpoint测试不同阶段的模型性能
ResNet18虽然比更深的ResNet模型轻量,但在大数据集上训练仍然需要相当长的时间。使用云端checkpoint方案,你可以安心睡觉,第二天接着训练,就像什么都没发生过一样。
2. 环境准备与镜像选择
2.1 选择预置镜像
为了快速开始,我们推荐使用CSDN星图镜像广场提供的PyTorch预置镜像,它已经包含了所有必要的依赖:
- PyTorch框架:支持ResNet18模型训练
- CUDA支持:充分利用GPU加速
- 常用工具:如Jupyter Notebook方便调试
这个镜像开箱即用,省去了繁琐的环境配置过程,特别适合新手快速上手。
2.2 启动GPU实例
在CSDN星图平台上:
- 搜索并选择PyTorch预置镜像
- 根据你的需求选择GPU型号(如RTX 3090)
- 启动实例,等待环境准备完成
启动后,你可以通过Web终端或Jupyter Notebook访问这个环境。
3. ResNet18训练基础代码
让我们从基础的ResNet18训练代码开始,这是后续添加checkpoint功能的基础。
import torch import torch.nn as nn import torch.optim as optim from torchvision import models, transforms, datasets # 1. 准备数据集 transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dataset = datasets.ImageFolder('path/to/train', transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) # 2. 初始化模型 model = models.resnet18(pretrained=True) num_classes = 10 # 根据你的数据集调整 model.fc = nn.Linear(model.fc.in_features, num_classes) # 3. 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 4. 训练循环 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) for epoch in range(10): # 假设训练10个epoch running_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader): inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 100 == 99: # 每100个batch打印一次 print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}') running_loss = 0.0这段代码完成了基本的ResNet18训练流程,但还没有实现checkpoint保存功能。接下来我们将逐步增强它。
4. 实现云端checkpoint保存与恢复
4.1 添加checkpoint保存功能
我们需要在训练过程中定期保存模型状态,包括:
- 模型参数
- 优化器状态
- 当前epoch
- 训练损失等指标
修改训练循环,添加保存逻辑:
import os import torch def save_checkpoint(state, filename='checkpoint.pth.tar'): torch.save(state, filename) # 这里可以添加代码将checkpoint上传到云存储 for epoch in range(10): running_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader): # ... 原有训练代码 ... # 每个epoch结束后保存checkpoint save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': running_loss / len(train_loader) }, filename=f'checkpoint_epoch{epoch+1}.pth.tar')4.2 从checkpoint恢复训练
当需要继续训练时,我们可以从保存的checkpoint恢复:
def load_checkpoint(model, optimizer, filename): if os.path.isfile(filename): checkpoint = torch.load(filename) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] loss = checkpoint['loss'] return start_epoch, loss else: print(f"Checkpoint file {filename} not found!") return 0, float('inf') # 使用示例 start_epoch, _ = load_checkpoint(model, optimizer, 'checkpoint_epoch5.pth.tar') for epoch in range(start_epoch, 10): # 继续训练...4.3 集成云存储服务
为了真正实现"云端"保存,我们需要将checkpoint文件上传到云存储。这里以阿里云OSS为例:
import oss2 def upload_to_oss(local_file, remote_file): # 配置你的OSS信息 auth = oss2.Auth('your-access-key-id', 'your-access-key-secret') bucket = oss2.Bucket(auth, 'your-endpoint', 'your-bucket-name') bucket.put_object_from_file(remote_file, local_file) print(f"Uploaded {local_file} to OSS as {remote_file}") def download_from_oss(remote_file, local_file): auth = oss2.Auth('your-access-key-id', 'your-access-key-secret') bucket = oss2.Bucket(auth, 'your-endpoint', 'your-bucket-name') bucket.get_object_to_file(remote_file, local_file) print(f"Downloaded {remote_file} from OSS to {local_file}")然后修改save_checkpoint和load_checkpoint函数:
def save_checkpoint(state, filename='checkpoint.pth.tar'): torch.save(state, filename) upload_to_oss(filename, f'resnet18_checkpoints/{filename}') def load_checkpoint(model, optimizer, filename): download_from_oss(f'resnet18_checkpoints/{filename}', filename) if os.path.isfile(filename): # ... 原有加载代码 ...5. 完整训练脚本与最佳实践
现在我们将所有部分组合成一个完整的、可投入生产的训练脚本:
import os import torch import torch.nn as nn import torch.optim as optim from torchvision import models, transforms, datasets import oss2 # 配置 CHECKPOINT_DIR = 'checkpoints' os.makedirs(CHECKPOINT_DIR, exist_ok=True) # OSS配置 OSS_CONFIG = { 'access_key_id': 'your-access-key-id', 'access_key_secret': 'your-access-key-secret', 'endpoint': 'your-endpoint', 'bucket_name': 'your-bucket-name' } # 云存储工具函数 def get_oss_bucket(): auth = oss2.Auth(OSS_CONFIG['access_key_id'], OSS_CONFIG['access_key_secret']) return oss2.Bucket(auth, OSS_CONFIG['endpoint'], OSS_CONFIG['bucket_name']) def upload_to_oss(local_file, remote_file): bucket = get_oss_bucket() bucket.put_object_from_file(remote_file, local_file) print(f"Uploaded {local_file} to OSS as {remote_file}") def download_from_oss(remote_file, local_file): bucket = get_oss_bucket() try: bucket.get_object_to_file(remote_file, local_file) print(f"Downloaded {remote_file} from OSS to {local_file}") return True except: print(f"File {remote_file} not found in OSS") return False # Checkpoint工具函数 def save_checkpoint(state, filename): filepath = os.path.join(CHECKPOINT_DIR, filename) torch.save(state, filepath) upload_to_oss(filepath, f'resnet18_checkpoints/{filename}') print(f"Checkpoint saved to {filepath} and uploaded to OSS") def load_checkpoint(model, optimizer, filename): filepath = os.path.join(CHECKPOINT_DIR, filename) if download_from_oss(f'resnet18_checkpoints/{filename}', filepath) or os.path.isfile(filepath): checkpoint = torch.load(filepath) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] loss = checkpoint['loss'] print(f"Loaded checkpoint from {filepath}, starting from epoch {start_epoch}") return start_epoch, loss else: print(f"No checkpoint found at {filepath}") return 0, float('inf') # 主训练函数 def train_resnet18(data_dir, num_epochs=10, resume_from=None): # 准备数据 transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dataset = datasets.ImageFolder(data_dir, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) # 初始化模型 model = models.resnet18(pretrained=True) num_classes = len(train_dataset.classes) model.fc = nn.Linear(model.fc.in_features, num_classes) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 设备设置 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) # 恢复训练 start_epoch = 0 if resume_from: start_epoch, _ = load_checkpoint(model, optimizer, resume_from) # 训练循环 for epoch in range(start_epoch, num_epochs): model.train() running_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader): inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 100 == 99: avg_loss = running_loss / 100 print(f'Epoch [{epoch+1}/{num_epochs}], Batch {i+1}, Loss: {avg_loss:.3f}') running_loss = 0.0 # 每个epoch结束后保存checkpoint checkpoint_name = f'resnet18_epoch{epoch+1}.pth.tar' save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': running_loss / len(train_loader), 'classes': train_dataset.classes }, checkpoint_name) print('Training complete!') # 使用示例 if __name__ == '__main__': # 从头开始训练 # train_resnet18('path/to/your/dataset', num_epochs=20) # 从第5个epoch继续训练 train_resnet18('path/to/your/dataset', num_epochs=20, resume_from='resnet18_epoch5.pth.tar')5.1 最佳实践建议
- 保存频率:
- 小型数据集:每个epoch保存一次
大型数据集:每N个batch保存一次(如每1000个batch)
命名规范:
- 包含模型名称、epoch/batch编号、日期等信息
例如:
resnet18_epoch5_20230815.pth.tar版本控制:
- 保留几个关键checkpoint(如最佳性能的、特定epoch的)
删除中间checkpoint节省空间
云存储组织:
- 按项目/实验创建不同目录
例如:
resnet18_experiment1/checkpoints/监控训练:
- 保存训练指标(如loss, accuracy)到日志文件
- 考虑使用TensorBoard等工具可视化
6. 常见问题与解决方案
在实际使用中,你可能会遇到以下问题:
6.1 Checkpoint文件太大
问题:ResNet18的checkpoint文件可能达到几十MB,频繁保存会占用大量存储空间。
解决方案: 1. 使用torch.save的_use_new_zipfile_serialization=False参数可以减小文件大小python torch.save(state, filename, _use_new_zipfile_serialization=False)2. 只保存模型参数(不保存优化器状态),但恢复时需要重新初始化优化器 3. 定期清理旧的checkpoint
6.2 云存储上传/下载速度慢
问题:大文件上传下载耗时,影响训练效率。
解决方案: 1. 使用压缩(checkpoint文件通常压缩率很高)python import gzip with gzip.open(filename, 'wb') as f: torch.save(state, f)2. 考虑增量上传,只上传变化的部分 3. 在训练结束时才上传完整checkpoint,训练中只保存在本地
6.3 多GPU训练的特殊处理
问题:使用DataParallel或DistributedDataParallel时,模型保存和加载需要特殊处理。
解决方案: 1. 保存时去除module.前缀:python # 多GPU训练时保存 state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict() torch.save(state_dict, filename)2. 加载时处理多GPU情况:python if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.load_state_dict(torch.load(filename))
6.4 跨平台兼容性问题
问题:在不同操作系统或PyTorch版本间加载checkpoint可能出错。
解决方案: 1. 保存时指定兼容模式:python torch.save(state, filename, _use_new_zipfile_serialization=True)2. 加载时使用map_location参数:python torch.load(filename, map_location='cuda:0') # 或 'cpu'3. 记录PyTorch版本信息在checkpoint中
总结
通过本文,你已经学会了如何实现ResNet18的云端持续学习方案。让我们回顾一下核心要点:
- 云端checkpoint的重要性:解决了长时间训练中的中断风险,支持灵活暂停和继续训练
- 关键实现步骤:定期保存模型状态、优化器状态和训练进度到云端存储
- 完整生产级代码:提供了可直接使用的训练脚本,包含checkpoint保存和恢复功能
- 最佳实践建议:合理的保存频率、命名规范和存储组织方式
- 常见问题解决方案:处理大文件、慢速传输、多GPU训练等实际问题
现在你就可以尝试在自己的项目中实现这个方案了。实测下来,这种方法非常稳定可靠,特别适合需要长时间训练的研究项目。
💡获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。