news 2026/5/4 2:19:05

PyTorch张量操作全攻略:从入门到精通

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
PyTorch张量操作全攻略:从入门到精通

在开篇中,我们用 3 分钟跑通了第一个手写数字识别网络。接下来,我们将从头开始深入 PyTorch 的每一个核心组件。第一步,就是张量(Tensor)——它是 PyTorch 的基石,相当于 NumPy 的ndarray嫁接了 GPU 加速和自动求导。

很多新手一上来就写网络,却发现数据处理错误、形状不匹配、device 不对。根本原因就是张量操作不熟。这篇文章会用大量实例,带你系统掌握张量的创建、变形、索引、运算和广播机制,并对比 NumPy 的异同,让你彻底告别形状错误。


一、张量是什么?

张量就是多维数组。你完全可以把它理解成可以跑在 GPU 上、且支持自动求导的 NumPy 数组

在 PyTorch 里,几乎所有的数据和模型参数都是张量。一个标量是 0 维张量,一个向量是 1 维张量,矩阵是 2 维,图像数据通常是 4 维(batch, channel, height, width)


二、创建张量的 10 种常用方法

引入库:

import torch import numpy as np

2.1 从数据直接创建

# 从列表创建 a = torch.tensor([1, 2, 3]) b = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) print(a, a.dtype) # torch.int64 print(b, b.dtype) # torch.float32

2.2 从 NumPy 互相转换

y([1, 2, 3]) t = torch.from_numpy(arr) # numpy -> tensor,共享内存 t2 = torch.tensor(arr) # 会复制一份 # tensor 转 numpy n = t.numpy() # 同样共享内存

注意from_numpy创建出来的张量和原数组共享内存,修改其中一个,另一个也会变。torch.tensor()则是深拷贝。

2.3 创建特殊张量

zeros = torch.zeros(2, 3) # 全 0 ones = torch.ones(2, 3) # 全 1 eye = torch.eye(3) # 单位矩阵 rand = torch.rand(2, 3) # [0,1) 均匀分布 randn = torch.randn(2, 3) # 标准正态分布 N(0,1)

2.4 按范围创建

arange = torch.arange(0, 10, step=2) # tensor([0,2,4,6,8]) linspace = torch.linspace(0, 1, steps=5) # tensor([0.00, 0.25, 0.50, 0.75, 1.00])

2.5 创建同形状张量

x = torch.ones(2, 3) y = torch.zeros_like(x) # 形状同 x,全 0 z = torch.randn_like(x) # 形状同 x,正态分布

2.6 指定数据类型和设备

f = torch.tensor([1, 2], dtype=torch.float32, device='cuda')

三、张量的基本属性

t = torch.randn(2, 3, 4) print(t.shape) # torch.Size([2, 3, 4]) print(t.size()) # 功能同上 print(t.dtype) # torch.float32 print(t.device) # cpu 或 cuda:0 print(t.ndim) # 维度数 3 print(t.numel()) # 元素总数 2*3*4=24

四、张量索引与切片:和 NumPy 一模一样

索引方式完全遵循 Python / NumPy 的规范。

t = torch.arange(12).reshape(3, 4) print(t) # tensor([[ 0, 1, 2, 3], # [ 4, 5, 6, 7], # [ 8, 9, 10, 11]]) # 取第 1 行 print(t[0]) # tensor([0, 1, 2, 3]) # 取第 1 列 print(t[:, 0]) # tensor([0, 4, 8]) # 区域切片 print(t[0:2, 1:3]) # 前 2 行,第 1-2 列 # 步长步取 print(t[::2, ::2]) # 每隔一行一列取一次

高级索引(与 NumPy 一样的 fancy indexing):

# 整数数组索引 idx = [0, 2] print(t[idx]) # 取第 0 和第 2 行 # 布尔索引 mask = t > 5 print(t[mask]) # tensor([6, 7, 8, 9, 10, 11])

五、张量变形:reshape, view, transpose, permute

这是实际写代码时出错最多的地方,必须记牢。

5.1reshapevsview

x = torch.arange(12) a = x.reshape(3, 4) # 安全,但可能复制 b = x.view(3, 4) # 必须内存连续,否则报错 c = x.contiguous().view(3, 4) # 保证内存连续后再 view

规则view只在张量内存连续时可用,通常 reshape 更保险;但 reshape 在非连续时会复制一份,不共享数据。

5.2 增加/移除维度

x = torch.tensor([1, 2, 3]) # shape (3,) print(x.unsqueeze(0)) # shape (1,3) 在第 0 维前加一维 print(x.unsqueeze(1)) # shape (3,1) 在第 1 维后加一维 y = torch.randn(1, 3, 1, 4) print(y.squeeze()) # 移除所有大小为 1 的维度,变成 (3,4) print(y.squeeze(0)) # 只移除第 0 维,若它等于 1

5.3transposepermute

t = torch.randn(2, 3, 4) # 交换两维 t1 = t.transpose(0, 2) # shape (4,3,2) # 多重转置用 permute t2 = t.permute(2, 1, 0) # shape (4,3,2)

transpose一次只能交换两个维度,permute可以一次性对全部维度重新排列。

5.4 扩维广播常用expandrepeat

a = torch.tensor([[1], [2], [3]]) # shape (3,1) b = a.expand(3, 4) # 广播成 (3,4),不复制数据 c = a.repeat(1, 4) # 实际复制数据成 (3,4)

expand只在需要时扩展,不分配新内存;repeat是真正复制。


六、张量的数学运算

6.1 基本运算

a = torch.tensor([1.0, 2.0]) b = torch.tensor([3.0, 4.0]) print(a + b) # 按元素加 print(a * b) # 按元素乘(不是矩阵乘法) print(a @ b) # 点积 print(a.pow(2)) # 平方 print(a.sqrt()) # 开方

