可视化
import torch import numpy as np import matplotlib.pyplot as plt import seaborn as sns from typing import Optional, List import os def visualize_attention_distribution( attentions, input_ids, processor, gt_start_frame, gt_end_frame, query_text, video_id: str, save_dir: str = "/home/share/svmd5vm0/home/scut_czy1/attn_map", show_all_layers: bool = True, figsize: tuple = (20, 12), ): """ 可视化query对各帧的注意力分布 Args: attentions: 模型输出的注意力 tuple of (batch, num_heads, seq_len, seq_len) input_ids: 输入token ids processor: tokenizer processor gt_start_frame: 真实起始帧 gt_end_frame: 真实结束帧 query_text: 查询文本 video_id: 视频ID,用于保存文件名 save_dir: 保存目录 show_all_layers: 是否显示所有层的注意力 figsize: 图表大小 """ os.makedirs(save_dir, exist_ok=True) # 1. 获取特殊token的ID vision_start_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_start|>') vision_end_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_end|>') # 2. 定位query token的位置 input_ids_list = input_ids[0].tolist() query = query_text.strip() if query.endswith('.'): query = query[:-1] query_ids = processor.tokenizer(query, add_special_tokens=False)["input_ids"] query_start_idx = None query_end_idx = None for i in range(len(input_ids_list) - len(query_ids) + 1): if input_ids_list[i:i + len(query_ids)] == query_ids: query_start_idx = i query_end_idx = i + len(query_ids) - 1 break if query_start_idx is None: print(f"Warning: Query tokens not found for video {video_id}") return # 3. 定位每一帧的vision token位置 vision_start_indices = [i for i, x in enumerate(input_ids_list) if x == vision_start_token_id] vision_end_indices = [i for i, x in enumerate(input_ids_list) if x == vision_end_token_id] num_frames = len(vision_start_indices) num_layers = len(attentions) if num_frames == 0: print(f"Warning: No vision tokens found for video {video_id}") return gt_end_frame = min(gt_end_frame, num_frames - 1) # 4. 提取每一层、每一帧的注意力分数 # layer_frame_attention: [num_layers, num_frames] layer_frame_attention = [] for layer_idx in range(num_layers): frame_scores = [] layer_attn = attentions[layer_idx][0] # [num_heads, seq_len, seq_len] for frame_idx in range(num_frames): v_start = vision_start_indices[frame_idx] v_end = vision_end_indices[frame_idx] # 提取 query tokens -> 该帧vision tokens 的注意力 query_to_frame_attn = layer_attn[:, query_start_idx:query_end_idx+1, v_start+1:v_end] # 对所有头、query tokens、vision patches取平均 frame_score = query_to_frame_attn.mean().item() frame_scores.append(frame_score) layer_frame_attention.append(frame_scores) layer_frame_attention = np.array(layer_frame_attention) # [num_layers, num_frames] # 5. 计算平均注意力(所有层平均) avg_attention = layer_frame_attention.mean(axis=0) # [num_frames] # 6. 创建可视化 if show_all_layers and num_layers > 1: fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3) # ========== 图1: 所有层的注意力热力图 ========== ax1 = fig.add_subplot(gs[0, :]) im = ax1.imshow(layer_frame_attention, aspect='auto', cmap='YlOrRd', interpolation='nearest') ax1.set_xlabel('Frame Index', fontsize=12, fontweight='bold') ax1.set_ylabel('Layer Index', fontsize=12, fontweight='bold') ax1.set_title(f'Attention Heatmap: Query → Frames (All Layers)\nQuery: "{query}"', fontsize=14, fontweight='bold', pad=20) # 标记目标区域 ax1.axvline(x=gt_start_frame-0.5, color='blue', linestyle='--', linewidth=2, label='GT Start') ax1.axvline(x=gt_end_frame+0.5, color='blue', linestyle='--', linewidth=2, label='GT End') # 添加颜色条 cbar = plt.colorbar(im, ax=ax1) cbar.set_label('Attention Score', fontsize=10, fontweight='bold') ax1.legend(loc='upper right') # ========== 图2: 平均注意力柱状图 ========== ax2 = fig.add_subplot(gs[1, :]) frames = np.arange(num_frames) colors = ['lightcoral' if gt_start_frame <= i <= gt_end_frame else 'lightblue' for i in range(num_frames)] bars = ax2.bar(frames, avg_attention, color=colors, edgecolor='black', linewidth=0.5) # 高亮目标帧 for i in range(gt_start_frame, gt_end_frame + 1): bars[i].set_edgecolor('red') bars[i].set_linewidth(2) ax2.set_xlabel('Frame Index', fontsize=12, fontweight='bold') ax2.set_ylabel('Average Attention Score', fontsize=12, fontweight='bold') ax2.set_title('Average Attention Distribution (All Layers & Heads)', fontsize=14, fontweight='bold', pad=15) ax2.grid(axis='y', alpha=0.3, linestyle='--') # 添加目标区域标注 ax2.axvspan(gt_start_frame-0.5, gt_end_frame+0.5, alpha=0.2, color='red', label=f'GT Frames [{gt_start_frame}, {gt_end_frame}]') ax2.legend(loc='upper right') # ========== 图3: 目标帧 vs 非目标帧的注意力对比 ========== ax3 = fig.add_subplot(gs[2, 0]) target_attention = avg_attention[gt_start_frame:gt_end_frame+1] non_target_mask = np.ones(num_frames, dtype=bool) non_target_mask[gt_start_frame:gt_end_frame+1] = False non_target_attention = avg_attention[non_target_mask] comparison_data = [target_attention, non_target_attention] box = ax3.boxplot(comparison_data, labels=['Target Frames', 'Non-Target Frames'], patch_artist=True, showmeans=True) box['boxes'][0].set_facecolor('lightcoral') box['boxes'][1].set_facecolor('lightblue') ax3.set_ylabel('Attention Score', fontsize=12, fontweight='bold') ax3.set_title('Target vs Non-Target Frames', fontsize=13, fontweight='bold', pad=15) ax3.grid(axis='y', alpha=0.3, linestyle='--') # 添加统计信息 target_mean = target_attention.mean() non_target_mean = non_target_attention.mean() ratio = target_mean / (non_target_mean + 1e-7) stats_text = f'Target Mean: {target_mean:.4f}\n' stats_text += f'Non-Target Mean: {non_target_mean:.4f}\n' stats_text += f'Ratio: {ratio:.2f}x' ax3.text(0.02, 0.98, stats_text, transform=ax3.transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) # ========== 图4: 逐层注意力趋势 ========== ax4 = fig.add_subplot(gs[2, 1]) layer_target_mean = [] layer_non_target_mean = [] for layer_idx in range(num_layers): target_mean = layer_frame_attention[layer_idx, gt_start_frame:gt_end_frame+1].mean() non_target_mean = layer_frame_attention[layer_idx, non_target_mask].mean() layer_target_mean.append(target_mean) layer_non_target_mean.append(non_target_mean) layers = np.arange(num_layers) ax4.plot(layers, layer_target_mean, 'o-', color='red', linewidth=2, markersize=6, label='Target Frames') ax4.plot(layers, layer_non_target_mean, 's-', color='blue', linewidth=2, markersize=6, label='Non-Target Frames') ax4.set_xlabel('Layer Index', fontsize=12, fontweight='bold') ax4.set_ylabel('Mean Attention Score', fontsize=12, fontweight='bold') ax4.set_title('Layer-wise Attention Trend', fontsize=13, fontweight='bold', pad=15) ax4.legend(loc='best') ax4.grid(alpha=0.3, linestyle='--') else: # 简化版:只显示平均注意力 fig, ax = plt.subplots(figsize=(12, 6)) frames = np.arange(num_frames) colors = ['lightcoral' if gt_start_frame <= i <= gt_end_frame else 'lightblue' for i in range(num_frames)] bars = ax.bar(frames, avg_attention, color=colors, edgecolor='black', linewidth=0.5) for i in range(gt_start_frame, gt_end_frame + 1): bars[i].set_edgecolor('red') bars[i].set_linewidth(2) ax.set_xlabel('Frame Index', fontsize=12, fontweight='bold') ax.set_ylabel('Average Attention Score', fontsize=12, fontweight='bold') ax.set_title(f'Attention Distribution\nQuery: "{query}"', fontsize=14, fontweight='bold', pad=20) ax.grid(axis='y', alpha=0.3, linestyle='--') ax.axvspan(gt_start_frame-0.5, gt_end_frame+0.5, alpha=0.2, color='red', label=f'GT Frames [{gt_start_frame}, {gt_end_frame}]') ax.legend(loc='upper right') # 7. 保存图表 save_path = os.path.join(save_dir, f"{video_id}_attention_distribution.png") plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Saved attention visualization to: {save_path}") plt.close() # 8. 保存数值数据(CSV) save_data_path = os.path.join(save_dir, f"{video_id}_attention_data.npz") np.savez( save_data_path, layer_frame_attention=layer_frame_attention, avg_attention=avg_attention, gt_start_frame=gt_start_frame, gt_end_frame=gt_end_frame, query=query ) print(f"Saved attention data to: {save_data_path}") # 9. 返回统计信息 target_attention = avg_attention[gt_start_frame:gt_end_frame+1] non_target_mask = np.ones(num_frames, dtype=bool) non_target_mask[gt_start_frame:gt_end_frame+1] = False non_target_attention = avg_attention[non_target_mask] stats = { 'video_id': video_id, 'query': query, 'num_frames': num_frames, 'num_layers': num_layers, 'gt_range': (gt_start_frame, gt_end_frame), 'target_attention_mean': float(target_attention.mean()), 'target_attention_std': float(target_attention.std()), 'non_target_attention_mean': float(non_target_attention.mean()), 'non_target_attention_std': float(non_target_attention.std()), 'attention_ratio': float(target_attention.mean() / (non_target_attention.mean() + 1e-7)), 'attention_concentration': float(target_attention.sum() / avg_attention.sum()), } return stats def batch_visualize_attention( model, processor, data_list: List[dict], save_dir: str = "/home/share/svmd5vm0/home/scut_czy1/attn_map", device: str = "cuda", ): """ 批量处理多个视频的注意力可视化 Args: model: 模型 processor: processor data_list: 数据列表,每个元素包含: - video_path: 视频路径 - query: 查询文本 - start_frame: 起始帧 - end_frame: 结束帧 - video_id: 视频ID save_dir: 保存目录 device: 设备 """ model.eval() all_stats = [] for data in data_list: print(f"\nProcessing video: {data['video_id']}") # 准备输入 messages = [ { "role": "user", "content": [ { "type": "video", "video": data['video_path'], "fps": 1 }, {"type": "text", "text": data['query']}, ], } ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) inputs = inputs.to(device) # 前向传播(获取注意力) with torch.no_grad(): outputs = model(**inputs, output_attentions=True) # 可视化 stats = visualize_attention_distribution( attentions=outputs.attentions, input_ids=inputs['input_ids'], processor=processor, gt_start_frame=data['start_frame'], gt_end_frame=data['end_frame'], query_text=data['query'], video_id=data['video_id'], save_dir=save_dir, ) all_stats.append(stats) # 保存所有统计信息 import json stats_path = os.path.join(save_dir, "all_stats.json") with open(stats_path, 'w') as f: json.dump(all_stats, f, indent=4) print(f"\nSaved all statistics to: {stats_path}") return all_stats # ========== 使用示例 ========== if __name__ == "__main__": """ 使用示例 """ # 示例1: 单个视频可视化 from transformers import Qwen3VLForConditionalGeneration, AutoProcessor model = Qwen3VLForConditionalGeneration.from_pretrained( "/home/share/svmd5vm0/home/scut_czy1/Qwen3-VL-2B-Instruct", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager" ) processor = AutoProcessor.from_pretrained("/home/share/svmd5vm0/home/scut_czy1/Qwen3-VL-2B-Instruct") query_text = "A person is reading a book" # 准备输入 messages = [{ "role": "user", "content": [ {"type": "video", "video": "/home/share/svmd5vm0/home/scut_czy1/datasets/Charadesfps/videos_1FPS/0A8CF.mp4", "fps": 1}, {"type": "text", "text": query_text}, ], }] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to("cuda") # 获取输出(带注意力) with torch.no_grad(): outputs = model(**inputs, output_attentions=True) # 可视化 stats = visualize_attention_distribution( attentions=outputs.attentions, input_ids=inputs['input_ids'], processor=processor, gt_start_frame=5, gt_end_frame=9, query_text= query_text, video_id="video_001", save_dir="/home/share/svmd5vm0/home/scut_czy1/attn_map" ) print("Statistics:", stats) # 示例2: 批量处理 """ data_list = [ { 'video_path': 'video1.mp4', 'query': 'person drinking water', 'start_frame': 5, 'end_frame': 9, 'video_id': 'video_001' }, { 'video_path': 'video2.mp4', 'query': 'person opening door', 'start_frame': 10, 'end_frame': 15, 'video_id': 'video_002' }, ] all_stats = batch_visualize_attention( model=model, processor=processor, data_list=data_list, save_dir="./visualizations" ) """ print("可视化工具已准备就绪!")