news 2026/5/5 5:26:40

Transformer的注意力权重的理解

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Transformer的注意力权重的理解
""" Transformer 注意力权重分析工具 详细解析注意力矩阵的含义和使用方法 """ import torch import torch.nn as nn import numpy as np import math # ============================================================ # 简化的多头注意力(用于演示) # ============================================================ class SimpleMultiHeadAttention(nn.Module): """简化的多头注意力,便于理解""" def __init__(self, d_model, num_heads): super().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.attention_weights = None def split_heads(self, x): batch_size, seq_len, d_model = x.size() x = x.view(batch_size, seq_len, self.num_heads, self.d_k) return x.transpose(1, 2) def forward(self, query, key, value, mask=None): batch_size = query.size(0) Q = self.W_q(query) K = self.W_k(key) V = self.W_v(value) Q = self.split_heads(Q) # [batch, num_heads, seq_len_q, d_k] K = self.split_heads(K) # [batch, num_heads, seq_len_k, d_k] V = self.split_heads(V) # [batch, num_heads, seq_len_v, d_k] # 核心:计算注意力分数 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # 这就是注意力权重矩阵! attention_weights = torch.softmax(scores, dim=-1) self.attention_weights = attention_weights context = torch.matmul(attention_weights, V) context = context.transpose(1, 2).contiguous() context = context.view(batch_size, -1, self.d_model) output = self.W_o(context) return output, attention_weights # ============================================================ # 注意力权重分析器 # ============================================================ class AttentionAnalyzer: """注意力权重分析工具""" @staticmethod def print_separator(title): print("\n" + "="*70) print(f" {title}") print("="*70) @staticmethod def analyze_attention_weights(attention_weights, token_names=None): """ 详细分析注意力权重矩阵 attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k] """ batch_size, num_heads, seq_len_q, seq_len_k = attention_weights.shape AttentionAnalyzer.print_separator("注意力权重矩阵结构分析") print(f"📊 注意力权重矩阵的形状: {attention_weights.shape}") print(f" - batch_size (批次大小): {batch_size}") print(f" - num_heads (注意力头数): {num_heads}") print(f" - seq_len_q (查询序列长度): {seq_len_q}") print(f" - seq_len_k (键序列长度): {seq_len_k}") print(f"\n💡 矩阵含义解释:") print(f" attention_weights[batch_i, head_j, pos_q, pos_k] 表示:") print(f" → 在第 i 个样本中") print(f" → 在第 j 个注意力头中") print(f" → 位置 pos_q 的token") print(f" → 对位置 pos_k 的token的注意力权重") # 生成默认token名称 if token_names is None: token_names = [f"Token_{i}" for i in range(seq_len_q)] return token_names @staticmethod def get_token_attention(attention_weights, token_position, batch_idx=0, head_idx=0): """ 获取指定token对其他所有token的注意力权重 参数: attention_weights: [batch_size, num_heads, seq_len_q, seq_len_k] token_position: 要查询的token位置(从0开始) batch_idx: 批次索引 head_idx: 注意力头索引 返回: 该token对所有其他token的注意力分布 """ # 提取该token的注意力分布 token_attn = attention_weights[batch_idx, head_idx, token_position, :] return token_attn @staticmethod def visualize_first_token_attention(attention_weights, token_names=None, batch_idx=0, head_idx=0): """可视化第一个token的注意力分布""" AttentionAnalyzer.print_separator("第一个Token的注意力分析") # 获取第一个token的注意力 first_token_attn = AttentionAnalyzer.get_token_attention( attention_weights, token_position=0, batch_idx=batch_idx, head_idx=head_idx ) seq_len = first_token_attn.size(0) if token_names is None: token_names = [f"Token_{i}" for i in range(seq_len)] print(f"🎯 第一个Token (位置0) 的注意力分布:") print(f" 批次: {batch_idx}, 注意力头: {head_idx}") print(f"\n{'位置':<8} {'Token名称':<15} {'注意力权重':<15} {'百分比':<10} {'可视化'}") print("-" * 70) attn_numpy = first_token_attn.detach().numpy() for i, (token_name, attn_value) in enumerate(zip(token_names, attn_numpy)): bar_length = int(attn_value * 50) # 最大50个字符 bar = "█" * bar_length percentage = attn_value * 100 print(f"{i:<8} {token_name:<15} {attn_value:<15.4f} {percentage:<10.2f}% {bar}") print(f"\n✅ 验证: 所有权重之和 = {attn_numpy.sum():.6f} (应该 ≈ 1.0)") return first_token_attn @staticmethod def compare_all_heads(attention_weights, token_position=0, batch_idx=0): """比较不同注意力头中同一token的注意力分布""" AttentionAnalyzer.print_separator(f"Token {token_position} 在所有注意力头中的表现") num_heads = attention_weights.size(1) seq_len = attention_weights.size(3) print(f"📊 对比 {num_heads} 个注意力头\n") for head_idx in range(num_heads): token_attn = AttentionAnalyzer.get_token_attention( attention_weights, token_position=token_position, batch_idx=batch_idx, head_idx=head_idx ) attn_numpy = token_attn.detach().numpy() # 找出最关注的位置 max_idx = np.argmax(attn_numpy) max_value = attn_numpy[max_idx] print(f"Head {head_idx}: 最关注位置 {max_idx} (权重={max_value:.4f})") # 显示top-3关注的位置 top3_indices = np.argsort(attn_numpy)[-3:][::-1] top3_values = attn_numpy[top3_indices] print(f" Top-3: ", end="") for idx, val in zip(top3_indices, top3_values): print(f"位置{idx}({val:.3f}) ", end="") print("\n") @staticmethod def create_attention_heatmap_text(attention_weights, batch_idx=0, head_idx=0, token_names=None): """创建文本形式的注意力热力图""" AttentionAnalyzer.print_separator("注意力权重热力图(文本版)") attn_matrix = attention_weights[batch_idx, head_idx].detach().numpy() seq_len = attn_matrix.shape[0] if token_names is None: token_names = [f"T{i}" for i in range(seq_len)] print(f"\n批次 {batch_idx}, 注意力头 {head_idx}") print(f"每行表示一个query token对所有key token的注意力\n") # 打印列标题 print("Query\\Key ", end="") for name in token_names: print(f"{name:>8}", end="") print() print("-" * (12 + 8 * seq_len)) # 打印每一行 for i, row_name in enumerate(token_names): print(f"{row_name:<10} ", end="") for j in range(seq_len): value = attn_matrix[i, j] # 使用不同符号表示权重大小 if value > 0.5: symbol = "██" elif value > 0.3: symbol = "▓▓" elif value > 0.1: symbol = "▒▒" elif value > 0.05: symbol = "░░" else: symbol = " " print(f"{symbol:>6}", end=" ") print(f" {row_name}") # ============================================================ # 完整示例 # ============================================================ def demo_attention_analysis(): """完整的注意力分析演示""" print("\n" + "🚀 "*30) print("Transformer 注意力权重矩阵详解") print("🚀 "*30) # ============================================================ # 1. 创建模型和数据 # ============================================================ AttentionAnalyzer.print_separator("1. 准备模型和数据") d_model = 64 num_heads = 4 batch_size = 2 seq_len = 6 print(f"模型配置:") print(f" - d_model: {d_model}") print(f" - num_heads: {num_heads}") print(f" - 序列长度: {seq_len}") # 创建注意力模块 attention = SimpleMultiHeadAttention(d_model, num_heads) attention.eval() # 创建输入数据 x = torch.randn(batch_size, seq_len, d_model) # 定义token名称(方便理解) token_names = ["我", "爱", "学习", "深度", "学习", "<EOS>"] print(f"\n输入序列: {' '.join(token_names)}") # ============================================================ # 2. 前向传播 # ============================================================ AttentionAnalyzer.print_separator("2. 执行注意力计算") with torch.no_grad(): output, attention_weights = attention(x, x, x) print(f"输出形状: {output.shape}") print(f"注意力权重形状: {attention_weights.shape}") print(f" → [batch_size={batch_size}, num_heads={num_heads}, " f"seq_len_q={seq_len}, seq_len_k={seq_len}]") # ============================================================ # 3. 分析注意力权重结构 # ============================================================ AttentionAnalyzer.analyze_attention_weights(attention_weights, token_names) # ============================================================ # 4. 分析第一个token的注意力 # ============================================================ first_token_attn = AttentionAnalyzer.visualize_first_token_attention( attention_weights, token_names=token_names, batch_idx=0, head_idx=0 ) # ============================================================ # 5. 比较不同注意力头 # ============================================================ AttentionAnalyzer.compare_all_heads( attention_weights, token_position=0, batch_idx=0 ) # ============================================================ # 6. 创建热力图 # ============================================================ AttentionAnalyzer.create_attention_heatmap_text( attention_weights, batch_idx=0, head_idx=0, token_names=token_names ) # ============================================================ # 7. 实用代码示例 # ============================================================ AttentionAnalyzer.print_separator("7. 实用代码示例") print("\n💻 如何在你的代码中使用:") print(""" # 方法1: 直接索引获取第一个token的注意力 first_token_attention = attention_weights[0, 0, 0, :] # 含义: 批次0, 头0, 位置0的token对所有token的注意力 # 方法2: 获取任意token的注意力 token_position = 2 # 第3个token token_attention = attention_weights[0, 0, token_position, :] # 方法3: 获取所有token之间的注意力矩阵 full_attention_matrix = attention_weights[0, 0, :, :] # shape: [seq_len, seq_len] # 方法4: 平均所有注意力头 avg_attention = attention_weights.mean(dim=1) # [batch, seq_len, seq_len] # 方法5: 查看token i 对 token j 的注意力 i, j = 0, 3 # 第1个token对第4个token的注意力 attn_value = attention_weights[0, 0, i, j].item() print(f"Token {i} → Token {j} 的注意力: {attn_value:.4f}") """) # ============================================================ # 8. 关键概念总结 # ============================================================ AttentionAnalyzer.print_separator("8. 关键概念总结") print(""" 🎯 注意力权重矩阵的核心理解: 1. 矩阵维度: [batch_size, num_heads, seq_len_q, seq_len_k] 2. 含义: attention_weights[b, h, i, j] 表示 → 第b个样本 → 第h个注意力头 → 第i个位置的token → 对第j个位置的token的注意力强度 3. 每一行的和 = 1.0 (因为经过了softmax) → attention_weights[b, h, i, :].sum() ≈ 1.0 4. Self-Attention中: seq_len_q == seq_len_k → 矩阵是方阵 → 对角线表示token对自己的注意力 5. 不同的注意力头会学习不同的模式 → 有的头关注局部信息 → 有的头关注长距离依赖 → 有的头关注句法结构 6. 在你原始代码的位置: → EncoderLayer.forward() 返回的 attn_weights → DecoderLayer.forward() 返回的 self_attn_weights 和 cross_attn_weights → MultiHeadAttention 类中的 self.attention_weights """) print("\n✅ 现在你应该完全理解注意力权重矩阵了!") print("💡 建议: 在调试器中逐步执行,观察矩阵的变化") return attention_weights # ============================================================ # 运行演示 # ============================================================ if __name__ == "__main__": attention_weights = demo_attention_analysis()
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/2 7:31:51

