news 2026/3/10 4:27:30

面试-Torch函数

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
面试-Torch函数

0. 连续张量和非连续张量

1.核心含义:“连续(contiguous)” 描述的是张量底层数据在内存中的存储方式。
2.连续张量:张量的元素在内存中按“行优先”顺序连续排列,没有间隔,能通过固定步长遍历所有元素;
3.非连续张量:经过transpose()、permute()等操作后,张量的维度顺序变了,但底层数据的存储顺序没改,导致元素在内存中不再连续,遍历需要不规则步长。

用 “书” 举例:

  • 连续张量:书按[0,0]→[0,1]→[0,2]→[1,0]→[1,1]→[1,2]的顺序摆放在一排,没有空隙;
  • 非连续张量(如转置后):维度变成[列,行],但书的摆放顺序还是原来的[0,0]→[0,1]→[0,2]→[1,0]→[1,1]→[1,2],此时要按列取数(如[0,0]→[1,0]→[0,1]→[1,1]),需要跳着找书,内存不连续。

1. torch.view()

核心作用:重塑张量形状,采用方式是 “共享内存”(修改新张量会影响原张量的位置),要求张量是 “连续的(contiguous)”,否则会报错。
特点:不改变原始 x ,通过共享内存的方式改变张量的形状,并且仅支持连续张量。因为view()需要按固定步长重塑维度。

进一步解读:PyTorch 中像view()这类操作,并不会复制张量的底层数据,而是创建一个新的 “视图(view)” —— 新张量和原张量共用同一块内存空间,只是对数据的 “解读方式”(维度、步长)不同。因此,修改新张量的某个元素,原张量对应位置的元素也会同步改变,反之亦然。

importtorch# 原始张量x=torch.randn(2,3)print("原始x shape:",x.shape)# ([2,6])# 重塑x_view=x.view(2,3,3)print("重塑x shape:",x_view.shape)# ([2,3,3])# 验证共享内存x_view[0,0,0]=100.0print("原始x[0,0]:",x[0,0])# tensor(100.)

2. torch.reshape()

核心作用:重塑张量形状,无需张量连续,是更推荐的通用重塑方法。
特点:reshape兼容非连续张量,view仅支持连续张量;功能上几乎等价,新手优先用reshape。

importtorch# 原始张量x=torch.arange(12).reshape(3,4)# torch.Size([3,4])print("x shape:",x.shape)# 重塑为[4,3]x_trans=x.transpose(0,1)# 连续张量->非连续张量x_reshape=x_trans.reshape(4,3)# [3,4] -> [4,3]print("x_reshape:",x_reshape.shape)# 展平x_flat=x_reshape.reshape(-1)# -1 表示自动计算维度print("x_flat shape:",x_flat.shape)

3. torch.triu()

核心作用:提取张量的上三角部分,其余元素置 0;常用来构造因果掩码(如 Transformer 的自注意力)。
特点:提取张量的上三角部分。其中,diagonal(对角线偏移,默认 0,diagonal=1表示主对角线以上的部分)。

importtorch# 原始矩阵:torch.ones(x,y) 创建x=torch.ones(3,3)# 提取上三角部分,(diagonal=1:主对角线以上保留)x_triu=torch.triu(x,diagonal=1)print("x_triu:",x_triu)# 输出:# tensor([[0., 1., 1.],# [0., 0., 1.],# [0., 0., 0.]])# 构造因果掩码矩阵seq_len=3mask=torch.triu(torch.full(seqlen,seqlen),float("-inf"),diagonal=1)print("mask:",mask)# 输出:# tensor([[-inf, -inf, -inf],# [ -inf, -inf, -inf],# [ -inf, -inf, -inf]])

4. torch.full()

核心作用:创建指定形状、所有元素均为固定值的张量;常用于构造掩码(如负无穷、0/1 掩码)。
特点:必须得传入默认值。
参数:size(张量形状)、fill_value(填充值)、device(可选,指定设备)。比 torch.ones(x,y) 和 torch.zeros(x,y) 要更灵活。

