1. 项目概述与核心价值
最近在整理自己的开源项目时,我一直在思考一个问题:一个模型训练完成后,如何让它能持续学习新知识,而不是像“一次性用品”那样被束之高阁?这正是“持续学习”要解决的核心痛点。SKY-lv/continuous-learning 这个项目,从名字就能看出其野心——它不是一个简单的模型训练脚本集合,而是一个旨在构建具备“终身学习”能力的智能系统的框架。简单来说,它想让你的AI模型像人一样,在遇到新任务、新数据时,能够记住旧知识,同时高效地学习新东西,而不是学了新的就忘了旧的,或者需要把所有新旧数据混在一起重新训练一遍。
这个需求在现实场景中太普遍了。想象一下,你开发了一个图像分类模型,能识别猫和狗。上线后,用户希望它能再识别鸟和鱼。传统做法是收集所有猫、狗、鸟、鱼的图片,重新训练一个模型。这不仅耗时耗力,更重要的是,你可能已经丢失了最初训练猫狗分类时的那部分数据,或者数据隐私法规不允许你混合新旧数据。持续学习就是为了优雅地解决这类“灾难性遗忘”问题。它让模型能够在不遗忘旧任务的前提下,按顺序、增量式地学习一系列新任务。这个项目提供了一个结构化的工具箱,帮助研究者和开发者快速搭建、评估和比较不同的持续学习算法。
对于机器学习工程师和研究者而言,这个项目的价值在于它提供了一个标准化的“试验场”。持续学习领域算法众多,比如基于正则化的方法、基于动态架构的方法、基于回放记忆的方法等。自己从头实现并公平比较这些算法非常繁琐。而这个项目很可能将这些经典和前沿的算法集成在一个统一的框架下,定义了清晰的数据流、任务切换接口和评估协议。你可以像搭积木一样,选择不同的骨干网络、不同的持续学习策略,在标准的数据集(如Split MNIST, Split CIFAR-100)上进行实验,快速验证你的想法或复现他人的工作。对于应用开发者,它则提供了一种将“静态模型”升级为“可进化模型”的可能路径,虽然直接用于生产环境还需要大量的工程化工作,但其核心思想极具启发性。
2. 持续学习的核心挑战与算法流派
在深入这个项目的具体实现之前,我们必须先理解持续学习要对抗的“头号敌人”:灾难性遗忘。当一个神经网络在任务A上训练得很好之后,如果直接用任务B的数据去训练它,网络会迅速调整其参数以适应任务B,但这通常会导致其在任务A上的性能急剧下降,仿佛完全忘记了之前学到的知识。这是因为神经网络参数是共享的,优化过程没有对“哪些参数对旧任务很重要”施加保护。
为了应对这一挑战,学术界发展出了几大主流技术流派,这也是像continuous-learning这类项目会重点集成和对比的方向。
2.1 基于正则化的方法:给重要的参数“上锁”
这类方法的思路很直观:在旧任务上表现良好的网络,其参数的重要性是不同的。有些参数稍微改动就会严重影响旧任务的性能,这些就是“重要参数”;有些参数则相对“自由”。基于正则化的方法,就是在学习新任务时,对重要的旧参数施加惩罚,限制其变化幅度。
最著名的代表是EWC。它的核心思想是,在完成旧任务后,计算网络每个参数对于旧任务损失函数的“重要性”(通常用费雪信息矩阵的对角线近似)。当学习新任务时,在损失函数中增加一个正则项:如果某个参数偏离了旧任务时的最优值,且该参数很重要,那么就会受到很大的惩罚。这就好比给重要的神经元上了一把“锁”,新任务的学习只能在不触动这些锁的前提下,调整其他相对不重要的参数。
另一个常见方法是LwF。它不需要保存旧数据,而是利用“知识蒸馏”的思想。在学习新任务时,不仅要求网络输出对新任务标签的预测,还要求其输出(经过温度缩放后的软标签)尽可能接近网络自身在旧任务上的“记忆输出”。这个“记忆输出”是网络在接触新数据前,对旧任务数据的预测。通过这种方式,网络在适应新任务的同时,被“提醒”要保持对旧任务输出分布的拟合。
实操心得:基于正则化的方法实现相对简单,且不需要存储原始数据,在隐私敏感场景下有优势。但它的效果严重依赖于重要性估计的准确性。EWC对超参数(正则化系数λ)比较敏感,调参需要耐心。LwF则对任务之间的相似性有要求,如果新旧任务差异巨大,蒸馏效果会打折扣。
2.2 基于回放记忆的方法:给模型“复习旧课”
这是最符合直觉的一类方法,既然怕忘记,那就时不时把旧知识拿出来复习一下。这类方法会维护一个固定大小的“记忆缓冲区”,存储一部分旧任务的代表性样本(或其特征)。在学习新任务时,不仅使用新任务的数据,还会从缓冲区中采样一些旧数据,混合在一起进行训练。
iCaRL是一个经典的基于回放记忆的算法。它不仅是简单地回放,还引入了一套完整的流程:1) 使用herding算法选择最具代表性的旧样本存入记忆;2) 在学习新类时,会同时用记忆中的旧样本和新样本一起训练;3) 使用一个“原型向量”来进行最近邻分类,而不是直接使用分类头的输出,这有助于稳定特征空间。iCaRL展示了如何将回放记忆与精心的样本选择和分类策略结合,取得比简单回放好得多的效果。
基于回放的方法效果通常很稳定,因为它让模型直接“看到”了旧数据。但它的缺点也很明显:需要额外的存储空间,并且可能引发隐私问题(存储了原始数据)。此外,如何高效地从海量旧数据中选取最具代表性的那一小部分存入记忆,本身就是一个研究课题。
2.3 基于动态架构的方法:给新任务“开小灶”
如果说前两种方法是在一个固定的“大脑”里想办法,那么动态架构的方法则更“奢侈”一些:它允许网络结构随着新任务的到来而增长。当学习新任务时,网络可以添加新的神经元、新的层,或者将一部分参数“隔离”出来专用于新任务,从而避免对旧任务参数的干扰。
PNN是这一思路的早期代表,它为每个新任务添加一个新的侧枝网络,并通过横向连接从旧网络中提取特征。HAT则更精巧,它在网络的每一层引入可学习的注意力掩码,当学习一个任务时,会激活一部分神经元并“冻结”它们,后续任务只能使用未被冻结的神经元。这样就在一个共享的网络中,为不同任务划分了专用的子网络。
这类方法的优势是理论上可以完全避免遗忘,因为旧任务的参数被物理隔离或保护起来了。但代价是模型会随着任务数量线性(甚至更快)地增长,导致计算和存储开销变大,不够“优雅”。在实际部署中,模型膨胀是一个需要严肃考虑的问题。
一个像样的continuous-learning项目,必然会涵盖上述至少两到三种流派的核心算法,并提供统一的接口,让使用者可以方便地进行算法A vs 算法B的对比实验。接下来,我们就来拆解一下,要实现这样一个框架,需要设计哪些核心模块。
3. 项目框架的核心模块设计
一个优秀的持续学习框架,其价值不仅在于实现了多少个SOTA算法,更在于其架构的清晰度、扩展的便捷性和实验的可复现性。根据我对这类项目的理解,SKY-lv/continuous-learning的理想架构应该包含以下几个核心模块,这也是我们评估或自建类似框架时的设计蓝图。
3.1 任务流与数据加载器
这是整个框架的基石。持续学习的核心是“任务序列”,框架必须能清晰地定义和管理任务的到来顺序。
首先,需要定义一个Task或Experience的抽象。每个任务应包含:任务ID、训练数据集、验证数据集、测试数据集,以及该任务所涉及的类别标签集合。对于图像分类,常用的基准测试包括:
- Split MNIST: 将10个手写数字类别按顺序分成5个任务,每个任务学习区分2个数字。
- Split CIFAR-100: 将100个类别分成10个或20个任务,每个任务学习10个或5个新类别。
- Permuted MNIST: 每个任务中,对MNIST图像的像素进行不同的随机排列,从而创造出输入分布不同但输出空间相同的任务序列。
框架需要提供标准的数据集切割和转换工具。例如,一个data/benchmarks.py模块,里面定义了get_split_mnist_tasks(num_tasks=5)这样的函数,返回一个任务列表。
更重要的是数据加载策略。在训练第N个任务时,数据加载器应该只能访问当前任务N的数据(对于无回放的方法),或者能访问当前任务数据+记忆缓冲区中的旧数据(对于有回放的方法)。框架需要封装好这个逻辑,对算法开发者透明。通常,会有一个ReplayDataLoader类,它内部维护一个当前任务数据集的DataLoader和一个记忆缓冲区的DataLoader,在每次迭代时按一定比例从两者中采样数据,合并成一个批次。
3.2 算法抽象与策略接口
这是框架的灵魂。所有持续学习算法,无论属于哪个流派,都应该继承自一个统一的基类,例如ContinualLearningStrategy。这个基类定义了算法生命周期中的几个关键钩子函数:
class ContinualLearningStrategy(nn.Module): def __init__(self, model, optimizer, args): super().__init__() self.model = model self.optimizer = optimizer def observe(self, batch, task_id): """核心训练步骤。接收一个数据批次,返回损失值。""" # 算法在这里实现其核心逻辑:计算损失,可能包含正则项、蒸馏项等。 # 返回的loss会被用于反向传播。 pass def before_task(self, task_id, train_loader): """在开始训练一个新任务前调用。用于初始化记忆、调整网络结构等。""" pass def after_task(self, task_id, train_loader): """在完成一个任务的训练后调用。用于更新重要性估计、选择样本存入记忆等。""" pass def on_eval(self): """在切换到评估模式时调用。用于设置网络中的特定模块(如HAT的掩码)。""" pass这种设计模式极大地提升了代码的模块化和可扩展性。要实现一个新的算法,你只需要继承这个基类,然后实现observe方法(以及可选的before_task/after_task)。框架的主训练循环会固定,它负责迭代数据、调用strategy.observe()、执行反向传播和优化器更新。算法开发者只需关心算法本身的逻辑。
例如,EWC策略会在after_task中计算并保存参数的重要性和最优值,然后在observe中计算正则损失并加到总损失上。而一个简单的回放策略会在observe中从记忆缓冲区采样数据,与当前批次混合训练。
3.3 模型与评估协议
模型:框架通常会支持常见的骨干网络,如用于MNIST的简单CNN、用于CIFAR的ResNet-18/32等。关键点在于,分类头(最后的全连接层)需要特殊处理。在持续学习分类任务中,分类头的输出维度应该是所有已见类别的总数,并且需要处理“增量分类头”的问题——即每学一个新任务,分类头就要增加对应数量的输出神经元。框架需要提供一种优雅的方式来自动管理分类头的扩展。
评估协议:这是衡量持续学习算法好坏的金标准。评估必须在所有已学过的任务上进行。训练完第T个任务后,我们需要在任务1, 2, ..., T的测试集上分别评估模型的准确率。最终会得到一个T x T的精度矩阵,其中元素a_{i,j}表示在训练完第i个任务后,在第j个任务测试集上的准确率。
两个核心指标是:
- 平均准确率:训练完所有任务后,计算模型在每个任务上最终准确率的平均值。这反映了模型的整体性能。
- 遗忘度:衡量模型对旧任务的遗忘程度。对于任务
i (i < T),其遗忘度可以定义为max_{k \in {i,...,T-1}}(a_{k,i}) - a_{T,i},即它在历史最高精度和最终精度之间的差值。所有旧任务遗忘度的平均值反映了算法的抗遗忘能力。
一个成熟的框架会自动在每轮训练后运行这套评估流程,并记录和可视化这些指标,生成类似于论文中的结果图表。
3.4 训练循环与实验管理
这是将以上所有模块串联起来的“胶水代码”。一个典型的训练循环伪代码如下:
# 初始化模型、优化器、策略 model = get_model() optimizer = torch.optim.Adam(model.parameters()) strategy = EWCStrategy(model, optimizer, ewc_lambda=100) # 获取任务序列 tasks = get_split_cifar100_tasks(num_tasks=10) for task_id, task_data in enumerate(tasks): print(f"Starting Task {task_id}") train_loader = task_data['train'] strategy.before_task(task_id, train_loader) # 训练当前任务 for epoch in range(num_epochs): for batch in train_loader: loss = strategy.observe(batch, task_id) loss.backward() optimizer.step() optimizer.zero_grad() strategy.after_task(task_id, train_loader) # 评估所有已学任务 eval_results = evaluate_on_all_tasks(model, strategy, tasks[:task_id+1]) log_results(task_id, eval_results)此外,一个实用的框架还会集成实验管理工具,比如通过配置文件(YAML或JSON)来定义实验超参数(学习率、批大小、正则化系数、记忆缓冲区大小等),并支持像Weights & Biases或TensorBoard这样的工具来跟踪实验过程、记录指标和可视化结果。这能让你轻松地运行一组对比实验,并清晰地分析结果。
4. 核心算法实现要点与避坑指南
假设我们要在continuous-learning框架中实现两个代表性算法:EWC(基于正则化)和一个小型回放策略。这里分享一些实现细节和容易踩的坑。
4.1 EWC 实现详解与调参陷阱
EWC的核心是在损失函数中加入正则项:L_total = L_ce + λ * Σ_i F_i * (θ_i - θ*_i)^2,其中F_i是参数θ_i的费雪信息矩阵对角线值(重要性),θ*_i是旧任务上的最优参数值。
实现步骤:
- 保存旧参数:在
after_task中,将当前模型所有需要保护参数的当前值θ*深拷贝保存下来。 - 计算费雪信息矩阵:这是EWC最微妙的一步。费雪信息矩阵
F衡量的是参数对模型预测分布的贡献。一种标准的近似方法是:在旧任务的数据集上执行一次前向传播,对于每个样本,计算模型输出对数概率对参数的梯度,然后对这些梯度的平方取平均。在实际操作中,我们通常按以下步骤进行:fisher_dict = {} model.eval() for batch in old_task_dataloader: model.zero_grad() output = model(batch.x) # 假设是分类任务,计算对数似然 log_likelihood = F.log_softmax(output, dim=1) # 为每个样本的每个类别计算梯度并累加 for i in range(batch.x.size(0)): label = batch.y[i] # 取对应真实标签的对数概率 log_prob = log_likelihood[i, label] log_prob.backward(retain_graph=True) # 关键:retain_graph for name, param in model.named_parameters(): if param.grad is not None: if name not in fisher_dict: fisher_dict[name] = param.grad.data.clone().pow(2) else: fisher_dict[name] += param.grad.data.clone().pow(2) # 对所有样本取平均 for name in fisher_dict: fisher_dict[name] /= len(old_task_dataloader.dataset)重要提示:计算费雪矩阵时,
backward()需要设置retain_graph=True,因为我们需要为每个样本单独计算梯度。这个过程计算量较大,通常只在旧任务的一个子集(例如部分训练数据)上进行,这本身也是一种近似。 - 整合到损失中:在
observe方法中,计算完交叉熵损失L_ce后,遍历所有参数,计算EWC正则损失并累加。ewc_loss = 0 for name, param in model.named_parameters(): if name in self.fisher_dict: fisher = self.fisher_dict[name] old_param = self.optimal_params[name] ewc_loss += (fisher * (param - old_param).pow(2)).sum() total_loss = cross_entropy_loss + self.ewc_lambda * ewc_loss
避坑指南:
- λ的选择是玄学:EWC的超参数
λ(正则化强度)对结果影响巨大。值太小,约束不够,还是会遗忘;值太大,会严重阻碍新任务的学习。它没有理论上的最优值,必须通过网格搜索在验证集上确定。通常需要尝试[1, 10, 100, 1000, 5000]这样的数量级。 - 费雪矩阵的估计质量:使用全部训练数据计算费雪矩阵开销太大,通常只用一部分数据。这会导致重要性估计不准。一个改进是使用对角经验费雪,它直接使用损失函数对参数的梯度平方的移动平均来更新重要性,这更高效且适合在线学习场景。
- 只保护部分层:通常只对网络的特征提取层(卷积层、全连接层)施加EWC约束,而不对最后的分类头施加,因为分类头本身就需要为新的类别进行调整。
- 内存消耗:需要为每个旧任务存储一套
{最优参数θ*, 费雪矩阵F}。对于大型网络和多个任务,这会占用可观的内存。可以考虑只存储最重要的那部分参数(例如,按F值排序,只保留top-k%)。
4.2 简单经验回放实现与样本选择策略
一个最基础的经验回放策略实现起来比EWC更直观。其核心是维护一个固定大小的记忆缓冲区M。
实现步骤:
- 缓冲区管理:在
__init__中初始化一个空缓冲区,可以是一个列表或更高效的数据结构(如环形缓冲区)。 - 样本选择与存储:在
after_task中,需要从刚结束的任务数据中选择一部分样本存入缓冲区。最简单的策略是随机选择。但更有效的策略是:- Herding(iCaRL):选择那些最接近该类特征均值的样本,旨在保留最能代表该类分布的原型。
- 基于不确定性的选择:选择模型预测最不确定(熵最高)的样本,这些通常是决策边界附近的“难样本”,回放它们可能收益更大。
- 基于覆盖度的选择:使用聚类方法(如K-Means)选择样本,以确保缓冲区中的样本能尽可能覆盖旧任务的数据分布。
- 训练过程:在
observe方法中,从当前任务批次batch_new中取出数据,同时从记忆缓冲区M中随机采样一个小批次batch_old。将两者合并,计算损失并进行反向传播。这里的关键是,batch_old的标签需要被正确映射到当前扩展后的分类头对应的位置。
避坑指南:
- 缓冲区大小是关键:缓冲区能存多少旧样本,直接决定了抗遗忘能力的上限。通常,每个旧类别存储20-50个样本是一个常见的起点。你需要平衡内存限制和性能需求。
- 新旧数据比例:在混合批次中,新旧数据的比例需要调整。常见做法是保持
|batch_new| : |batch_old|在 2:1 到 1:1 之间。这个比例也可以作为一个可调的超参数。 - 分类头偏移:这是新手最容易出错的地方。假设旧任务有50类,新任务有10类。那么当前模型的分类头输出维度是60。对于缓冲区中一个属于旧任务第5类的样本,它的标签在计算损失时,对应的应该是第5个输出神经元,而不是标签“5”本身(如果使用0-based索引)。框架需要处理好这种标签的映射逻辑,通常是在数据存入缓冲区时,就将其标签转换为一个全局的、连续的任务无关的类别ID。
- 数据增强:对回放样本应用适度的数据增强(如随机裁剪、翻转)可以增加多样性,提升效果,这被称为“增强回放”。
5. 高级话题与未来方向探索
当你熟练掌握了基础算法的实现和调参后,continuous-learning这个领域还有更多深水区和有趣的方向值得探索,这些也往往是开源项目试图涵盖或提供接口的前沿部分。
5.1 任务无关的持续学习与任务标识符
我们之前讨论的都属于“任务感知”的持续学习,即算法明确知道当前处于哪个任务(有一个task_id)。但在更现实的场景中,数据流可能不会自带清晰的任务边界标签。这就是“任务无关的持续学习”。此时,算法需要自己检测分布变化,判断是否进入了新任务。
一种常见方法是使用“任务标识符”。例如,在HAT中,网络会为每个检测到的“新任务”学习一组新的注意力掩码。另一种思路是使用生成模型(如VAE、GAN)来为输入数据学习一个表征,并通过监控表征分布的变化来检测任务边界。实现任务无关的CL是框架设计的一个高级挑战,它要求数据流接口和算法接口更加灵活。
5.2 在线持续学习与流式数据
大部分研究假设每个任务的数据可以多次遍历(多轮训练)。但在线持续学习场景下,数据以流的形式到来,每个样本通常只能被看到一次。这对算法提出了更高的要求:必须在单次更新中做出正确的调整。
基于回放的方法天然适合在线场景,因为缓冲区总是保存着部分历史。基于正则化的方法(如在线EWC)则需要在线更新重要性估计。一些专门针对在线场景的算法,如GEM和A-GEM,通过将梯度投影到对旧任务损失梯度方向夹角为钝角的方向上,来保证不增加旧任务的损失,这种约束可以在单个批次更新中完成计算。
5.3 持续学习与其他学习范式的结合
这是当前研究非常活跃的领域,也是检验一个框架扩展性的试金石。
- 持续学习与元学习:元学习旨在“学会如何学习”。将元学习应用于持续学习,目标是让模型在经历一系列任务后,获得快速适应新任务且不遗忘旧任务的能力。例如,MAML的持续学习变体,试图找到一个好的参数初始化点,从这个点出发,对每个新任务进行少量更新就能达到好性能,同时这个更新过程本身是受约束的以避免遗忘。
- 持续学习与自监督学习:自监督学习利用数据自身的结构作为监督信号。在持续学习中引入自监督辅助任务(如旋转预测、拼图),可以帮助模型学习到更通用、更鲁棒的特征表示,这些特征对任务变化不敏感,从而减轻遗忘。在框架中,这可能意味着损失函数是交叉熵损失和自监督损失的加权和。
- 持续学习与联邦学习:在联邦学习场景中,数据分布在多个客户端且隐私敏感。持续学习可以帮助每个客户端在本地进行增量学习,同时通过联邦聚合来融合知识,并避免全局模型的遗忘。这要求框架能模拟分布式的训练环境。
一个设计良好的continuous-learning框架应该为这些交叉研究提供可能性。例如,提供灵活的损失函数组合接口、支持自定义的数据流模拟器、以及便于分布式训练的钩子。
6. 项目实践:从克隆到贡献
假设你对SKY-lv/continuous-learning项目产生了兴趣,无论是想使用它进行实验,还是想为其贡献代码,以下是一个标准的实践路径。
6.1 环境搭建与快速启动
首先,克隆项目并查看其结构。
git clone https://github.com/SKY-lv/continuous-learning.git cd continuous-learning ls -la一个理想的项目结构可能如下:
continuous-learning/ ├── README.md # 项目总览、安装、快速开始 ├── requirements.txt # Python依赖 ├── configs/ # 实验配置文件(YAML/JSON) ├── src/ │ ├── data/ # 数据集加载与任务构建模块 │ ├── models/ # 骨干网络定义 │ ├── strategies/ # 所有持续学习算法实现(EWC, LwF, iCaRL等) │ ├── trainers/ # 训练循环和评估逻辑 │ └── utils/ # 日志、工具函数 ├── scripts/ # 启动训练和评估的脚本 ├── experiments/ # 默认的实验输出目录 └── tests/ # 单元测试接下来,按照README.md的指引安装依赖,通常是用pip install -r requirements.txt。核心依赖一般是PyTorch,torchvision,numpy,tqdm,tensorboard/wandb等。
尝试运行一个示例脚本,这是验证环境是否正确的关键一步。
python scripts/train.py --config configs/ewc_split_mnist.yaml这个命令应该会启动一个在Split MNIST数据集上使用EWC算法的训练实验。观察控制台输出,看是否有错误,并检查experiments/目录下是否生成了日志和模型检查点。
6.2 代码走读与核心逻辑追踪
要深入理解项目,最好的方法是“跑通一个实验,跟踪一条数据流”。选择最简单的配置(如Split MNIST + EWC),在关键函数处设置断点或添加打印语句。
- 追踪数据流:从
train.py的main函数开始,找到数据加载部分。看它是如何调用src/data/benchmarks.py中的函数来构建任务序列的。理解返回的tasks数据结构。 - 追踪训练循环:找到核心的训练循环。它很可能在
src/trainers/base_trainer.py中。观察它如何遍历任务,在每个任务中如何遍历数据批次。重点关注它如何调用你选择的策略(如EWC)的observe方法。 - 深入策略内部:打开
src/strategies/ewc.py。对照我们前面讲过的原理,看它的__init__,before_task,observe,after_task方法是如何实现的。它是如何计算和存储费雪矩阵的?正则项是如何加到损失上的? - 追踪评估流程:找到在任务训练结束后被调用的评估函数。它应该是在所有已学任务上测试性能,并计算平均准确率和遗忘度。查看这些指标是如何被计算和记录的。
通过这样一次完整的跟踪,你就能彻底掌握这个框架的工作流程,为后续的修改或贡献打下基础。
6.3 实现一个新算法并提交贡献
假设你想实现一篇新论文中的算法X。以下是标准的贡献流程:
- 复现环境:确保你的本地环境能无错误地运行项目现有的所有示例。
- 创建新策略文件:在
src/strategies/目录下创建x.py。让你的新类继承自基础策略类(例如BaseStrategy)。# src/strategies/x.py from .base import BaseStrategy class XStrategy(BaseStrategy): def __init__(self, model, optimizer, args): super().__init__(model, optimizer, args) # 初始化你的算法特有的参数,比如记忆缓冲区、正则化系数等 self.memory_buffer = [] self.buffer_size = args.buffer_size def observe(self, batch, task_id): # 实现算法X的核心训练逻辑 # 1. 可能从缓冲区合并数据 # 2. 计算损失(包含算法特有的项) # 3. 返回总损失 pass def after_task(self, task_id, train_loader): # 任务结束后,更新缓冲区或其他内部状态 self.update_memory_buffer(train_loader, task_id) # ... 其他必要方法 - 注册你的策略:为了让框架能发现并使用你的新策略,需要在策略包的
__init__.py中导入并注册它,或者通过配置文件动态加载。一种常见做法是使用策略名称到类的映射字典。# src/strategies/__init__.py from .ewc import EWCStrategy from .lwf import LwFStrategy from .x import XStrategy # 新增导入 STRATEGY_MAP = { 'ewc': EWCStrategy, 'lwf': LwFStrategy, 'x': XStrategy, # 新增映射 } - 创建配置文件:在
configs/目录下复制一个现有的YAML文件(如ewc_split_mnist.yaml),修改为x_split_mnist.yaml。主要改动strategy.name为'x',并添加算法X所需的特有参数(如buffer_size,x_lambda等)。 - 测试与验证:运行你的新配置
python scripts/train.py --config configs/x_split_mnist.yaml。确保训练能正常进行,没有报错。然后,在标准基准测试(如Split MNIST, Split CIFAR-100)上运行你的算法,并与基线算法(如EWC、简单回放)进行比较,确保其性能符合预期(至少不比随机训练差)。 - 编写单元测试:在
tests/目录下为你的新策略添加简单的单元测试,测试其初始化、前向传播和损失计算的基本功能。 - 提交Pull Request:将你的更改推送到你自己的项目Fork分支,然后在原项目仓库发起Pull Request。PR描述中应清晰说明:1) 实现的算法简介;2) 所做的代码改动;3) 在标准数据集上的性能结果(可以附上截图或日志);4) 任何可能影响现有功能的变更。
贡献心得:在开始实现一个复杂算法前,先尝试在项目框架内复现一个已有的简单算法(比如自己从头实现一个简单的经验回放)。这能帮你彻底理解框架的接口和数据流。另外,多阅读项目已有的代码风格和文档规范,保持代码风格一致是PR被接受的重要因素。最后,积极与项目维护者沟通,在Issue或PR中讨论你的实现方案,可以避免走弯路。