day41早停策略和模型权重的保存@浙大疏锦行
基于day40代码实现模型权重的保存和早停
# 定义损失函数和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=0.001)# 训练参数num_epochs=1000check_interval=10# 每多少轮检查一次验证集# 记录列表train_losses=[]test_losses=[]epochs_rec=[]# ===== 早停策略参数 =====best_test_loss=float('inf')patience=20# 容忍多少次验证集loss不下降 (20 * 10 = 200 epochs)counter=0early_stopped=Falsebest_model_path='best_credit_model.pth'# ======================start_time=time.time()# 使用tqdm显示进度条withtqdm(total=num_epochs,desc="训练进度",unit="epoch")aspbar:forepochinrange(num_epochs):model.train()# 前向传播outputs=model(X_train)loss=criterion(outputs,y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 定期评估和检查早停if(epoch+1)%check_interval==0:model.eval()withtorch.no_grad():test_outputs=model(X_test)test_loss=criterion(test_outputs,y_test)model.train()# 记录train_losses.append(loss.item())test_losses.append(test_loss.item())epochs_rec.append(epoch+1)pbar.set_postfix({'Train Loss':f'{loss.item():.4f}','Test Loss':f'{test_loss.item():.4f}','Best':f'{best_test_loss:.4f}','Patience':f'{counter}/{patience}'})# ===== 早停逻辑 =====iftest_loss.item()<best_test_loss:best_test_loss=test_loss.item()counter=0# 保存最佳模型权重torch.save(model.state_dict(),best_model_path)else:counter+=1ifcounter>=patience:print(f"\n早停触发!在第{epoch+1}轮停止训练。")print(f"最佳测试集损失:{best_test_loss:.4f}")early_stopped=Truebreak# ===================pbar.update(1)print(f"训练耗时:{time.time()-start_time:.2f}秒")# 绘制损失曲线plt.figure(figsize=(10,6))plt.plot(epochs_rec,train_losses,label='Train Loss')plt.plot(epochs_rec,test_losses,label='Test Loss')plt.title('Training and Test Loss (with Early Stopping)')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.grid(True)plt.show()训练进度: 31% 309/1000 [00:01<00:02, 292.95epoch/s, Train Loss=0.4183, Test Loss=0.4887, Best=0.4769, Patience=19/20]
早停触发!在第 310 轮停止训练。 最佳测试集损失: 0.4769 训练耗时: 1.06 秒
# 模型评估 - 加载最佳模型print("\n--- 加载最佳模型进行评估 ---")ifos.path.exists(best_model_path):# 重新初始化模型结构best_model=CreditModel(input_dim).to(device)# 加载权重best_model.load_state_dict(torch.load(best_model_path))best_model.eval()withtorch.no_grad():outputs=best_model(X_test)_,predicted=torch.max(outputs.data,1)total=y_test.size(0)correct=(predicted==y_test).sum().item()accuracy=100*correct/totalprint(f'最佳模型测试集准确率:{accuracy:.2f}%')# 简单的推理示例print("\n--- 推理示例 ---")print(f"真实标签:{y_test[:10].cpu().numpy()}")print(f"预测标签:{predicted[:10].cpu().numpy()}")else:print("未找到保存的模型文件。")
@浙大疏锦行