news 2026/5/23 13:29:01

一文看明白PyTorch 模型设计训练保存加载预测

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
一文看明白PyTorch 模型设计训练保存加载预测

需求

输入x
128维

fc1
Linear 128→96

ReLU激活

Dropout 0.2

fc2
Linear 96→64

ReLU激活

Dropout 0.2

fc3
Linear 64→32

输出out
32维

代码样例

包含训练 → 保存 → 加载 → 预测,代码可以直接运行:

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoader,TensorDataset# -----------------------------# 1. 定义模型# -----------------------------classSimpleModel(nn.Module):def__init__(self):super(SimpleModel,self).__init__()self.fc1=nn.Linear(128,96)self.fc2=nn.Linear(96,64)self.fc3=nn.Linear(64,32)self.relu=nn.ReLU()self.dropout=nn.Dropout(0.2)defforward(self,x):x=self.relu(self.fc1(x))x=self.dropout(x)x=self.relu(self.fc2(x))x=self.dropout(x)out=self.fc3(x)returnout# -----------------------------# 2. 准备数据 (示例随机数据)# -----------------------------X=torch.randn(1000,128)y=torch.randn(1000,32)dataset=TensorDataset(X,y)batch_size=32dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True)# -----------------------------# 3. 定义损失函数和优化器# MSELoss Mean Squared Error(均方误差)# -----------------------------model=SimpleModel()criterion=nn.MSELoss()optimizer=optim.Adam(model.parameters(),lr=0.001)# -----------------------------# 4. 训练循环# -----------------------------num_epochs=20forepochinrange(num_epochs):model.train()# 训练模式epoch_loss=0forbatch_X,batch_yindataloader:optimizer.zero_grad()outputs=model(batch_X)loss=criterion(outputs,batch_y)loss.backward()optimizer.step()epoch_loss+=loss.item()*batch_X.size(0)epoch_loss/=len(dataset)print(f"Epoch{epoch+1}/{num_epochs}, Loss:{epoch_loss:.4f}")# -----------------------------# 5. 保存训练好的模型参数# -----------------------------torch.save(model.state_dict(),"simple_model.pth")print("模型参数已保存到 simple_model.pth")# -----------------------------# 6. 加载模型进行预测# -----------------------------# 重新创建模型对象model_loaded=SimpleModel()# 加载保存的参数model_loaded.load_state_dict(torch.load("simple_model.pth"))# 切换到评估模式model_loaded.eval()# 假设有新样本 x_newx_new=torch.randn(5,128)withtorch.no_grad():# 推理时禁用梯度y_pred=model_loaded(x_new)print("加载模型预测结果形状:",y_pred.shape)# [5, 32]

✅ 特点

  1. 训练完成后保存权重simple_model.pth可以随时加载。
  2. 加载模型时必须重新创建类,然后load_state_dict
  3. 推理时切换到eval()模式,保证 Dropout 不随机失活。
  4. 使用torch.no_grad()提升预测效率,减少显存占用。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/23 13:27:39

新手开发者首次使用Taotoken从注册到发出第一个API请求的全流程

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 新手开发者首次使用Taotoken从注册到发出第一个API请求的全流程 本文面向初次接触大模型API的开发者,提供一个从零开始…

作者头像 李华
网站建设 2026/5/23 13:26:56

python运行提速方案全解

纯PYTHON运行10亿次求余要140秒,VBS只要130秒,VC,.NET 只要1.5秒。PYTHON不可能去提速了吧,除了用第三方软件或装强大的IDE把PY代码编绎成EXE,还有什么方法提速看到这个对比,确实容易让人对 Python 的性能感到一丝绝望…

作者头像 李华
网站建设 2026/5/23 13:26:03

从牧高笛看露营装备业:增量不再,存量难吞

2022年的时候,牧高笛的仓库里,一顶帐篷从入库到出库,平均只需要55天。到了2025年,这个数字变成了261天。将近九个月。足够一个孩子从孕育到出生,而一顶帐篷还在货架上蒙灰。牧高笛不是没有努力过。2025年全年&#xff…

作者头像 李华
网站建设 2026/5/23 13:24:03

告别游戏平台切换烦恼:用Playnite打造你的专属游戏中心

告别游戏平台切换烦恼:用Playnite打造你的专属游戏中心 【免费下载链接】Playnite Video game library manager with support for wide range of 3rd party libraries and game emulation support, providing one unified interface for your games. 项目地址: ht…

作者头像 李华
网站建设 2026/5/23 13:21:15

终极指南:快速掌握Hybrid A*路径规划器

终极指南:快速掌握Hybrid A*路径规划器 【免费下载链接】path_planner Hybrid A* Path Planner for the KTH Research Concept Vehicle 项目地址: https://gitcode.com/gh_mirrors/pa/path_planner 想要为你的自动驾驶项目或机器人系统找到一个高效、可靠的路…

作者头像 李华
网站建设 2026/5/23 13:20:01

如何快速掌握开源电磁仿真工具:openEMS的5个高效技巧指南

如何快速掌握开源电磁仿真工具:openEMS的5个高效技巧指南 【免费下载链接】openEMS openEMS is a free and open-source electromagnetic field solver using the EC-FDTD method. 项目地址: https://gitcode.com/gh_mirrors/ope/openEMS 想要学习电磁仿真但…

作者头像 李华