基于SpringBoot的车辆报废回收系统(毕业设计项目源码+文档)

课题摘要 在机动车报废回收行业规范化、数字化升级的背景下&#xff0c;传统车辆报废回收模式存在 “流程审批繁琐、车辆溯源难、数据统计滞后、监管透明度低” 的痛点&#xff0c;难以满足车主便捷报废、企业高效运营、监管部门精准管控的需求。基于 SpringBoot 的车辆报废回收…

作者头像 李华
网站建设 2026/5/1 6:46:46

租用日本服务器价格便宜的原因

在 2026 年的海外服务器租赁市场中&#xff0c;日本服务器呈现出 “高配置 低门槛” 的独特优势&#xff0c;更关键的是&#xff0c;低价并未牺牲核心品质 ——90% 以上服务商提供 NTT/KDDI 原生 IP、CN2 GIA 直连线路&#xff0c;稳定性与纯净度远超同价位其他地区服务器。这…

作者头像 李华
网站建设 2026/5/1 20:04:19

数据结构:广义表

广义表 资料&#xff1a;https://pan.quark.cn/s/43d906ddfa1b、https://pan.quark.cn/s/90ad8fba8347、https://pan.quark.cn/s/d9d72152d3cf 一、广义表的定义 广义表&#xff08;Generalized List&#xff09;是线性表的扩展&#xff0c;是由零个或多个原子&#xff08;Atom…

作者头像 李华
网站建设 2026/4/28 6:48:02

Linux进程间通信内存映射(mmap)实现篇

Linux 内核中 mmap 的实现(基于 2.6.12) 概述 基于 2.6.12 内核, 说明 mmap 系统调用的核心数据结构、系统调用路径及关键实现. 主要文件: mm/mmap.c、mm/msync.c、mm/filemap.c、include/linux/mm.h、include/linux/mman.h. 核心数据结构 mm_struct (进程地址空间描述符) // i…

作者头像 李华
网站建设 2026/4/28 15:39:06

a5 4444444444

444444444444444444

作者头像 李华