别再为数据分布不同发愁了!用Python实战带你搞懂Domain Adaptation的三种核心方法
当你在MNIST手写数字数据集上训练的分类器,面对SVHN街景门牌号数字时准确率暴跌50%,这不是模型出了问题,而是遇到了**域偏移(Domain Shift)**的经典困境。这种现象在真实业务场景中比比皆是:医学影像分析中不同医院采集的CT扫描、自动驾驶系统中晴天和雨天的道路图像、电商平台里手机拍摄和专业棚拍的商品图片...数据分布的差异让精心调优的模型瞬间"失明"。
传统解决方案是重新标注目标域数据,但成本往往令人望而却步。域适应(Domain Adaptation)技术正是破局关键——它让模型学会忽略域间差异,专注挖掘跨域不变特征。本文将用Python代码实战演示三种最具代表性的方法:
- 基于分布对齐的MMD方法:通过核函数匹配两个域的高维特征分布
- 对抗训练框架DANN:用梯度反转层欺骗域判别器
- 双重任务DRCN:分类与重建并行的多任务学习
1. 环境准备与数据加载
我们先配置实验环境,使用PyTorch框架和经典数据集构建测试场景:
import torch import torch.nn as nn from torchvision import datasets, transforms # 数据预处理 transform = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # 加载MNIST作为源域 source_data = datasets.MNIST( root='./data', train=True, download=True, transform=transform) # 加载SVHN作为目标域 target_data = datasets.SVHN( root='./data', split='train', download=True, transform=transform)两个数据集虽然都是数字分类,但分布差异显著:
| 特征 | MNIST源域 | SVHN目标域 |
|---|---|---|
| 图像来源 | 手写数字扫描 | 街景门牌号照片 |
| 背景复杂度 | 纯白背景 | 复杂街道背景 |
| 数字形态 | 标准书写体 | 印刷体+变形 |
| 色彩模式 | 灰度图像 | RGB彩色图像 |
2. 基线模型与域差异评估
在实现域适应前,我们先建立性能基准。使用在MNIST上训练的ResNet-18直接测试SVHN:
model = resnet18(pretrained=False) model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) # 适配MNIST单通道 # 训练过程省略... test_acc = evaluate(model, target_test_loader) print(f"直接迁移准确率: {test_acc:.2f}%") # 典型输出约35-40%这种**直接迁移(Direct Transfer)**的表现验证了域偏移的存在。接下来我们引入最大均值差异(MMD)量化两个域的分布距离:
def mmd_rbf(source, target, gamma=1.0): # 计算高斯核矩阵 XX = torch.exp(-gamma * (source @ source.t())) YY = torch.exp(-gamma * (target @ target.t())) XY = torch.exp(-gamma * (source @ target.t())) return XX.mean() + YY.mean() - 2 * XY.mean() # 提取特征后计算MMD with torch.no_grad(): src_feat = model.feature_extractor(source_samples) tgt_feat = model.feature_extractor(target_samples) mmd_value = mmd_rbf(src_feat, tgt_feat) print(f"MMD距离: {mmd_value.item():.4f}") # 典型值约0.8-1.23. 基于MMD的域适应实现
核心思想是在训练时最小化MMD距离,使特征提取器生成域不变表示:
class MMD_Loss(nn.Module): def __init__(self, gamma=1.0): super().__init__() self.gamma = gamma def forward(self, src_feat, tgt_feat): return mmd_rbf(src_feat, tgt_feat, self.gamma) # 修改网络结构 class DomainAdaptNet(nn.Module): def __init__(self): super().__init__() self.feature_extractor = nn.Sequential( nn.Conv2d(3, 64, 5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 5), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten() ) self.classifier = nn.Linear(128*5*5, 10) def forward(self, x): features = self.feature_extractor(x) return self.classifier(features) # 训练循环中加入MMD损失 model = DomainAdaptNet() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) mmd_loss = MMD_Loss() for epoch in range(10): for (src_data, src_labels), (tgt_data, _) in zip(source_loader, target_loader): src_pred = model(src_data) tgt_feat = model.feature_extractor(tgt_data) # 计算总损失 cls_loss = F.cross_entropy(src_pred, src_labels) adapt_loss = mmd_loss(model.feature_extractor(src_data), tgt_feat) total_loss = cls_loss + 0.5 * adapt_loss # 平衡系数需调优 optimizer.zero_grad() total_loss.backward() optimizer.step()关键点在于平衡分类损失和适应损失的权重系数。经过训练后,SVHN测试准确率通常能提升至55-60%,同时MMD距离降低到0.3左右。
4. 对抗训练方法DANN实战
比MMD更激进的是域对抗神经网络(DANN),它通过梯度反转层(GRL)实现特征空间的对抗对齐:
class GradientReversalFn(torch.autograd.Function): @staticmethod def forward(ctx, x, alpha): ctx.alpha = alpha return x.view_as(x) @staticmethod def backward(ctx, grad_output): return -ctx.alpha * grad_output, None class DANN(nn.Module): def __init__(self): super().__init__() self.feature_extractor = nn.Sequential( # 与MMD相同的特征提取层 ) self.classifier = nn.Linear(128*5*5, 10) self.domain_discriminator = nn.Sequential( nn.Linear(128*5*5, 256), nn.ReLU(), nn.Linear(256, 1) ) def forward(self, x, alpha=1.0): features = self.feature_extractor(x) reversed_features = GradientReversalFn.apply(features, alpha) domain_pred = self.domain_discriminator(reversed_features) return self.classifier(features), domain_pred.squeeze() # 训练过程需要同时优化两个目标 model = DANN() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(10): for (src_data, src_labels), (tgt_data, _) in zip(source_loader, target_loader): # 源域数据标签为0,目标域为1 domain_labels = torch.cat([ torch.zeros(src_data.size(0)), torch.ones(tgt_data.size(0)) ]) # 合并数据 all_data = torch.cat([src_data, tgt_data]) class_pred, domain_pred = model(all_data) # 计算损失 cls_loss = F.cross_entropy(class_pred[:len(src_data)], src_labels) domain_loss = F.binary_cross_entropy_with_logits( domain_pred, domain_labels) total_loss = cls_loss + 0.3 * domain_loss # 平衡系数 optimizer.zero_grad() total_loss.backward() optimizer.step()DANN的核心创新在于梯度反转层——在反向传播时对域判别器的梯度取反,迫使特征提取器生成"欺骗性"特征。这种方法通常能达到60-65%的准确率,但对超参数(如α系数)更敏感。
5. 基于重建的DRCN方法
第三种思路是通过重建目标域数据来学习共享表示,典型代表是深度重建分类网络(DRCN):
class DRCN(nn.Module): def __init__(self): super().__init__() # 共享编码器 self.encoder = nn.Sequential( nn.Conv2d(3, 64, 5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 5), nn.ReLU(), nn.MaxPool2d(2) ) # 分类分支 self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(128*5*5, 10) ) # 重建分支 self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, 5, stride=2), nn.ReLU(), nn.ConvTranspose2d(64, 3, 5, stride=2), nn.Tanh() ) def forward(self, x, mode='train'): features = self.encoder(x) if mode == 'classify': return self.classifier(features) elif mode == 'reconstruct': return self.decoder(features) # 交替训练策略 model = DRCN() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(10): # 阶段1:源域分类 for src_data, src_labels in source_loader: pred = model(src_data, mode='classify') loss = F.cross_entropy(pred, src_labels) optimizer.zero_grad() loss.backward() optimizer.step() # 阶段2:目标域重建 for tgt_data, _ in target_loader: recon = model(tgt_data, mode='reconstruct') loss = F.mse_loss(recon, tgt_data) optimizer.zero_grad() loss.backward() optimizer.step()DRCN通过共享编码器迫使网络找出对分类和重建都有用的特征。实际部署时可以采用更复杂的训练策略:
- 先用源域数据预训练分类分支
- 冻结分类器,用目标域数据训练解码器
- 微调整个网络
这种方法在笔者的多个工业项目中表现稳定,尤其适合目标域完全没有标签的场景,典型准确率在58-63%之间。
6. 方法对比与选型指南
三种方法各有优劣,以下是关键对比:
| 指标 | MMD方法 | DANN | DRCN |
|---|---|---|---|
| 理论复杂度 | 中等(核方法) | 高(对抗训练) | 低(多任务学习) |
| 计算开销 | 额外MMD计算 | 需额外判别器 | 需解码器 |
| 超参数敏感性 | 核带宽选择 | 梯度反转系数 | 任务平衡权重 |
| 最佳适用场景 | 中小型数据集 | 大数据集 | 无标签目标域 |
| 典型准确率 | 55-60% | 60-65% | 58-63% |
实际选择时建议:先尝试MMD作为基线,数据量大时用DANN追求更高性能,当目标域完全无标签则首选DRCN。工业场景中,组合多种方法的集成策略往往能取得更好效果。
在医疗影像分析项目中,笔者团队曾遇到内窥镜图像(源域)与超声图像(目标域)的适配问题。最终方案是先用DRCN进行初步对齐,再用DANN精细调整,使模型在目标域的F1分数从0.42提升至0.68。关键教训是:域适应不是一次性过程,而需要根据数据特性设计分层策略。