news 2026/4/28 16:56:23

PyTorch 基础知识点汇总

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
PyTorch 基础知识点汇总

这篇笔记是 PyTorch 基础点,还有自己写代码时容易卡壳的地方整理了一下。主要包含 Tensor 操作、自动求导逻辑、线性回归模型构建、以及 GPU 环境切换。

1. Tensor(张量)基础操作

Tensor 是 PyTorch 最核心的数据格式。咱们平时处理图像或者文本,最后全都要转成这玩意。

  • 维度代表的意思:

    • 0 维 (Scalar):就是一个数(比如 Loss 值)。

    • 1 维 (Vector):一维数组。代表一个样本的特征(比如一个词向量)。

    • 2 维 (Matrix):二维数组。代表多个样本的特征集合(行是样本,列是特征)。

    • 3 维/4 维 (High-dim):图像里最常用,格式通常是 [Batch大小, 通道数, 高, 宽]。

  • 跟 NumPy 互转(实战高频):

import torch import numpy as np # 1. Numpy 转 Tensor np_data = np.array([1.0, 2.0, 3.0]) tensor_data = torch.from_numpy(np_data) # 2. Tensor 转 Numpy np_data_back = tensor_data.numpy() # 3. 改变形状:PyTorch 中用 view x = torch.zeros(4, 4) y = x.view(-1, 8) # -1 代表自动计算维度,这里会自动转成 2x8 print(y.size()) # 查看维度信息

2. 自动求导机制 (Autograd)

PyTorch 强在能自动算反向传播的梯度。

  • 标记求导:在创建 Tensor 时指定 requires_grad=True,系统就会追踪这个变量。

  • 链式法则:调用 .backward() 时,它会顺着计算图瞬间算出所有权重参数的梯度。

  • 避坑指南(梯度累加):PyTorch 的梯度默认是累加的。所以每次训练前必须清零梯度(optimizer.zero_grad()),不然跑出来的梯度全是错的。

3. 模型构建与训练(线性回归代码示例)

这是一个完整的线性回归构建逻辑。

第一步:定义模型类

import torch.nn as nn class LinearRegressionModel(nn.Module): def __init__(self): super(LinearRegressionModel, self).__init__() # 定义一个全连接层,输入1维,输出1维 self.linear = nn.Linear(1, 1) def forward(self, x): # 规定数据的前向传播逻辑 out = self.linear(x) return out model = LinearRegressionModel()

第二步:配置训练参数

# 定义损失函数 (回归任务常用 MSE) criterion = nn.MSELoss() # 定义优化器 (SGD 随机梯度下降) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

第三步:核心训练循环(通用模板)

epochs = 1000 for epoch in range(epochs): # 将输入数据转为 Tensor inputs = torch.from_numpy(x_train_numpy) labels = torch.from_numpy(y_train_numpy) # 1. 梯度清零 (必须放在最前面) optimizer.zero_grad() # 2. 前向传播 (拿到模型预测结果) outputs = model(inputs) # 3. 计算损失 (对比预测值与真实标签) loss = criterion(outputs, labels) # 4. 反向传播 (自动求导算出梯度) loss.backward() # 5. 更新参数 (优化器执行一步更新) optimizer.step() if (epoch+1) % 50 == 0: # 用 loss.item() 把 Tensor 转回普通数值打印 print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

4. GPU 加速与模型保存

  • 切换 GPU 训练:

# 判断是否有可用 CUDA 环境 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1. 把模型搬到显卡上 model.to(device) # 2. 训练循环里把数据搬到显卡上 inputs = inputs.to(device) labels = labels.to(device)
  • 模型保存与读取:

# 推荐保存权重参数字典 (state_dict) torch.save(model.state_dict(), 'my_model.pkl') # 读取模型 (要先实例化上面的类) model.load_state_dict(torch.load('my_model.pkl'))

5. 关于 torch.hub 的说明

torch.hub 是官方的模型库(Model Zoo),里面有很多预训练好的模型(如 ResNet、Transformer)。

  • 加载方法:

# 一行代码加载预训练模型 model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
  • 碎碎念:虽然调包很省事,但咱们在打基础阶段还是得少用。大作业或者平时练手,先把最基础的层(Linear、Conv2d 等)练熟,以后再用这些预训练模型会更有底。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/28 16:48:26

如何快速生成Beyond Compare 5密钥:完整激活指南与实用工具

如何快速生成Beyond Compare 5密钥:完整激活指南与实用工具 【免费下载链接】BCompare_Keygen Keygen for BCompare 5 项目地址: https://gitcode.com/gh_mirrors/bc/BCompare_Keygen 您是否正在寻找Beyond Compare 5的激活解决方案?当这款强大的…

作者头像 李华
网站建设 2026/4/28 16:47:40

OpCore Simplify:突破性OpenCore EFI自动化配置工具深度解析

OpCore Simplify:突破性OpenCore EFI自动化配置工具深度解析 【免费下载链接】OpCore-Simplify A tool designed to simplify the creation of OpenCore EFI 项目地址: https://gitcode.com/GitHub_Trending/op/OpCore-Simplify 在传统黑苹果安装过程中&…

作者头像 李华
网站建设 2026/4/28 16:44:30

用RandLA-Net处理S3DIS数据集:从原始点云到6折交叉验证的完整实战解析

用RandLA-Net处理S3DIS数据集:从原始点云到6折交叉验证的完整实战解析 在三维点云语义分割领域,S3DIS数据集作为室内场景的标杆性基准,常被用于验证算法性能。RandLA-Net凭借其高效的随机降采样和局部特征聚合机制,成为处理大规模…

作者头像 李华
网站建设 2026/4/28 16:43:25

Milvus实战:5分钟搞定一个图片相似搜索Demo(Docker + Python全流程)

Milvus实战:5分钟构建图片相似搜索系统(DockerPython全流程) 想象一下这样的场景:你手机里存了上万张旅行照片,突然想找三年前在京都拍的那张红叶照,但只记得画面里有座木桥和橙色枫叶。传统的关键词搜索完…

作者头像 李华
网站建设 2026/4/28 16:36:26

AgentCorral:可视化集中管理Claude Code配置,告别JSON碎片化

1. 项目概述:为什么我们需要一个Claude Code配置管理工具?如果你和我一样,在日常开发中重度依赖Claude Code,那你肯定也经历过这样的混乱时刻:上周在A项目里精心调教了一个代码审查Agent,这周在B项目里想复…

作者头像 李华
网站建设 2026/4/28 16:35:25

Flutter动画详解:创建流畅的用户体验

Flutter动画详解:创建流畅的用户体验 引言 在现代移动应用开发中,动画是提升用户体验的关键因素。精心设计的动画可以使应用界面更加生动、直观,增强用户与应用的互动感。Flutter提供了强大而灵活的动画系统,使开发者能够创建各…

作者头像 李华