1. 环境准备:从零搭建TransUNet开发环境
第一次接触医学图像分割时,我被各种专业术语和复杂的工具链搞得晕头转向。直到遇到TransUNet这个结合了Transformer和U-Net优势的模型,才发现原来搭建环境可以这么简单。下面我就用最直白的语言,带你一步步避开我踩过的所有坑。
首先需要准备一台配备NVIDIA显卡的电脑(显存建议8GB以上),实测GTX 1080Ti也能跑但会比较吃力。操作系统推荐Ubuntu 20.04,Windows用户可以用WSL2但会有5%左右的性能损耗。我这里以Ubuntu系统为例,Windows用户只需把apt换成对应的包管理命令即可。
打开终端依次执行这些命令:
sudo apt update sudo apt install -y python3.8 python3-pip git nvidia-driver-510安装完基础依赖后,强烈建议先配置conda环境。我遇到过无数次因为包版本冲突导致的神秘bug,用conda能避免99%的这类问题:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh bash Miniconda3-latest-Linux-x86_64.sh创建专属的python环境(我这里用python3.8,因为有些医学图像处理库对更高版本支持不好):
conda create -n transunet python=3.8 conda activate transunet1.1 关键依赖安装指南
接下来安装PyTorch时有个大坑要注意:必须装带CUDA支持的版本!我有次偷懒直接pip install torch,结果训练速度慢了20倍。根据你的CUDA版本选择对应命令(用nvidia-smi查看CUDA版本):
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113然后是TransUNet的核心依赖,这里有个小技巧——先装大包再装小包能避免依赖冲突:
pip install nibabel opencv-python tqdm matplotlib pip install tensorboardX==2.5.1 # 这个版本和最新PyTorch兼容性最好特别提醒:医学图像处理必备的SimpleITK建议用conda安装,用pip容易出问题:
conda install -c simpleitk simpleitk2. 数据集处理实战技巧
拿到原始医学图像数据(通常是.nii或.dcm格式)时,我建议先在3D Slicer这类专业软件里肉眼检查下数据质量。有次我直接开始预处理,后来发现30%的扫描图像存在伪影,白白浪费了两天训练时间。
2.1 NIfTI文件预处理
创建2Ddata目录存放切片结果:
mkdir -p ./2Ddata mkdir -p ./data/train_npz用这个改良版的预处理脚本时,注意修改data_path指向你的原始数据目录。我增加了异常捕获机制,避免因为单个文件损坏导致整个流程中断:
import os from PIL import Image import numpy as np import nibabel as nib def safe_load_nii(file_path): try: img = nib.load(file_path) return img.get_fdata() except: print(f"损坏文件跳过: {file_path}") return None def process_file(file_path): img_data = safe_load_nii(file_path) if img_data is None: return label_path = file_path.replace('_gt.nii.gz', '_label.nii.gz') label_data = safe_load_nii(label_path) if label_data is None: return # 动态计算归一化范围 win_center, win_width = -125, 400 min_val = win_center - win_width//2 max_val = win_center + win_width//2 img_clipped = np.clip(img_data, min_val, max_val) img_normalised = (img_clipped - min_val) / (max_val - min_val) for i in range(img_data.shape[2]): slice_num = f"{i+1:04d}" case_name = os.path.splitext(os.path.basename(file_path))[0].replace("_gt.nii","") # 保存图像和标签 img_slice = (img_normalised[:,:,i] * 255).astype(np.uint8) Image.fromarray(img_slice).convert('L').save(f"./2Ddata/{case_name}_{slice_num}.png") label_slice = label_data[:,:,i].astype(np.uint8) Image.fromarray(label_slice).convert('L').save(f"./2Ddata/{case_name}_{slice_num}_label.png")2.2 生成NPZ训练文件
这个增强版脚本会自动跳过损坏文件,并显示进度条:
from tqdm import tqdm import numpy as np import cv2 import glob def generate_npz(): img_files = [f for f in glob.glob('./2Ddata/*.png') if not f.endswith('_label.png')] for img_path in tqdm(img_files, desc="生成NPZ文件"): try: image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) label_path = img_path.replace('.png', '_label.png') label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) if image is None or label is None: print(f"跳过损坏文件: {img_path}") continue filename = os.path.splitext(os.path.basename(img_path))[0] np.savez(f'./data/train_npz/{filename}.npz', image=image, label=label) except Exception as e: print(f"处理 {img_path} 时出错: {str(e)}")3. 模型训练全流程解析
3.1 下载预训练权重
官方提供的Imagenet预训练模型需要从Google Drive下载,国内用户可能会遇到网络问题。我把它转存到了百度网盘(提取码:trans),下载后放到./model/vit_checkpoint/imagenet21k目录下。
验证下载是否正确:
md5sum ./model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz # 正确输出应为:3da1e512e7676fadf0106a6a7a3a1c2c3.2 启动训练任务
修改train.py中的这些关键参数能显著影响结果:
parser.add_argument('--batch_size', type=int, default=24) # 显存不足时减小这个值 parser.add_argument('--num_epochs', type=int, default=150) parser.add_argument('--lr', type=float, default=3e-4) # 学习率太大容易震荡 parser.add_argument('--img_size', type=int, default=224) # 分辨率越高显存消耗越大启动训练的最佳实践:
nohup python -u train.py --dataset Synapse --vit_name R50-ViT-B_16 > train.log 2>&1 & tail -f train.log # 实时查看日志训练过程中用Tensorboard监控指标:
tensorboard --logdir ./logs --bind_all4. 模型测试与结果分析
4.1 测试集推理
修改test.py中的模型路径后运行:
python test.py --is_savenii --test_save_dir ./test_results遇到显存不足时可以添加--no_cuda参数改用CPU推理(速度会慢10倍左右)。
4.2 结果可视化技巧
用这个脚本可以生成对比图,左边是原图,中间是预测结果,右边是真实标签:
import matplotlib.pyplot as plt def visualize_sample(npz_path, pred_path): data = np.load(npz_path) pred = np.load(pred_path) plt.figure(figsize=(15,5)) plt.subplot(131) plt.imshow(data['image']) plt.title("Input") plt.subplot(132) plt.imshow(pred['predict']) plt.title("Prediction") plt.subplot(133) plt.imshow(data['label']) plt.title("Ground Truth") plt.savefig("comparison.png", dpi=300)最后提醒几个常见坑点:
- 数据集路径中不要有中文或空格
- 训练中断后恢复时记得检查学习率是否重置
- 验证集Dice系数波动大时可以尝试减小学习率
- 遇到CUDA out of memory错误时先尝试减小batch size