整个程序是一个基于进化算法的多模态融合架构搜索框架(DC-NAS),核心目标是自动搜索最优的多模态特征融合架构,用于分类任务。以下是程序的完整执行流程,并同步说明各辅助文件的调用时机和作用:
一、初始化阶段:环境与配置准备
启动入口:从
train_DC.py开始执行,首先加载配置参数和基础模块。- 调用
config.get_configs()获取超参数(如种群大小、迭代次数、融合方式等),定义数据路径和GPU资源。 - 忽略冗余警告(
simplefilter)和TensorFlow日志(os.environ['TF_CPP_MIN_LOG_LEVEL']),确保输出简洁。
- 调用
数据加载与预处理:
- 调用
datasetsplit/split.py:load_data_features()从5个视图目录(view_data_dir1到view_data_dir5)加载多模态数据,每个目录对应不同的数据拆分或视图来源。- 通过
splits_data()使用StratifiedShuffleSplit进行分层抽样,生成训练集的不同比例拆分(用于后续种群训练的多样性),并将数据整理为data_list(包含训练集、测试集、标签等)。
- 调用
data_utils/data_uitl.py:get_views()实际读取每个视图的.npy数据文件(如rgb1train.npy、dep1test.npy),并将标签转换为独热编码(to_categorical)。
- 调用
二、种群初始化:生成初始模型结构
- 生成初始种群:
train_DC.py调用population_init.generate_population_tree()生成初始模型结构种群(ini_population)。- 调用
random_tree.py:randomTree()随机生成树结构,其中叶节点代表视图(如1a、2b,数字为视图号,字母为数据子集标识),内部节点代表融合操作(如-0对应add)。
- 调用
utils_tree.py:new_tree()辅助生成随机树,结合配置中的视图数量(nb_view)和融合方式(fusion_ways)确保结构合法性。
三、进化迭代:训练与优化种群
进化算法的核心循环(train_DC.py的train()函数),共执行nb_iters次迭代,每次迭代包含训练评估和进化操作两个阶段。
阶段1:种群训练与评估(多进程并行)
多进程分配:
multi_proccess_train()函数管理训练任务,根据gpu_list分配GPU资源,使用multiprocessing.Pool并行训练种群中的个体。- 调用
list2str_tree()将模型结构编码为字符串,通过shared_code_sets记录已训练结构,避免重复计算。
单个模型训练(
train_individual()):- 解析编码:从个体编码(如
['1a', '2b', '-0'])中提取视图号和数据子集,调用tree_to_strlist.viewfusion()转换为树结构。 - 构建模型:调用
code2net_tree.py的code2net_tree(),将树结构转换为Keras模型:- 为每个视图创建输入层、BatchNorm和Dropout层,通过
Dense层统一特征维度。 - 递归执行融合操作(
fusion()函数支持add/mul/cat等方式),最终输出分类结果(softmax层)。
- 为每个视图创建输入层、BatchNorm和Dropout层,通过
- 训练与评估:
- 使用
ModelCheckpoint保存最优模型,EarlyStopping防止过拟合。 - 评估测试集准确率(
sklearn.metrics.accuracy_score),记录模型参数总量,结果写入result.csv(调用utils.write_result_file())。
- 使用
- 解析编码:从个体编码(如
阶段2:进化操作(生成下一代种群)
生成后代:
train_DC.py调用gen_offspring_tree.gen_offspring()从当前种群(P_t)生成后代(Q_t)。- 交叉操作:调用
gen_offspring_tree.py的crossover():- 随机选择两棵树的非根节点作为交叉点,交换子树(
split_tree()分割树,paste()拼接子树)。 - 调用
quchong()去重重复视图节点,确保结构合法性。
- 随机选择两棵树的非根节点作为交叉点,交换子树(
- 变异操作:调用
gen_offspring_tree.py的mutation():- 随机修改节点标签:视图节点(如
1a→3b)或融合节点(如-0→-2)。 - 或通过
mutation_new_tree_crossover()与随机新树交叉实现变异。
- 随机修改节点标签:视图节点(如
- 交叉操作:调用
选择与更新:从父代(
P_t)和后代(Q_t)中筛选最优个体组成下一代种群(P_{t+1}),** 调用best_fives.py** 存档最优结构(best_fivess、Archive1等)。
四、结束阶段:输出最优结果
- 迭代结束后,最优模型结构和性能保存在
result.csv和best_fives.py的变量中,可用于后续部署或分析。 - 调用
utils.py:load_result()加载历史结果,辅助选择最优个体。
各文件调用关系与核心作用总结
| 阶段 | 核心文件 | 辅助文件 | 主要作用 |
|---|---|---|---|
| 初始化 | train_DC.py | config.py、split.py | 加载配置、准备多模态数据 |
| 种群生成 | train_DC.py | random_tree.py、utils_tree.py | 生成随机树结构作为初始模型 |
| 模型训练 | train_DC.py | code2net_tree.py、utils.py | 树结构转模型、训练评估、记录结果 |
| 进化操作 | gen_offspring_tree.py | test_treelib.py | 交叉/变异生成新结构,去重和深度控制 |
| 最优解存档 | train_DC.py | best_fives.py | 保存迭代过程中的最优模型结构 |
整个流程通过进化算法不断迭代优化模型结构,结合多进程并行训练提高效率,最终找到性能最优的多模态融合架构。