6.2 矩阵乘法

x = torch.randn(2, 3) y = torch.randn(3, 4) result = torch.mm(x, y) # 矩阵乘,结果 (2,4) result = x @ y # 等价写法 # 对于批量矩阵乘用 torch.bmm 或 torch.matmul

6.3 聚合操作

t = torch.randn(3, 4) print(t.sum()) # 所有元素和 print(t.sum(dim=0)) # 按行方向求和(压缩第 0 维),形状 (4,) print(t.sum(dim=1, keepdim=True)) # 按列方向求和并保持维度,形状 (3,1) print(t.mean(), t.max(), t.min()) print(t.argmax(dim=1)) # 每行最大值的索引

dim参数必须理解清楚sum(dim=0)就是把第 0 维压缩掉,你可以想象成“在这个方向上拍扁”。


七、广播机制 (Broadcasting)

广播是 PyTorch 里最重要的隐式操作之一,它允许形状不同的张量在运算时自动扩展。

规则:从最后一个维度向前对比,满足以下条件之一即可广播:

  • 两个维度大小相等

  • 其中一个维度是 1

  • 其中一个维度不存在

例子:

a = torch.randn(3, 1) # shape (3,1) b = torch.randn(1, 4) # shape (1,4) c = a + b # 广播成 (3,4)

常见错误形如(3,4) + (3,)会触发广播吗?答案是不会直接报错,因为(3,)可以被广播成(1,3)再和(3,4)尝试,但最后一个维度 4 vs 3 不匹配,会报错。要显式调成(3,1)(1,4)


八、张量在 CPU 与 GPU 间移动

x = torch.randn(3, 3) if torch.cuda.is_available(): device = torch.device("cuda") x_gpu = x.to(device) # 搬到 GPU # 或者 x.cuda() x_cpu = x_gpu.cpu() # 搬回 CPU

注意:模型和数据必须在同一个设备上,否则会报 RuntimeError。


九、张量与自动求导的初遇

张量通过requires_grad=True开启梯度追踪:

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) y = x.pow(2).sum() y.backward() print(x.grad) # tensor([2., 4., 6.])

每当你对一个标量调用backward(),所有路径上的梯度都会自动计算并累加到张量的.grad属性中。

这也是后面要详细讲解的autograd 机制,现在你只需要知道张量自带这个超级能力。


十、本讲总结与练习

今天我们全面拆解了张量的创建、索引、变形、运算和广播。掌握这些操作,你就拿到了玩转 PyTorch 的钥匙。

试着做几个练习

  1. torch.randn创建一个形状为(4, 5)的张量,提取出第 1、3 行和第 2、4 列构成的子矩阵。

  2. 实现一个形状为(3, 1)的张量与形状为(4,)的向量相加,结果形状是什么?

  3. 将上面创建的张量搬到 GPU 并验证 device。


如果这篇文章对你有帮助,请你

  • 点个收藏,方便后面查阅

  • 关注我,第一时间收到系列更新

  • 在评论区打卡:你平时在张量上最常犯的错误是什么?一起避坑!

下篇见,我们继续拆解 PyTorch 的核心。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/4 2:18:09

PINGPONG基准测试:评估AI在多语言代码理解中的表现

1. 项目背景与核心价值在全球化协作的软件开发环境中,多语言代码混合的场景越来越普遍。一个Java后端工程师可能需要调用Python编写的机器学习模型,而前端开发者又需要理解这些接口的返回格式。这种跨语言协作的常态催生了对代码理解与对话能力的新需求—…

作者头像 李华
网站建设 2026/5/4 2:12:11

HPH的构造详解

HPH(高压氢系统)是氢能利用中的关键设备,其构造直接决定了储氢密度与安全性。简单来说,HPH由内胆、碳纤维缠绕层、阀体及温控装置四大部分构成。理解这四者的协同工作,才能真正掌握高压氢技术的核心。 HPH的核心部件有…

作者头像 李华
网站建设 2026/5/4 2:10:31

【C++ STL】探索STL的奥秘——vector底层的深度剖析和模拟实现!

vector的基本成员变量在模拟实现vector之前我们首先要了解vector的基本成员变量,然后在逐步进入到vector的一些核心接口的实现。如何知道这些成员变量呢?下面通过源码一探究竟:在这里插入图片描述有了上面的认识,那么我们模拟实现…

作者头像 李华
网站建设 2026/5/4 2:08:28

零基础入门:用快马AI生成你的第一个带详解的Python服务器

今天想和大家分享一个特别适合编程新手的实践:用Python Flask搭建最简单的服务器。作为一个刚入门的小白,我发现在InsCode(快马)平台上尝试这个项目特别友好,完全不需要担心环境配置的问题。 为什么选择Flask? Flask是Python最轻量…

作者头像 李华
网站建设 2026/5/4 2:05:29

CVPR 2024审稿人视角:除了创新性,你的论文在这些细节上可能已经丢分了

CVPR 2024审稿人内参:那些被忽视却致命的论文细节陷阱 当你在深夜反复调试模型参数时,可能不会想到论文的页边距会成为审稿人打低分的理由。在计算机视觉顶会CVPR的评审中,每年有37%的论文因非技术性缺陷被降档——这个数字来自对近三年评审数…

作者头像 李华
网站建设 2026/5/4 2:03:28

Navicat学生实用指南

下载与安装Navicat官网提供Windows、macOS和Linux版本下载。访问官网后选择对应操作系统版本,点击下载按钮获取安装包。Windows用户双击安装包,按照向导提示完成安装。macOS用户将Navicat图标拖拽至Applications文件夹即可完成安装。连接数据库启动Navic…

作者头像 李华