importtorch# 创建2x3的全5张量 torch.full((seq, seq), float("-inf"))x_full=torch.full((2,3),5.0)print("x_full:\n",x_full)# 输出:# tensor([[5., 5., 5.],# [5., 5., 5.]])# 创建3x3的全负无穷张量(注意力掩码常用)mask=torch.full((3,3),float("-inf"),device="cpu")print("mask:\n",mask)# 输出:# tensor([[-inf, -inf, -inf],# [-inf, -inf, -inf],# [-inf, -inf, -inf]])

5. torch.transpose()

核心作用:交换张量的两个维度;常用于矩阵转置、调整注意力张量的维度顺序(如[bsz, seq_len, heads]→[bsz, heads, seq_len])。
参数:dim0、dim1(要交换的两个维度索引)。

importtorch# 通过 torch.randn()、torch.full()、torch.ones()创建张量x=torch.randn(1,512,16,1024)x=torch.full((1,512,16,1024),"float(-inf)")x=torch.full((1,512,16,1024))# 报错,torch.full(tensor, value)必须得同时传入默认值、张量两个元素x=torch.ones(1,512,16,1024)print("x:",x)# torch.Size([bsz, seq, heads, dim])

6. torch.cat()

核心作用:指定维度上拼接多个张量;要求除拼接维度外,其他维度形状完全一致。
参数:tensors(待拼接的张量列表)、dim(拼接维度)。

importtorch# 通过 torch.randn()、torch.full()、torch.ones()创建张量x=torch.randn(1,512,16,1024)x=torch.full((1,512,16,1024),"float(-inf)")x=torch.full((1,512,16,1024))# 报错,torch.full(tensor, value)必须得同时传入默认值、张量两个元素x=torch.ones(1,512,16,1024)print("x:",x)# [bsz, seq, heads, dim]x2=torch.randn(1,512,8,1024)# 在维度2上进行拼接x_cat=torch.cat([x1,x2],dim=2)print("x_cat shape:",x_cat.shape)# torch.Size([1,512,24,1024])# 注意力 KV 缓存拼接past_kv=torch.randn(1,10,1024)# [bsz, seq, dim],这里seq代表已经处理了 10 个kv健cur_kv=torch.randn(1,1,1024)# 当前 kv 键值对new_kv=torch.cat([past_kv,new_kv],dim=1)print("new_kv cache:",new_kv)# torch.cat([a, b], dim=c):torch.Size([1, 11, 1024])

7. torch.arange()

核心作用:创建连续整数序列的一维张量;常用于生成索引、位置编码等。
特点:torch.arange() 是根据步长来生成张量的,没有默认值,只能生成一维张量;torch.full() 能生成任意维度张量,且支持默认值;torch.randn() 随机生成指定维度的张量,不支持默认值。
参数:start(起始值,默认 0)、end(结束值,不包含)、step(步长,默认 1)。

# 生成0到9的整数:[0,1,2,...,9]x1=torch.arange(10)print("x1:",x1)# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])# 生成1到9,步长2:[1,3,5,7,9]x2=torch.arange(1,10,2)print("x2:",x2)# tensor([1, 3, 5, 7, 9])# 结合size()使用:生成与张量某维度长度匹配的索引x=torch.randn(2,5,8)# 生成0到x.size(1)-1的索引(x.size(1)=5)idx=torch.arange(x.size(1))print("idx:",idx)# tensor([0, 1, 2, 3, 4])

8. tensor.size() / tensor.shape

核心作用:获取张量的形状信息;size()是方法,shape是属性,功能几乎等价。
参数:dim(可选,指定维度索引,返回该维度的长度;不指定则返回 torch.Size 对象)。如 x.size(0) 代表张量中第一个维度的大小。

