DETR目标检测实战:用PyTorch从零搭建你的第一个Transformer检测模型
在计算机视觉领域,目标检测一直是核心任务之一。传统方法如Faster R-CNN、YOLO等依赖复杂的锚框设计和后处理流程,而DETR(Detection Transformer)的出现彻底改变了这一局面。本文将带你从零开始,用PyTorch实现一个完整的DETR模型,涵盖数据准备、模型架构、训练技巧等实战环节。
1. 环境准备与数据加载
1.1 安装必要依赖
首先确保你的环境已安装PyTorch 1.7+和Torchvision 0.8+。推荐使用conda创建虚拟环境:
conda create -n detr python=3.8 conda activate detr pip install torch torchvision torchaudio pip install pycocotools matplotlib tqdm1.2 准备COCO数据集
DETR通常使用COCO数据集进行训练和评估。下载并解压数据集后,目录结构应如下:
coco/ ├── annotations │ ├── instances_train2017.json │ └── instances_val2017.json ├── train2017 │ └── *.jpg └── val2017 └── *.jpg提示:COCO数据集约18GB,确保有足够磁盘空间。也可使用
torchvision.datasets.CocoDetection简化加载过程。
2. 模型架构实现
2.1 Backbone网络
DETR使用ResNet作为特征提取器。以下是简化版的实现:
import torch from torch import nn from torchvision.models import resnet50 class Backbone(nn.Module): def __init__(self, pretrained=True): super().__init__() resnet = resnet50(pretrained=pretrained) self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x2.2 Transformer编码器-解码器
这是DETR的核心组件:
from torch.nn import MultiheadAttention class TransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, nhead=8, dim_feedforward=2048): super().__init__() self.self_attn = MultiheadAttention(d_model, nhead) self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(0.1) def forward(self, src): src2 = self.norm1(src) src2 = self.self_attn(src2, src2, src2)[0] src = src + self.dropout(src2) src2 = self.norm2(src) src2 = self.linear2(F.relu(self.linear1(src2))) src = src + self.dropout(src2) return src3. 完整DETR模型组装
3.1 模型整合
将各组件组合成完整模型:
class DETR(nn.Module): def __init__(self, num_classes=91, num_queries=100): super().__init__() self.backbone = Backbone() self.transformer = nn.Transformer( d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048 ) self.query_embed = nn.Embedding(num_queries, 256) self.class_embed = nn.Linear(256, num_classes + 1) self.bbox_embed = MLP(256, 256, 4, 3) def forward(self, x): features = self.backbone(x) hs = self.transformer(features, self.query_embed.weight) outputs_class = self.class_embed(hs) outputs_coord = self.bbox_embed(hs).sigmoid() return {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}3.2 辅助损失实现
DETR使用匈牙利算法进行预测匹配:
from scipy.optimize import linear_sum_assignment def hungarian_matcher(outputs, targets): bs, num_queries = outputs["pred_logits"].shape[:2] indices = [] for i in range(bs): cost_class = -outputs["pred_logits"][i].softmax(-1)[..., :-1] cost_bbox = torch.cdist(outputs["pred_boxes"][i], targets[i]["boxes"]) cost = cost_class + cost_bbox row_ind, col_ind = linear_sum_assignment(cost.cpu()) indices.append((row_ind, col_ind)) return indices4. 训练流程与技巧
4.1 训练配置
推荐使用以下超参数配置:
| 参数 | 推荐值 | 说明 |
|---|---|---|
| batch_size | 8-16 | 根据GPU显存调整 |
| lr | 1e-4 | 初始学习率 |
| epochs | 50-100 | 训练轮次 |
| weight_decay | 1e-4 | L2正则化 |
4.2 学习率调度
使用warmup策略提升训练稳定性:
from torch.optim.lr_scheduler import LambdaLR def get_lr_scheduler(optimizer, warmup_epochs=10): def lr_lambda(epoch): if epoch < warmup_epochs: return (epoch + 1) / warmup_epochs return 0.1 ** (epoch // 30) return LambdaLR(optimizer, lr_lambda)4.3 数据增强策略
有效的数据增强能显著提升模型性能:
- 随机水平翻转(p=0.5)
- 随机缩放(0.8-1.2倍)
- 随机裁剪(最小IoU=0.3)
- 颜色抖动(亮度=0.2, 对比度=0.2, 饱和度=0.2)
5. 模型评估与优化
5.1 评估指标
使用标准COCO评估指标:
- AP (平均精度)
- AP50 (IoU=0.5时的AP)
- AP75 (IoU=0.75时的AP)
- AP_small (小目标AP)
- AP_medium (中目标AP)
- AP_large (大目标AP)
5.2 常见问题解决
以下是训练中可能遇到的问题及解决方案:
收敛慢:
- 增加warmup周期
- 使用更大的batch size
- 尝试AdamW优化器
小目标检测效果差:
- 增加输入图像分辨率
- 使用多尺度特征融合
- 尝试Deformable DETR变体
显存不足:
- 减小batch size
- 使用梯度累积
- 尝试混合精度训练
# 混合精度训练示例 from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() with autocast(): outputs = model(images) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()6. 推理部署
6.1 模型导出
将训练好的模型导出为TorchScript:
model.eval() traced_model = torch.jit.trace(model, torch.rand(1, 3, 800, 800)) traced_model.save("detr_model.pt")6.2 优化推理速度
提升推理效率的技巧:
- 使用TensorRT加速
- 量化模型(FP16/INT8)
- 剪枝冗余注意力头
- 使用更轻量backbone(如ResNet18)
# 量化示例 quantized_model = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=torch.qint8 )在实际项目中,我发现DETR的端到端特性大大简化了部署流程,特别是在需要动态调整检测目标的场景中表现优异。通过合理调整查询数量(num_queries),可以在精度和速度之间取得良好平衡。