049、Focal Loss 核心思想:从 cross-entropy 到 α-balanced Focal Loss 的推导
一、一个让我失眠的调试问题
去年做工业缺陷检测项目,正样本(缺陷)占比不到0.1%,负样本(正常产品)占99.9%。用标准交叉熵训练YOLOv5,模型收敛后mAP@0.5只有0.12——它学会了“永远输出负样本”。我试过过采样、欠采样、数据增强,效果都像在泥潭里挣扎。
直到我在RetinaNet论文里看到Focal Loss,才意识到问题本质:交叉熵对易分类样本的梯度贡献太大,淹没了难分类样本的信号。这篇文章就带你从数学推导到代码实现,彻底搞懂Focal Loss为什么能解决这个问题。
二、交叉熵的“公平”其实是“不公平”
先看二分类交叉熵的标准形式:
CE(p, y) = -y * log(p) - (1-y) * log(1-p)其中y∈{0,1}是真实标签,p∈[0,1]是模型预测为正类的概率。
为了方便推导,定义pt:
pt = p if y=1 else 1-p这样交叉熵可以简写为:
CE(p, y) = -log(pt)问题出在哪里?假设一个负样本(y=0),模型预测p=0.9(即pt=0.1),这个样本被正确分类的概率只有10%,属于“难分类样本”。但交叉熵给它的损失是 -log(0.1) ≈ 2.3。再看一个负样本,模型预测p=0.01(pt=0.99),这是“易分类样本”,损失是 -log(0.99) ≈ 0.01。
关键点来了:易分类样本的损失虽然小,但它们的数量是难分类样本的成千上万倍。累加起来,易分类样本的总损失占据了主导地位,梯度更新方向被它们牵着鼻子走。这就是类别不平衡问题的本质——不是正负样本数量不平衡,而是“易分类样本”和“难分类样本”的梯度贡献不平衡。
三、Focal Loss的直觉:让模型“聚焦”难样本
Focal Loss的核心修改只有一行公式:
FL(pt) = -(1-pt)^γ * log(pt)对比交叉熵,多了一个调制因子 (1-pt)^γ。
这个因子做了什么?当pt接近1(易分类样本),(1-pt)γ接近0,损失被大幅压低。当pt接近0(难分类样本),(1-pt)γ接近1,损失几乎不变。
γ是聚焦参数,论文推荐γ=2。我实际调参的经验是:γ=0退化为交叉熵,γ=1效果不明显,γ=2~3效果最好,γ>5会导致训练不稳定(这里踩过坑,梯度消失严重)。
举个例子:γ=2时,易分类样本(pt=0.99)的损失从0.01降为0.01*(0.01)2=1e-6,几乎被忽略。难分类样本(pt=0.1)的损失从2.3降为2.3*(0.9)2≈1.86,只降低了20%。这样模型就会把注意力集中在那些“模棱两可”的样本上。
四、α-balanced Focal Loss:再加一层保险
Focal Loss解决了“难易样本”问题,但没解决“正负样本”问题。实际场景中,正样本往往既是“少数”又是“难分类”的。如果只用Focal Loss,模型可能过度关注负样本中的难分类样本(比如背景中的噪声),而忽略正样本。
解决方案是引入α平衡因子:
FL(pt) = -α_t * (1-pt)^γ * log(pt)其中α_t的定义和pt类似:
α_t = α if y=1 else 1-αα通常取0.25~0.75,具体值取决于正负样本比例。我的经验公式:α = 1 / (1 + 正负样本比)。比如正负样本比1:1000,α≈0.001。但别直接套用这个公式,它只是初始值,最终需要调参。
注意:α和γ不是独立参数。α控制正负样本的权重分配,γ控制难易样本的权重分配。两者协同工作:α让模型“看到”正样本,γ让模型“聚焦”难样本。
五、PyTorch实现:别踩这些坑
直接上代码,注释里写满了我的血泪史:
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassFocalLoss(nn.Module):def__init__(self,alpha=0.25,gamma=2.0,reduction='mean'):super().__init__()self.alpha=alpha# 正样本权重,别设成0.5,那是交叉熵self.gamma=gamma# 聚焦参数,推荐2.0self.reduction=reduction# 'mean'或'sum'defforward(self,inputs,targets):# inputs: 模型输出logits,shape [N, C] 或 [N]# targets: 真实标签,shape [N],值域{0,1,...,C-1}# 别这样写:直接用F.binary_cross_entropy_with_logits# 因为Focal Loss需要手动计算pt# 先计算交叉熵的log部分ce_loss=F.cross_entropy(inputs,targets,reduction='none')# ce_loss shape: [N]# 计算pt:模型预测正确类别的概率# 这里踩过坑:直接用softmax再gather,但数值不稳定pt=torch.exp(-ce_loss)# 因为ce_loss = -log(pt)# pt shape: [N]# 计算Focal Lossfocal_loss=(1-pt)**self.gamma*ce_loss# 加入alpha平衡因子# 需要根据targets构建alpha_tifself.alphaisnotNone:# 假设二分类,alpha_t = alpha if y=1 else 1-alpha# 多分类时,alpha可以是一个列表alpha_t=self.alpha*targets+(1-self.alpha)*(1-targets)# 别这样写:直接乘alpha_t,因为targets可能是long类型focal_loss=alpha_t*focal_loss# 根据reduction聚合ifself.reduction=='mean':returnfocal_loss.mean()elifself.reduction=='sum':returnfocal_loss.sum()else:returnfocal_loss几个容易翻车的地方:
数值稳定性:不要直接用softmax计算pt,然后用log。上面用
torch.exp(-ce_loss)是安全的,因为ce_loss已经包含了log_softmax。alpha的维度:如果做多分类,alpha应该是一个长度为C的tensor,每个类别一个权重。别偷懒用一个标量。
reduction的选择:YOLO系列通常用’sum’,因为每个anchor的损失需要独立累加。分类任务用’mean’更稳定。
六、在YOLO中集成Focal Loss
YOLOv5/v8的损失函数在loss.py里,替换交叉熵部分:
# 原始YOLOv5的分类损失self.bce=nn.BCEWithLogitsLoss(reduction='mean')# 替换为Focal Lossself.fl=FocalLoss(alpha=0.25,gamma=2.0,reduction='mean')注意:YOLO的输出是multi-label分类(每个类别独立二分类),所以Focal Loss需要按二分类方式计算。上面的实现已经兼容了这种情况。
调参建议:
- 先从γ=2.0, α=0.25开始(论文默认值)
- 如果正样本极少(<0.1%),尝试α=0.1~0.15
- 如果模型过拟合,增大γ到3.0~4.0
- 如果模型欠拟合,减小γ到1.0~1.5
七、个人经验:什么时候用Focal Loss
强烈推荐使用:
- 目标检测中的一阶段检测器(YOLO、SSD、RetinaNet)
- 正负样本比例超过1:100的分类任务
- 存在大量“简单负样本”的场景(比如背景占90%以上的图像)
不建议使用:
- 二阶段检测器(Faster R-CNN),因为RPN已经做了样本筛选
- 正负样本比例接近1:1的任务
- 模型已经过拟合的情况(Focal Loss会加剧过拟合)
一个容易被忽视的点:Focal Loss会改变损失函数的尺度。如果从交叉熵切换到Focal Loss,学习率可能需要调低1~2个数量级。我习惯先用γ=0(即交叉熵)跑一个epoch,观察损失量级,再调整学习率后启用Focal Loss。
最后说句实在话:Focal Loss不是万能药。如果你的数据极度不平衡(正样本比例<0.01%),建议先做数据增强和难例挖掘,再用Focal Loss做精细调优。我那个缺陷检测项目,最终方案是:Online Hard Example Mining + Focal Loss + 数据增强,才把mAP从0.12拉到0.87。工具是死的,组合拳才是活的。