高光谱图像分类实战:从Python环境搭建到PyTorch模型部署
当大多数人还在RGB图像的世界里打转时,计算机视觉的前沿已经悄然进入了高光谱时代。想象一下,你的相机不仅能捕捉红绿蓝三种颜色,而是能记录数百个连续光谱波段——这就是高光谱成像带来的革命。本文将带你从零开始,用Python和PyTorch构建一个完整的高光谱图像分类系统。
1. 高光谱图像基础与环境配置
高光谱图像与传统RGB图像的最大区别在于其光谱维度的丰富性。一个典型的高光谱数据集可能包含200-300个光谱波段,每个像素都携带着完整的光谱特征。这种"图谱合一"的特性,使得我们能够通过光谱特征精确区分看似相似的不同物质。
环境准备清单:
- Python 3.8+
- PyTorch 1.10+
- scikit-learn
- NumPy
- Matplotlib
- Spectral Python (SPy) 库
pip install torch torchvision scikit-learn numpy matplotlib spectral提示:建议使用Anaconda创建独立环境,避免依赖冲突。GPU加速可显著提升训练速度,确保安装对应版本的CUDA工具包。
Indian Pines和Pavia University是两个最常用的公开高光谱数据集。前者包含145×145像素的图像和16类地物,后者则是610×340像素和9类地物。我们将以Indian Pines为例:
from spectral import open_image # 加载Indian Pines数据集 img = open_image('Indian_pines.hdr') data = img.load() print(f"数据维度:{data.shape}") # 输出:(145, 145, 200)2. 数据预处理与特征工程
高光谱数据的维度诅咒是首要挑战。200多个波段意味着高计算成本和可能的过拟合风险。主成分分析(PCA)是最常用的降维手段:
from sklearn.decomposition import PCA # 将三维数据转为二维矩阵(样本数×特征数) X = data.reshape(-1, data.shape[2]) # 保留95%方差的PCA降维 pca = PCA(n_components=0.95) X_pca = pca.fit_transform(X) print(f"降维后特征数:{X_pca.shape[1]}")波段选择策略对比:
| 方法 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| PCA | 自动保留最大方差 | 失去物理意义 | 通用降维 |
| 波段相关性 | 保留原始特征 | 需要领域知识 | 特定应用 |
| 信息熵 | 选择信息量大的波段 | 计算成本高 | 精细分类 |
数据标准化同样关键。由于不同波段的光谱强度范围差异巨大,我们需要对每个波段单独归一化:
from sklearn.preprocessing import StandardScaler scaler = StandardScaler() X_scaled = scaler.fit_transform(X_pca)3. 构建PyTorch分类模型
我们将实现一个混合光谱-空间网络(HybridSN),它结合了3D卷积(提取光谱特征)和2D卷积(提取空间特征)的优势:
import torch import torch.nn as nn class HybridSN(nn.Module): def __init__(self, num_classes): super().__init__() # 光谱特征提取(3D卷积) self.conv3d_1 = nn.Conv3d(1, 8, kernel_size=(7,3,3)) self.conv3d_2 = nn.Conv3d(8, 16, kernel_size=(5,3,3)) self.conv3d_3 = nn.Conv3d(16, 32, kernel_size=(3,3,3)) # 空间特征提取(2D卷积) self.conv2d = nn.Conv2d(576, 64, kernel_size=3) # 分类头 self.fc1 = nn.Linear(18496, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, num_classes) def forward(self, x): # 3D卷积部分 x = F.relu(self.conv3d_1(x)) x = F.relu(self.conv3d_2(x)) x = F.relu(self.conv3d_3(x)) # 转为2D输入 batch, channels, _, height, width = x.shape x = x.view(batch, -1, height, width) # 2D卷积部分 x = F.relu(self.conv2d(x)) x = x.view(batch, -1) # 全连接层 x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x模型训练技巧:
- 使用标签平滑(Label Smoothing)缓解样本不均衡问题
- 采用学习率预热(Learning Rate Warmup)稳定初期训练
- 结合CutMix数据增强提升泛化能力
4. 训练优化与结果分析
数据划分是另一个关键点。高光谱分类常面临样本量少的问题,我们采用空间分块策略:
from sklearn.model_selection import train_test_split # 创建空间块(避免像素级泄露) block_size = 5 blocks = [] for i in range(0, data.shape[0]-block_size, block_size): for j in range(0, data.shape[1]-block_size, block_size): blocks.append((i,j)) # 按7:2:1划分训练/验证/测试集 train_blocks, test_blocks = train_test_split(blocks, test_size=0.3) val_blocks, test_blocks = train_test_split(test_blocks, test_size=0.33)训练参数配置:
import torch.optim as optim model = HybridSN(num_classes=16).cuda() criterion = nn.CrossEntropyLoss(label_smoothing=0.1) optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) # 学习率调度器 scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=100 )性能对比(Indian Pines数据集):
| 模型 | 总体精度 | Kappa系数 | 训练时间(epoch) |
|---|---|---|---|
| 2D-CNN | 83.2% | 0.812 | 45s |
| 3D-CNN | 86.7% | 0.853 | 68s |
| HybridSN | 91.4% | 0.902 | 52s |
| 论文SOTA | 94.1% | 0.932 | - |
可视化结果同样重要。使用混淆矩阵和分类图可以直观展示模型表现:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay # 绘制混淆矩阵 cm = confusion_matrix(true_labels, preds) disp = ConfusionMatrixDisplay(cm, display_labels=class_names) disp.plot(cmap='Blues')5. 实战技巧与避坑指南
数据增强策略:
- 光谱域:添加高斯噪声、波段随机屏蔽
- 空间域:随机旋转、镜像、小块裁剪
class HSI_Augment: def __call__(self, sample): # 光谱增强 if random.random() > 0.5: noise = torch.randn_like(sample) * 0.01 sample += noise # 空间增强 if random.random() > 0.5: sample = torch.flip(sample, dims=[-1]) # 水平翻转 return sample常见报错解决:
- 内存不足:减小批大小或使用梯度累积
- 过拟合:增加Dropout层或权重衰减
- 训练不稳定:使用梯度裁剪(Gradient Clipping)
# 梯度裁剪示例 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)模型部署优化:
- 使用TorchScript导出模型
- 应用半精度(FP16)推理
- 实现ONNX格式转换
# 导出TorchScript model.eval() traced_script = torch.jit.trace(model, example_input) traced_script.save("hs_classifier.pt")在实际项目中,我们发现将高光谱分类与传统RGB检测结合能显著提升系统鲁棒性。例如在农业应用中,高光谱识别作物病害,RGB定位病害区域,两者协同工作比单一模态效果提升约15%。