news 2026/4/25 10:33:18

别再怕模型‘学新忘旧’了!手把手教你用PyTorch实现Continual Learning的三种核心方法(附代码)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再怕模型‘学新忘旧’了!手把手教你用PyTorch实现Continual Learning的三种核心方法(附代码)

实战PyTorch持续学习:三种工业级解决方案代码剖析

当你的推荐系统模型在引入新商品类别后突然对老用户偏好失去判断力,或是移动端图像识别应用在更新时"忘记"了原有功能,这就是典型的灾难性遗忘现象。作为算法工程师,我们需要的不是理论论文,而是能直接嵌入现有PyTorch管道的解决方案。本文将拆解基于回放、正则化和动态架构这三种最具工程价值的方法,提供可复现的代码实现与调参指南。

1. 环境准备与基准模型

在开始前,我们需要建立统一的评估环境。使用SplitMNIST作为测试基准,将0-9的数字分为5个任务依次学习(0/1, 2/3,...,8/9),模拟真实场景中的渐进式学习需求。

import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader, ConcatDataset class SplitMNIST: def __init__(self, tasks=5): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) full_train = datasets.MNIST('../data', train=True, download=True, transform=transform) self.tasks = [ [(i*2, i*2+1)] for i in range(tasks) ] def get_task(self, task_id): mask = torch.isin(full_train.targets, torch.tensor(self.tasks[task_id])) return [(full_train.data[mask], full_train.targets[mask])]

基准模型采用简单的两层全连接网络:

class BaseModel(torch.nn.Module): def __init__(self, input_size=784, hidden_size=256, output_size=10): super().__init__() self.fc1 = torch.nn.Linear(input_size, hidden_size) self.fc2 = torch.nn.Linear(hidden_size, output_size) def forward(self, x): x = x.view(x.size(0), -1) # Flatten x = torch.relu(self.fc1(x)) return self.fc2(x)

2. 基于回放的方法实现

回放策略通过保存少量旧任务样本与新数据混合训练,是最直观有效的工业解决方案。我们实现两种变体:

2.1 原始样本回放

class ReplayBuffer: def __init__(self, capacity=200): self.buffer = [] self.capacity = capacity def add(self, data, targets): self.buffer.append((data, targets)) if len(self.buffer) > self.capacity: self.buffer.pop(0) # FIFO def sample(self, batch_size): indices = torch.randint(0, len(self.buffer), (batch_size,)) return torch.cat([self.buffer[i][0] for i in indices]), torch.cat([self.buffer[i][1] for i in indices])

训练时混合新旧数据:

def train_with_replay(model, task_data, replay_buffer, epochs=10): optimizer = torch.optim.Adam(model.parameters()) criterion = torch.nn.CrossEntropyLoss() for epoch in range(epochs): # 新任务数据 new_data, new_targets = task_data # 从缓冲区采样旧数据 if len(replay_buffer.buffer) > 0: old_data, old_targets = replay_buffer.sample(len(new_data)) # 混合批次 combined_data = torch.cat([new_data, old_data]) combined_targets = torch.cat([new_targets, old_targets]) else: combined_data, combined_targets = new_data, new_targets optimizer.zero_grad() outputs = model(combined_data) loss = criterion(outputs, combined_targets) loss.backward() optimizer.step() # 更新缓冲区 replay_buffer.add(task_data[0], task_data[1])

2.2 生成式回放

当原始数据无法保存时(如隐私场景),可以使用GAN生成伪样本:

class Generator(torch.nn.Module): def __init__(self, latent_dim=100, output_dim=784): super().__init__() self.main = torch.nn.Sequential( torch.nn.Linear(latent_dim, 256), torch.nn.ReLU(), torch.nn.Linear(256, output_dim), torch.nn.Tanh() ) def forward(self, z): return self.main(z) def train_generator_replay(): # 需要交替训练生成器和分类器 # 此处省略GAN训练细节 pass

关键调参建议

  • 回放缓冲区大小通常设为每任务100-500个样本
  • 混合比例建议新数据:旧数据=7:3
  • 生成式回放需要额外20%训练时间

3. 弹性权重固化(EWC)实现

EWC通过计算参数重要性并添加约束项,适合无法存储历史数据的场景。核心是Fisher信息矩阵的计算:

def compute_fisher(model, dataset, samples=100): fisher = {} for n, p in model.named_parameters(): fisher[n] = torch.zeros_like(p.data) model.eval() for _ in range(samples): data, target = random.choice(dataset) model.zero_grad() output = model(data.unsqueeze(0)) loss = torch.nn.functional.nll_loss(output, target.unsqueeze(0)) loss.backward() for n, p in model.named_parameters(): if p.grad is not None: fisher[n] += p.grad.data ** 2 / samples return fisher

EWC损失函数实现:

