在 MATLAB 中实现基于 迁移学习(Transfer Learning, TL) 与 SqueezeNet 网络的滚动轴承故障诊断,是一种高效利用预训练模型进行小样本故障分类的方法。以下是一个完整的实现流程,包括数据准备、网络修改、迁移学习训练和评估。
🧠 背景简述
SqueezeNet:轻量级 CNN,参数少、速度快,适合嵌入式或资源受限场景。
迁移学习(TL):利用 ImageNet 上预训练的 SqueezeNet 特征提取能力,微调最后几层用于轴承故障分类。
滚动轴承故障诊断:通常使用振动信号 → 转换为时频图像(如 STFT、CWT、Wigner-Ville 等)→ 图像分类任务。
✅ 实现步骤(MATLAB)
- 准备故障数据集(图像格式)
假设你已将原始振动信号转换为图像(如 227×227 的 RGB 图像),并按类别组织在文件夹中:
dataset/
├── normal/
├── inner_fault/
├── outer_fault/
└── ball_fault/
⚠️ SqueezeNet 输入尺寸为 227×227×3,务必统一图像尺寸。
- 加载预训练 SqueezeNet
matlab
% 加载预训练 SqueezeNet(ImageNet)
net = squeezenet;
- 修改网络结构(适配你的故障类别数)
matlab
% 获取原网络层
layers = net.Layers;
% 假设你的故障类别数为 4
numClasses = 4;
% 替换最后两层:conv10 和 softmax + classification
newLayers = [
layers(1:end-2) % 保留前面所有层
convolution2dLayer(1, numClasses, ‘Name’, ‘new_conv’) % 新卷积层
softmaxLayer(‘Name’, ‘new_softmax’)
classificationLayer(‘Name’, ‘new_classoutput’)
];
% 显示新网络
analyzeNetwork(newLayers);
- 准备图像数据存储(ImageDatastore)
matlab
dataFolder = ‘dataset’; % 你的数据路径
imds = imageDatastore(dataFolder, …
‘IncludeSubfolders’, true, …
‘LabelSource’, ‘foldernames’);
% 划分训练集和验证集(例如 80% 训练,20% 验证)
[imdsTrain, imdsValidation] = splitEachLabel(imds, 0.8, ‘randomized’);
- 设置训练选项(迁移学习策略)
冻结前面的特征提取层,只训练新添加的层(可选):
matlab
% 获取要训练的层(仅新层)
layerNames = {‘new_conv’, ‘new_softmax’, ‘new_classoutput’};
options = trainingOptions(‘sgdm’, …
‘InitialLearnRate’, 1e-3, …
‘MaxEpochs’, 20, …
‘MiniBatchSize’, 32, …
‘Shuffle’, ‘every-epoch’, …
‘ValidationData’, imdsValidation, …
‘ValidationFrequency’, 30, …
‘Verbose’, false, …
‘Plots’, ‘training-progress’);
% 若想微调整个网络,可设置 LearnRateSchedule=‘piecewise’ 并降低初始学习率
🔍 进阶技巧:先冻结特征提取层训练新层,再解冻微调整个网络(两阶段训练)。
- 开始训练
matlab
netTransfer = trainNetwork(imdsTrain, newLayers, options);
- 评估模型性能
matlab
% 预测验证集
YPred = classify(netTransfer, imdsValidation);
YTrue = imdsValidation.Labels;
% 计算准确率
accuracy = mean(YPred == YTrue);
fprintf(‘验证集准确率: %.2f%%\n’, accuracy*100);
% 混淆矩阵
figure;
confusionchart(YTrue, YPred);
title(‘混淆矩阵 - 轴承故障诊断’);
📌 补充建议
数据增强(防止过拟合):
matlab
augImds = augmentedImageDatastore([227 227], imdsTrain, …
‘ColorPreprocessing’, ‘gray2rgb’); % 若原图是灰度图
注意:SqueezeNet 需要 3 通道输入,若原始为灰度图,需转为 RGB(如 gray2rgb)。信号转图像方法推荐:
短时傅里叶变换(STFT) → 时频图
连续小波变换(CWT) → 小波尺度图(效果通常更好)
使用 cwtfilterbank 或 stft 函数生成图像使用预训练权重初始化新层(可选):
可从原 conv10 层复制部分权重(若类别数相近)
📚 参考资料
MATLAB 官方文档:Transfer Learning Using SqueezeNet
CWRU 轴承数据集(常用公开数据集):https://engineering.case.edu/bearingdatacenter