x=torch.randn(2,3,4)# 获取整体形状print("x.size():",x.size())# torch.Size([2, 3, 4])print("x.shape:",x.shape)# torch.Size([2, 3, 4])# 获取指定维度的长度print("维度0长度:",x.size(0))# 2(批次大小)print("维度1长度:",x.size(1))# 3(序列长度)print("维度2长度:",x.size(2))# 4(特征维度)# 解包形状(常用操作)bsz,seq_len,hidden_dim=x.size()print(f"批次:{bsz}, 序列长度:{seq_len}, 特征维度:{hidden_dim}")# 批次:2, 序列长度:3, 特征维度:4

9. torch.unsqueeze() / torch.squeeze()

核心作用:插入和删除指定维度,插入和删除的维度的长度为1.
torch.unsqueeze(tensor, dim):在指定维度插入一个维度(维度长度为 1),常用于扩展掩码维度;
torch.squeeze(tensor, dim):删除长度为 1 的维度,简化张量形状。

# unsqueeze:扩展维度(注意力掩码常用)mask=torch.randn(2,3)# torch.Size([2,3])# 插入维度1和2:shape [2,1,1,3](匹配注意力分数维度)mask_unsq=mask.unsqueeze(1).unsqueeze(2)print("mask_unsq shape:",mask_unsq.shape)# torch.Size([2, 1, 1, 3])# squeeze:删除长度为1的维度x=torch.randn(2,1,3,1)x_sq=x.squeeze()# 删除所有长度为1的维度print("x_sq shape:",x_sq.shape)# torch.Size([2, 3])

总结:

  • 形状调整:reshape(通用)、view(共享内存)是核心,优先用reshape;size()/shape用于获取形状信息。
  • 维度操作:transpose(交换维度)、unsqueeze/squeeze(增 / 删维度)、cat(拼接张量)是维度调整高频函数。
  • 特殊张量创建:arange(生成序列)、full(固定值张量)、triu(上三角矩阵)常用于掩码、索引构造。
  • 记忆要点:cat要求非拼接维度形状一致;triu(diagonal=1)是 Transformer 因果掩码的核心;unsqueeze是扩展掩码维度的常用操作。
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/6 4:30:24

Excel分类汇总完全指南:从数据分析到分页打印的专业应用

📊 第一章:分类汇总基础概念与原理 1.1 什么是分类汇总? 分类汇总是Excel中用于对数据按类别进行统计分析的强大功能。它能够: 自动识别数据类别并进行分组 对每个分组执行指定的计算(求和、平均值、计数等&#xf…

作者头像 李华
网站建设 2026/3/3 17:17:29

一遍搞定全流程!专科生专属AI论文神器 —— 千笔·专业论文写作工具

你是否在论文写作中感到力不从心?选题无头绪、资料难查找、格式总出错、查重率高得让人焦虑……这些难题是否让你夜不能寐?别再独自挣扎,现在有了更聪明的解决方案——千笔AI。它专为专科生量身打造,从选题到查重,一站…

作者头像 李华
网站建设 2026/3/9 11:27:20

Python Pydantic库深度解析

Pydantic是一个在Python生态中广泛使用的库,特别在Flask开发中,它帮助处理数据验证和配置管理。下面从五个方面详细讲解Pydantic。1. 它是什么Pydantic是一个基于Python类型注解的库,用于数据验证和设置管理。它允许你通过定义类来描述数据的…

作者头像 李华
网站建设 2026/3/4 8:27:57

实测才敢推!专科生专属降AIGC网站 —— 千笔

在AI技术深度渗透学术写作的当下,越来越多的学生开始依赖AI工具辅助完成论文、报告等学术内容。然而,随着查重系统对AI生成内容的识别能力不断提升,如何有效降低AI率和重复率成为摆在学生面前的难题。面对市场上琳琅满目的降AI率与降重复率工…

作者头像 李华
网站建设 2026/3/4 9:17:03

python python-jose库,深度解析

1. 它是什么 python-jose 是一个用于处理 JWT(JSON Web Token)的 Python 库。JWT 可以理解为一种数字“通行证”,它允许在不同系统之间安全地传递信息,就像现实生活中的证件(如身份证)包含了你的基本信息且…

作者头像 李华