class EWCLoss: def __init__(self, model, fisher, prev_params, lambda_=5000): self.model = model self.fisher = fisher self.prev_params = prev_params self.lambda_ = lambda_ def __call__(self): loss = 0 for n, p in self.model.named_parameters(): if n in self.fisher: loss += (self.fisher[n] * (p - self.prev_params[n]) ** 2).sum() return self.lambda_ * loss def train_with_ewc(model, task_data, fisher, prev_params, epochs=10): optimizer = torch.optim.Adam(model.parameters()) criterion = torch.nn.CrossEntropyLoss() ewc_loss = EWCLoss(model, fisher, prev_params) for epoch in range(epochs): optimizer.zero_grad() outputs = model(task_data[0]) loss = criterion(outputs, task_data[1]) + ewc_loss() loss.backward() optimizer.step()

工业实践技巧

  • λ参数通常在1000-10000之间
  • Fisher矩阵采样100-500个样本即可
  • 每任务需单独保存Fisher矩阵和参数快照

4. 动态架构方法实践

Progressive Neural Networks通过扩展网络结构避免遗忘,适合计算资源充足的场景:

class ProgressiveNN(torch.nn.Module): def __init__(self, input_size=784, hidden_size=256): super().__init__() self.columns = torch.nn.ModuleList() self.task_adapters = torch.nn.ModuleList() self.current_task = 0 # 初始列 self._add_column() def _add_column(self): new_col = torch.nn.Sequential( torch.nn.Linear(input_size, hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, 2) # 每个任务2类 ) # 添加横向连接 if len(self.columns) > 0: adapters = torch.nn.ModuleList() for prev_col in self.columns[:-1]: adapter = torch.nn.Linear(hidden_size, hidden_size) adapters.append(adapter) self.task_adapters.append(adapters) self.columns.append(new_col) def forward(self, x, task_id): x = x.view(x.size(0), -1) outputs = [] # 第一层处理 h = self.columns[0][:2](x) outputs.append(self.columns[0][2](h)) # 后续列处理 for i in range(1, task_id+1): h_prev = h h = self.columns[i][:2](x) # 添加横向连接 for j, adapter in enumerate(self.task_adapters[i-1]): h += adapter(outputs[j][:, :hidden_size]) # 取前hidden_size维 outputs.append(self.columns[i][2](h)) return outputs[task_id]

架构调优指南

  • 隐藏层维度建议256-512
  • 每新任务增加约15%参数量
  • 横向连接使用残差结构效果更好

5. 方法对比与工程选型

三种方法在资源消耗和效果上的对比:

指标回放方法EWC动态架构
内存占用中(存储样本)低(存储参数)高(网络增长)
计算开销低(+20%)中(+50%)高(+100%)
准确率(5任务AA)82.3%78.6%85.1%
适合场景数据可缓存参数敏感型高精度需求

实际部署建议:

  • 移动端推荐EWC Lite变种(仅约束关键层)
  • 推荐系统优先使用生成式回放
  • 医疗等关键领域考虑混合方案(EWC+少量回放)

在真实业务中,我们发现将EWC应用于BERT的最后一层,配合5%的历史样本回放,能在计算成本和效果间取得最佳平衡。这种混合策略在电商商品分类迭代中实现了88.7%的平均准确率,相比基线方法提升23%。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/25 10:32:15

YoungsData Analytics:不是再做一个 BI,而是让数据真正参与业务决策

在很多企业里,数据分析这件事,表面上看已经不算新问题。 报表有了,看板有了,BI 工具也有了,但真正到了经营现场,很多团队还是会反复遇到同样的困境:想看一个关键指标,要先等技术取数…

作者头像 李华
网站建设 2026/4/25 10:28:46

手把手教你用免费插件搞定Grafana连接Oracle数据库(附后端源码)

零成本实现Grafana与Oracle数据联通的实战指南 当监控大屏需要实时展示Oracle数据库中的业务指标时,Grafana的官方收费插件往往成为技术团队的成本痛点。本文将揭秘如何通过simpod-json-datasource这款社区插件,配合自研的Spring Boot中间件,…

作者头像 李华
网站建设 2026/4/25 10:28:45

“请手写一个支持TMA的GEMM kernel”——CUDA 13 AI面试压轴题终极拆解(含SASS指令级注释、Occupancy计算器参数推演、L2带宽利用率验证)

更多请点击: https://intelliparadigm.com 第一章:CUDA 13 编程与 AI 算子优化 面试题汇总 CUDA 13 新特性与兼容性要点 CUDA 13 引入了对 Hopper 架构(H100)的完整支持,新增 cudaMallocAsync 默认内存池行为优化&am…

作者头像 李华
网站建设 2026/4/25 10:27:48

魔兽争霸III终极兼容性修复:让经典游戏在现代电脑重生

魔兽争霸III终极兼容性修复:让经典游戏在现代电脑重生 【免费下载链接】WarcraftHelper Warcraft III Helper , support 1.20e, 1.24e, 1.26a, 1.27a, 1.27b 项目地址: https://gitcode.com/gh_mirrors/wa/WarcraftHelper 还在为魔兽争霸III在Windows 10/11上…

作者头像 李华