一、整体比喻:游戏存档系统
想象你在玩一个超长的RPG游戏(训练模型):
训练模型 = 打超长RPG游戏 中断训练 = 游戏突然闪退 保存文件 = 游戏存档点 恢复训练 = 读取存档继续玩
二、为什么要保存?三个血泪教训
2.1 真实场景:训练中断的惨痛经历
场景1:电源故障 你:已经训练了3天3夜,到第98轮了... 啪!停电了! 你:😱 啊啊啊!什么都没了! 场景2:程序报错 你:训练到第50轮,效果越来越好... 报错:CUDA out of memory 你:😭 又要从头开始? 场景3:想修改参数 你:训了100轮,但学习率好像设大了 想:如果能回到第50轮改参数就好了... 现实:只能从头开始 😫
2.2 保存的好处:就像时光机!
防断电:停电也不怕
防报错:程序崩溃可以继续
可以调参:回到某个时间点改参数
可以比较:保存多个版本比较效果
可以分享:把存档给朋友继续训练
三、必须保存的四个核心文件(四大护法)
3.1 文件1:模型权重文件(.pth或.pt)
比喻:游戏角色的属性和装备
文件内容:模型的所有参数(权重) 就像:角色的力量、敏捷、智力、装备等 文件大小:几十MB到几GB 文件格式:.pth, .pt, .ckpt
包含什么:
- 每一层的权重(卷积核参数) - 每一层的偏置 - BatchNorm的均值和方差 - 所有可学习参数的状态
代码示例:
# 保存模型权重 torch.save(model.state_dict(), 'model_weights.pth') # 这个文件包含了: # conv1.weight: [[0.1, 0.2], [0.3, 0.4], ...] # conv1.bias: [0.01, 0.02, ...] # bn1.weight: [1.0, 0.9, ...] # bn1.running_mean: [0.0, 0.1, ...] # ...3.2 文件2:优化器状态文件(.pth)
比喻:游戏的进度和经验值
文件内容:优化器的状态 就像:角色当前的经验值、技能冷却时间等 重要性:⭐⭐⭐⭐⭐(没有这个就无法正确继续!)
包含什么:
1. 当前学习率 2. 动量缓冲(如果优化器有动量) 3. Adam的m和v估计(一阶矩、二阶矩估计) 4. 其他优化器特有的状态
为什么重要?
没有优化器状态: - 只能加载模型权重 - 但优化器会从头开始 - 相当于:角色等级保留了,但经验值清空了 - 结果:继续训练可能会出问题 有优化器状态: - 完美恢复训练状态 - 学习率调度正常 - 动量等状态保持
代码示例:
# 保存优化器状态 torch.save(optimizer.state_dict(), 'optimizer_state.pth') # 这个文件包含了: # state: (每个参数的优化器状态) # param1: # step: 1000 # 已经更新了1000次 # exp_avg: [...] # Adam的一阶矩估计 # exp_avg_sq: [...] # Adam的二阶矩估计 # param_groups: (优化器参数组) # lr: 0.001 # 当前学习率 # betas: (0.9, 0.999) # ...3.3 文件3:训练状态文件(.json或.pkl)
比喻:游戏的任务日志和成就系统
文件内容:训练的各种状态信息 就像:完成了哪些任务、当前章节、游戏设置等
应该包含:
训练状态 = { # 1. 训练进度 'epoch': 50, # 当前训练到第几轮 'batch': 125, # 当前批次的索引 'total_epochs': 100, # 总共要训练多少轮 # 2. 训练历史 'train_loss_history': [0.5, 0.4, 0.35, ...], # 训练损失历史 'val_loss_history': [0.6, 0.5, 0.45, ...], # 验证损失历史 'train_acc_history': [0.7, 0.75, 0.78, ...], # 训练准确率历史 'val_acc_history': [0.65, 0.7, 0.73, ...], # 验证准确率历史 # 3. 最佳记录 'best_val_loss': 0.42, # 最佳验证损失 'best_val_acc': 0.76, # 最佳验证准确率 'best_epoch': 45, # 达到最佳的轮数 # 4. 学习率状态 'lr_history': [0.001, 0.00095, 0.0009, ...], # 学习率历史 'current_lr': 0.0009, # 当前学习率 # 5. 其他状态 'random_seed': 42, # 随机种子(重要!) 'training_time': '36:15:22', # 已训练时间 'start_time': '2024-01-01 10:00:00', # 开始时间 }3.4 文件4:配置文件(.yaml或.json)
比喻:游戏的设置选项
文件内容:所有的训练配置参数 就像:游戏难度、画质设置、按键设置等 重要性:⭐⭐⭐⭐(没有这个就不知道当时怎么设置的)
包含什么:
配置文件 = { # 模型配置 'model_type': 'yolov8n', 'num_classes': 80, 'pretrained': True, # 数据配置 'data_path': './data/coco', 'img_size': 640, 'batch_size': 32, 'num_workers': 4, # 训练配置 'learning_rate': 0.001, 'optimizer': 'Adam', 'weight_decay': 0.0005, 'momentum': 0.937, # 数据增强配置 'hsv_h': 0.015, 'hsv_s': 0.7, 'hsv_v': 0.4, 'flipud': 0.0, 'fliplr': 0.5, # 其他 'device': 'cuda:0', 'amp': True, # 混合精度 'save_dir': './runs', }四、完整的保存方案(四种存档策略)
4.1 方案1:简单存档(新手推荐)
比喻:只存一个快速存档
def save_simple_checkpoint(epoch, model, optimizer, path='checkpoint.pth'): """ 最简单的保存方法 就像:只按F5快速存档 """ checkpoint = { 'epoch': epoch, # 当前轮数 'model_state_dict': model.state_dict(), # 模型权重 'optimizer_state_dict': optimizer.state_dict(), # 优化器状态 'loss': current_loss, # 当前损失 } torch.save(checkpoint, path) print(f"✅ 已保存检查点到: {path} (第{epoch}轮)") # 使用示例:每10轮保存一次 if epoch % 10 == 0: save_simple_checkpoint(epoch, model, optimizer, f'checkpoint_epoch_{epoch}.pth')4.2 方案2:完整存档(推荐)
比喻:存一个完整的存档文件
def save_full_checkpoint(epoch, model, optimizer, scheduler, train_history, config, path='checkpoint_full.pth'): """ 完整的检查点保存 就像:存一个包含所有信息的存档 """ checkpoint = { # 1. 训练进度 'epoch': epoch, 'global_step': global_step, # 总训练步数 # 2. 模型状态 'model_state_dict': model.state_dict(), 'model_config': model.config if hasattr(model, 'config') else None, # 3. 优化状态 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict() if scheduler else None, # 4. 训练历史 'train_loss_history': train_history['train_loss'], 'val_loss_history': train_history['val_loss'], 'train_acc_history': train_history['train_acc'], 'val_acc_history': train_history['val_acc'], # 5. 最佳记录 'best_val_loss': best_val_loss, 'best_val_acc': best_val_acc, 'best_epoch': best_epoch, # 6. 配置信息 'config': config, # 7. 其他信息 'random_seed': random_seed, 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), 'training_time': training_time, } torch.save(checkpoint, path) print(f"💾 完整检查点已保存: {path}")4.3 方案3:智能存档(高级)
比喻:自动管理多个存档
class CheckpointManager: """ 智能检查点管理器 就像:游戏的自动存档系统 """ def __init__(self, save_dir='./checkpoints', max_save=5): self.save_dir = save_dir self.max_save = max_save # 最多保存几个检查点 os.makedirs(save_dir, exist_ok=True) def save(self, epoch, model, optimizer, val_acc, is_best=False): """ 保存检查点,自动管理空间 """ # 1. 创建检查点 checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_acc': val_acc, } # 2. 总是保存最新的 latest_path = os.path.join(self.save_dir, 'latest.pth') torch.save(checkpoint, latest_path) # 3. 如果是最好的,额外保存 if is_best: best_path = os.path.join(self.save_dir, 'best.pth') torch.save(checkpoint, best_path) print(f"🏆 保存最佳模型: {best_path} (准确率: {val_acc:.2f}%)") # 4. 定期保存(每10轮) if epoch % 10 == 0: epoch_path = os.path.join(self.save_dir, f'epoch_{epoch:03d}.pth') torch.save(checkpoint, epoch_path) # 5. 清理旧的检查点(保持最多max_save个) self._cleanup_old_checkpoints() def _cleanup_old_checkpoints(self): """清理旧的检查点文件""" # 获取所有检查点文件 checkpoint_files = [] for f in os.listdir(self.save_dir): if f.startswith('epoch_') and f.endswith('.pth'): epoch_num = int(f.split('_')[1].split('.')[0]) checkpoint_files.append((epoch_num, f)) # 按轮数排序 checkpoint_files.sort() # 删除多余的 if len(checkpoint_files) > self.max_save: for i in range(len(checkpoint_files) - self.max_save): old_file = checkpoint_files[i][1] old_path = os.path.join(self.save_dir, old_file) os.remove(old_path) print(f"🗑️ 删除旧检查点: {old_file}")4.4 方案4:YOLOv8风格存档
比喻:专业游戏主播的存档系统
def save_yolo_checkpoint(epoch, model, optimizer, results, save_dir='./runs/train'): """ YOLOv8风格的检查点保存 """ import yaml from datetime import datetime # 1. 创建时间戳 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # 2. 创建检查点目录 checkpoint_dir = os.path.join(save_dir, f'exp_{timestamp}') os.makedirs(checkpoint_dir, exist_ok=True) # 3. 保存模型权重 weights_path = os.path.join(checkpoint_dir, 'weights', 'last.pt') os.makedirs(os.path.dirname(weights_path), exist_ok=True) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, weights_path) # 4. 保存结果(metrics) results_path = os.path.join(checkpoint_dir, 'results.csv') results.to_csv(results_path, index=False) # 5. 保存配置 config = { 'data': model.args['data'] if hasattr(model, 'args') else 'coco.yaml', 'model': model.args['model'] if hasattr(model, 'args') else 'yolov8n', 'epochs': epoch, 'imgsz': 640, 'batch': 16, 'save_dir': checkpoint_dir, 'timestamp': timestamp, } config_path = os.path.join(checkpoint_dir, 'args.yaml') with open(config_path, 'w') as f: yaml.dump(config, f) # 6. 保存训练日志 log_path = os.path.join(checkpoint_dir, 'train.log') # ... 保存日志 print(f"📁 YOLO检查点保存到: {checkpoint_dir}") return checkpoint_dir五、如何恢复训练(读取存档)
5.1 恢复训练的完整流程
def load_and_resume_training(checkpoint_path, model, optimizer=None): """ 加载检查点并恢复训练 比喻:读取游戏存档继续玩 """ print(f"🔄 正在加载检查点: {checkpoint_path}") # 1. 加载检查点文件 checkpoint = torch.load(checkpoint_path, map_location='cpu') # 2. 加载模型权重 model.load_state_dict(checkpoint['model_state_dict']) print(f"✅ 模型权重已加载 (第{checkpoint.get('epoch', '未知')}轮)") # 3. 加载优化器状态(如果提供优化器) if optimizer is not None and 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print("✅ 优化器状态已加载") # 4. 获取训练进度 start_epoch = checkpoint.get('epoch', 0) + 1 # 从下一轮开始 best_val_acc = checkpoint.get('best_val_acc', 0) # 5. 加载训练历史(如果有) train_history = { 'train_loss': checkpoint.get('train_loss_history', []), 'val_loss': checkpoint.get('val_loss_history', []), 'train_acc': checkpoint.get('train_acc_history', []), 'val_acc': checkpoint.get('val_acc_history', []), } # 6. 打印恢复信息 print("\n恢复信息:") print(f" 起始轮次: {start_epoch}") print(f" 最佳准确率: {best_val_acc:.2%}") print(f" 训练历史长度: {len(train_history['train_loss'])}") return start_epoch, best_val_acc, train_history # 使用示例 print("游戏闪退了!读取存档继续...") start_epoch, best_acc, history = load_and_resume_training( 'checkpoints/best.pth', model, optimizer ) # 继续训练 for epoch in range(start_epoch, total_epochs): train_one_epoch(...)5.2 处理特殊情况
情况1:模型结构变了怎么办?
def load_with_partial_match(checkpoint_path, model): """ 部分加载:当模型结构有变化时 比喻:换了新角色,但继承部分旧装备 """ checkpoint = torch.load(checkpoint_path) model_state_dict = model.state_dict() # 只加载匹配的参数 loaded_params = 0 total_params = len(model_state_dict.keys()) for name, param in checkpoint['model_state_dict'].items(): if name in model_state_dict: # 检查形状是否匹配 if param.shape == model_state_dict[name].shape: model_state_dict[name] = param loaded_params += 1 else: print(f"⚠️ 形状不匹配: {name}") else: print(f"⚠️ 参数不存在: {name}") model.load_state_dict(model_state_dict, strict=False) print(f"✅ 部分加载完成: {loaded_params}/{total_params} 个参数") return model情况2:只有模型权重,没有优化器状态
def resume_without_optimizer(checkpoint_path, model, new_optimizer): """ 只有模型权重,没有优化器状态时 比喻:角色等级保留了,但经验值要重新积累 """ # 加载模型 checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) # 获取之前的训练信息 epoch = checkpoint.get('epoch', 0) old_lr = None # 尝试从优化器状态中提取学习率 if 'optimizer_state_dict' in checkpoint: old_optimizer_state = checkpoint['optimizer_state_dict'] if 'param_groups' in old_optimizer_state: old_lr = old_optimizer_state['param_groups'][0]['lr'] # 设置新的优化器,尽量使用之前的学习率 if old_lr is not None: for param_group in new_optimizer.param_groups: param_group['lr'] = old_lr print(f"📊 使用之前的学习率: {old_lr}") print(f"⚠️ 注意:优化器状态已重置,从第{epoch+1}轮继续") return epoch + 1六、实战:完整的训练+保存+恢复代码
6.1 完整的训练脚本(带保存功能)
import torch import torch.nn as nn import torch.optim as optim import os import time import json from datetime import datetime class TrainingManager: """ 完整的训练管理器(带保存恢复功能) 比喻:专业的游戏存档系统 """ def __init__(self, model, optimizer, config, save_dir='./checkpoints'): self.model = model self.optimizer = optimizer self.config = config self.save_dir = save_dir # 创建保存目录 os.makedirs(save_dir, exist_ok=True) # 训练状态 self.current_epoch = 0 self.global_step = 0 self.best_val_acc = 0 self.best_epoch = 0 # 训练历史 self.history = { 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'lr': [], } # 开始时间 self.start_time = time.time() # 自动创建实验名 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') self.exp_name = f"exp_{timestamp}" self.exp_dir = os.path.join(save_dir, self.exp_name) os.makedirs(self.exp_dir, exist_ok=True) print(f"🎮 训练管理器已初始化") print(f"📁 实验目录: {self.exp_dir}") def save_checkpoint(self, epoch, is_best=False, reason='periodic'): """ 保存检查点 """ checkpoint = { # 训练进度 'epoch': epoch, 'global_step': self.global_step, # 模型状态 'model_state_dict': self.model.state_dict(), # 优化器状态 'optimizer_state_dict': self.optimizer.state_dict(), # 训练历史 'train_loss_history': self.history['train_loss'], 'val_loss_history': self.history['val_loss'], 'train_acc_history': self.history['train_acc'], 'val_acc_history': self.history['val_acc'], 'lr_history': self.history['lr'], # 最佳记录 'best_val_acc': self.best_val_acc, 'best_epoch': self.best_epoch, # 配置信息 'config': self.config, # 元数据 'exp_name': self.exp_name, 'save_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'training_time': time.time() - self.start_time, } # 保存文件 if is_best: filename = f'best_epoch_{epoch:03d}.pth' save_reason = '🏆 最佳模型' else: filename = f'epoch_{epoch:03d}.pth' save_reason = f'📅 {reason}保存' filepath = os.path.join(self.exp_dir, filename) torch.save(checkpoint, filepath) # 总是保存最新的 latest_path = os.path.join(self.exp_dir, 'latest.pth') torch.save(checkpoint, latest_path) # 保存配置为JSON(便于查看) config_path = os.path.join(self.exp_dir, 'config.json') with open(config_path, 'w') as f: json.dump(self.config, f, indent=2) print(f"{save_reason}: {filepath}") return filepath def load_checkpoint(self, checkpoint_path): """ 加载检查点 """ print(f"🔄 正在加载检查点: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') # 恢复模型 self.model.load_state_dict(checkpoint['model_state_dict']) # 恢复优化器 self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 恢复训练状态 self.current_epoch = checkpoint['epoch'] self.global_step = checkpoint['global_step'] self.best_val_acc = checkpoint['best_val_acc'] self.best_epoch = checkpoint['best_epoch'] # 恢复历史 self.history = { 'train_loss': checkpoint['train_loss_history'], 'val_loss': checkpoint['val_loss_history'], 'train_acc': checkpoint['train_acc_history'], 'val_acc': checkpoint['val_acc_history'], 'lr': checkpoint['lr_history'], } # 计算已训练时间 if 'training_time' in checkpoint: self.start_time = time.time() - checkpoint['training_time'] print(f"✅ 检查点加载完成") print(f" 当前轮次: {self.current_epoch}") print(f" 最佳准确率: {self.best_val_acc:.2%} (第{self.best_epoch}轮)") print(f" 训练步数: {self.global_step}") return self.current_epoch + 1 # 返回下一轮 def train_epoch(self, train_loader, val_loader, criterion, device): """ 训练一个epoch """ self.model.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device) # 前向传播 self.optimizer.zero_grad() outputs = self.model(inputs) loss = criterion(outputs, targets) # 反向传播 loss.backward() self.optimizer.step() # 统计 train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() self.global_step += 1 # 每50个batch打印一次 if batch_idx % 50 == 0: print(f" 批次 [{batch_idx}/{len(train_loader)}] " f"损失: {loss.item():.4f}") # 计算训练指标 avg_train_loss = train_loss / len(train_loader) train_acc = 100. * correct / total # 验证 val_loss, val_acc = self.validate(val_loader, criterion, device) # 更新历史 self.current_epoch += 1 self.history['train_loss'].append(avg_train_loss) self.history['train_acc'].append(train_acc) self.history['val_loss'].append(val_loss) self.history['val_acc'].append(val_acc) self.history['lr'].append(self.optimizer.param_groups[0]['lr']) # 更新最佳记录 if val_acc > self.best_val_acc: self.best_val_acc = val_acc self.best_epoch = self.current_epoch self.save_checkpoint(self.current_epoch, is_best=True) return avg_train_loss, train_acc, val_loss, val_acc def validate(self, val_loader, criterion, device): """验证函数""" self.model.eval() val_loss = 0 correct = 0 total = 0 with torch.no_grad(): for inputs, targets in val_loader: inputs, targets = inputs.to(device), targets.to(device) outputs = self.model(inputs) loss = criterion(outputs, targets) val_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() avg_val_loss = val_loss / len(val_loader) val_acc = 100. * correct / total return avg_val_loss, val_acc def auto_save_strategy(self): """ 自动保存策略 """ epoch = self.current_epoch # 策略1:每10轮保存一次 if epoch % 10 == 0: self.save_checkpoint(epoch, reason='每10轮保存') # 策略2:前10轮每轮都保存(重要!) elif epoch <= 10: self.save_checkpoint(epoch, reason='初期密集保存') # 策略3:损失下降明显时保存 if len(self.history['val_loss']) >= 2: loss_improve = self.history['val_loss'][-2] - self.history['val_loss'][-1] if loss_improve > 0.01: # 损失明显下降 self.save_checkpoint(epoch, reason='损失明显下降')6.2 使用示例
# 初始化训练管理器 config = { 'model': 'ResNet18', 'dataset': 'CIFAR10', 'batch_size': 128, 'learning_rate': 0.001, 'epochs': 100, } manager = TrainingManager(model, optimizer, config, save_dir='./my_experiments') # 尝试加载之前的检查点(如果存在) latest_checkpoint = os.path.join(manager.exp_dir, 'latest.pth') if os.path.exists(latest_checkpoint): print("发现之前的检查点,尝试恢复训练...") start_epoch = manager.load_checkpoint(latest_checkpoint) else: start_epoch = 0 print("没有找到检查点,从头开始训练") # 训练循环 for epoch in range(start_epoch, config['epochs']): print(f"\n{'='*60}") print(f"第 {epoch+1}/{config['epochs']} 轮") print(f"{'='*60}") # 训练一个epoch train_loss, train_acc, val_loss, val_acc = manager.train_epoch( train_loader, val_loader, criterion, device ) # 打印结果 print(f"训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}%") print(f"验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.2f}%") # 自动保存 manager.auto_save_strategy() # 每轮结束后检查是否要早停 if manager.current_epoch - manager.best_epoch > 20: print("🎯 早停触发:连续20轮没有进步") break print(f"\n🎉 训练完成!最佳准确率: {manager.best_val_acc:.2f}%")七、文件组织结构建议
7.1 推荐的目录结构
my_experiment/ ├── checkpoints/ # 所有检查点 │ ├── exp_20240101_120000/ # 实验1 │ │ ├── epoch_001.pth # 第1轮检查点 │ │ ├── epoch_010.pth # 第10轮检查点 │ │ ├── best_epoch_045.pth # 最佳模型 │ │ ├── latest.pth # 最新模型 │ │ └── config.json # 配置文件 │ │ │ └── exp_20240102_150000/ # 实验2 │ ├── logs/ # 训练日志 │ ├── exp_20240101_120000.log │ └── exp_20240102_150000.log │ ├── tensorboard/ # TensorBoard日志 │ ├── exp_20240101_120000/ │ └── exp_20240102_150000/ │ └── results/ # 结果文件 ├── exp_20240101_120000.csv └── exp_20240102_150000.csv
7.2 不同阶段的保存策略
阶段1:刚开始训练(前10轮)
策略:每轮都保存 原因:初期变化快,容易出问题 文件:epoch_001.pth, epoch_002.pth, ..., epoch_010.pth
阶段2:稳定训练(10-100轮)
策略:每10轮保存一次 + 最佳模型 原因:变化相对稳定 文件:epoch_020.pth, epoch_030.pth, ..., best.pth
阶段3:精细训练(100轮后)
策略:每50轮保存一次 + 最佳模型 + 最新模型 原因:训练时间长,变化慢 文件:epoch_150.pth, epoch_200.pth, best.pth, latest.pth
八、常见问题解答(Q&A)
Q1:应该保存多少个检查点?
答:建议3-5个: 1. 最新的(latest.pth) 2. 最佳的(best.pth) 3. 几个关键轮次的(如epoch_050.pth, epoch_100.pth) 太多会占空间,太少不够用
Q2:文件太大怎么办?
答:可以: 1. 只保存模型权重(不保存优化器状态) 2. 使用torch.save的压缩格式 3. 定期清理旧的检查点 4. 使用云存储(如Google Drive)
Q3:训练中断后怎么判断从哪继续?
答:看这几个文件: 1. latest.pth → 最新的进度 2. 按时间戳排序 → 找最新的 3. 看文件名中的轮数 → 找最大的数字
Q4:多人协作时怎么管理?
答:建议: 1. 每人有自己的实验目录 2. 文件名包含用户名和日期 3. 使用git管理代码,不管理大文件 4. 大的检查点放共享存储
九、总结:记住这些要点
9.1 必须保存的四个文件
1. 🎯 模型权重文件(.pth)→ 最重要的! 2. 📊 优化器状态文件(.pth)→ 同样重要! 3. 📝 训练状态文件(.json)→ 记录历史 4. ⚙️ 配置文件(.yaml)→ 知道当时怎么设置的
9.2 推荐的保存频率
🔵 每轮都保存:前10轮 🟢 每10轮保存:10-100轮 🟡 每50轮保存:100轮以后 🔴 总是保存:最佳模型 + 最新模型
9.3 一句话建议
"训练前先想好怎么保存,就像出门前先想好带什么钥匙"
新手建议:就用最简单的方案,每10轮保存一次,同时保存最佳模型。这样既安全又不占太多空间!