1. 项目概述与核心挑战
在工业级深度学习推荐系统的构建中,我们面临着一个核心的“剪刀差”困境:一方面,模型需要处理海量的稀疏特征,这些特征通常被编码为规模极其庞大的嵌入表,其参数量动辄达到数十亿甚至数百亿级别;另一方面,我们依赖像TPU这样的专用硬件加速器来获得极致的计算吞吐量,但这些硬件最初是为密集、规整的矩阵运算(如Transformer、MLP)而设计的。如何让这两者高效协同,是决定整个训练系统成本与效率的关键。我过去几年深度参与过多个超大规模推荐系统的TPU训练优化项目,从输入数据准备到梯度回传,几乎踩遍了每一个环节的坑。今天,我想系统性地拆解一下,如何针对TPU架构,对推荐系统的训练流程进行全链路性能优化,特别是聚焦于最棘手的嵌入表操作。
简单来说,优化的目标就是让昂贵的TPU芯片时刻保持“忙碌”,避免它们因为等待数据(输入管道瓶颈)或因为低效的内存访问(嵌入表查找瓶颈)而空闲。这不仅仅是调几个参数那么简单,它涉及到从数据流架构、计算图调度到底层硬件资源分配的一整套系统工程。本文将围绕两个最关键的子系统展开:一是确保数据供给不卡脖子的输入管道,二是榨干TPU SparseCore硬件潜力的嵌入表操作优化。我们会看到,通过一系列组合拳,包括共享输入生成、动态水平扩展、混合分区策略以及计算流水线化,我们最终在真实的广告推荐模型上实现了平均超过一倍的性能提升,同时显著降低了对外部CPU/内存资源的依赖。
2. 输入管道优化:告别TPU“饥饿”
在分布式训练中,TPU阵列的计算能力非常强大,但如果喂给它们数据的速度跟不上其消耗速度,那么这些昂贵的芯片就会处于闲置状态,造成巨大的资源浪费。输入管道的优化,首要目标就是解决TPU的“饥饿”问题。
2.1 从“各自为政”到“共享厨房”:共享输入生成服务
在传统的模式,即本地输入生成中,每个独立的训练任务都需要运行自己专属的输入数据预处理流水线。这意味着,即使有十个模型都在处理同一份用户点击日志,提取相似的特征(如用户历史行为序列、商品属性),它们也会各自重复地进行完全一样的特征解析、转换和归一化操作。这相当于每个厨师都在自己的小厨房里,从洗菜、切菜开始准备完全相同的菜品,造成了CPU和内存资源的极大冗余。
我们引入的共享输入生成服务,其核心思想类似于一个中央厨房。这个“中央厨房”独立于任何具体的训练任务,它负责执行最耗时的公共特征转换计算。工作流程如下:
- 特征图定义与子图识别:所有模型将其特征处理逻辑(一个计算图)注册到SIG。SIG会分析这些图,识别出其中计算代价高昂且被多个模型共享的子图。例如,从原始日志中解析出“用户过去7天点击某类目的次数”这个特征,可能涉及复杂的窗口聚合和连接操作,是典型的候选。
- 计算与缓存:SIG服务内的专用工作节点会执行这些公共子图的计算,并将结果(即转换后的特征张量)持久化到高速存储中。这个过程被称为“物化”。
- 按需分发:当各个模型的输入读取器需要数据时,它们不再重复执行完整的特征转换,而是向SIG请求对应的物化结果。SIG根据请求的批次和特征键,快速读取并返回缓存的数据。
注意:SIG并非缓存所有原始数据,而是缓存中间特征转换结果。原始数据可能非常庞大且变化频繁,而特征转换结果相对稳定且可复用。在我们的实践中,SIG的缓存命中率超过95%,平均每个物化的特征子图被超过22个模型复用,峰值时可达400个以上。
带来的收益是立竿见影的:如图6所示,与LIG相比,SIG将输入读取器的资源成本降低了4.3倍到7.5倍。虽然TPU的成本在总成本中占主导地位,但SIG依然带来了12%到27%的总训练成本下降,几何平均为18%。更重要的是,它解决了资源争用问题:在LIG模式下,为每个任务配置足够的CPU/内存资源非常困难,经常导致TPU因等待数据而利用率低下。SIG通过资源共享,确保了TPU阵列能够持续获得数据供给。
2.2 动态水平扩展:应对波动的数据需求
输入读取器的资源需求并非一成不变。它主要受两个因素影响:
- 训练阶段:在初始训练阶段,模型需要快速处理历史积累的巨量数据,此时使用SIG,输入读取负载较低。而在追新训练阶段,模型需要近乎实时地学习最新产生的数据,由于数据新鲜度要求高,无法使用SIG的缓存,必须采用LIG模式,导致CPU需求激增。
- 模型复杂度:不同模型的特征数量、转换逻辑复杂度差异巨大,对输入读取器的压力也不同。
表I的数据清晰地说明了问题:对于50%的训练流水线,其输入读取所需的CPU资源尚可由TPU主机本身满足(比值为0.7)。但对于90分位的任务,其需求是单个TPU主机资源的3.5倍;对于99分位的任务,甚至高达16倍。这意味着,仅靠TPU主机附带的CPU资源是远远不够的。
因此,我们实现了水平可扩展的输入读取器服务。该服务可以独立于TPU Pod进行部署和伸缩。其核心是一个控制器,它持续监控两个指标:
- 数据队列深度:TPU端等待处理的数据批次队列是否即将排空?
- 输入读取器利用率:当前输入读取器节点的CPU/内存使用率是否持续过高?
基于这些指标,控制器可以动态地增加或减少输入读取器实例的数量。例如,当系统进入追新训练阶段,控制器会自动扩容,增加输入读取器以应对LIG模式下的高负载;当训练进入稳定阶段或切换回使用SIG的初始训练时,则可以缩容以节省资源。
实操心得:动态伸缩的粒度不宜过细。我们通常以“任务组”为单位进行伸缩,并设置一个冷却期,避免因瞬时波动导致频繁的启停操作,这反而会引入开销和不稳定。同时,需要为输入读取器配置足够的网络带宽和低延迟存储访问,防止其成为新的瓶颈。
3. 嵌入表操作优化:征服TPU上的稀疏计算
对于推荐模型,嵌入表查找和更新是训练过程中最核心、也最耗时的稀疏操作。TPUv4通过集成专用的SparseCore硬件来高效处理这些操作,但如何用好SC,并与负责密集计算的TensorCore协同工作,是性能优化的重中之重。
3.1 理解计算流程与瓶颈
一个典型的推荐模型训练步骤中,嵌入层操作流程如下:
- 前向传播:输入读取器提供一批训练样本,每个样本包含多个特征ID。主机CPU首先对这些ID进行去重,得到一批唯一的特征值。
- 嵌入查找:去重后的ID被发送给SparseCore。SC根据这些ID,从其管理的分布式嵌入表分区中,并行地查找并收集对应的嵌入向量。
- 归约求和:由于一个特征ID可能在同一个批次的不同样本中出现多次(即“多值”特征),SC收集到的是每个唯一ID的向量。接着,这些向量被传递给TensorCore,TC根据样本-特征的映射关系,执行段求和操作,为每个样本生成其对应的聚合后的嵌入向量,作为后续MLP等密集层的输入。
- 反向传播:梯度从密集层反向传播到嵌入层。TC计算得到每个���本特征对应的嵌入梯度,这些梯度需要根据原始ID映射,分散更新回SC中对应的嵌入表行。
这个流程的瓶颈非常明显:
- 负载不均衡:不同特征ID的出现频率(热度)差异巨大,导致某些SC核心需要处理的热门ID远多于其他核心。
- 串行执行:在朴素的实现中,前向传播必须等待所有SC嵌入查找完成后,TC才能开始密集计算;反向传播也必须等待TC梯度计算完成后,SC才能开始更新。SC和TC交替空闲,硬件利用率低。
- 内存带宽压力:嵌入表通常远超单个SC的HBM容量,必须分区存放。低效的分区策略会导致跨芯片通信频繁,挤占宝贵的ICI带宽。
3.2 核心优化策略一:混合分区
单纯按行分区(将不同的嵌入表行分布到不同的SC上)是最直观的方法,但面对高度倾斜的访问分布时,效果很差。假设有4个SC,两个嵌入表T1和T2,每表4行。每行的平均访问次数分别为0.6, 0.3, 0.2, 0.1。如果采用行分区,最热门的行(0.6)落在某个SC上,最冷的行(0.1)落在另一个SC上,那么负载不均衡因子高达(4 * 0.6) / (0.6+0.3+0.2+0.1) = 2,意味着最忙的SC工作量是平均值的2倍。
我们引入了混合分区策略,它结合了三种维度:
- 表级分区:将不同的整个嵌入表放置到不同的SC集合上。适用于表之间大小和访问模式差异大的场景。
- 行级分区:将一个大表的行分散到多个SC上。这是最基础的负载分散方法。
- 列级分区:将一个嵌入向量的不同维度(列)切分到不同的SC上。这是提升性能的关键。
继续上面的例子,采用混合分区:首先进行表分区,T1放在{SC0, SC1},T2放在{SC2, SC3}。然后对每个表进行列分区,将每个64维的向量从中间切开,前32维放在一组SC(SC0, SC2),后32维放在另一组SC(SC1, SC3)。这样,每个SC最终负责存储:SC0: T1的前32列,SC2: T2的前32列,SC1: T1的后32列,SC3: T2的后32列。计算负载被完美均摊,负载不均衡因子降为1。
列分区的额外好处:它降低了内存访问的粒度。在行分区中,即使只需要嵌入向量的一部分,也必须读取整行(例如64个浮点数)。而在列分区中,可以只读取所需的那部分列。这减少了对HBM带宽的占用,虽然可能增加一次通信(如果需要完整的向量),但在带宽成为瓶颈的场景下,收益显著。
注意事项:混合分区策略的寻址逻辑变得复杂。在查找时,系统需要知道目标ID的行分布在哪个SC,以及其列切分情况,然后向多个SC发起并行的查找RPC请求,最后在TC侧将部分向量拼接成完整的嵌入向量。这需要运行时系统有精密的元数据管理和路由机制。
3.3 核心优化策略二:反馈导向的分区
混合分区解决了静态的负载不均衡问题,但还有一个动态挑战:多值特征。例如,“用户最近点击的100个商品”这个特征,其值的数量(称为“价态”)在运行时是变化的,且依赖于数据分布,编译时无法预知。价态高的特征,其嵌入查找和归约计算量也大。
反馈导向的分区的核心思想是利用运行时收集的剖析信息来指导分区决策。系统在训练过程中会持续采样并记录:
- 每个训练批次中,各个特征出现的总次数。
- 每个特征ID的唯一值数量(即去重后的价态)。
- 各SC核心的实时负载和通信量。
这些统计信息被汇总到一个“剖析数据库”中。在定期或触发式的分区决策时刻,系统会利用这些数据,将高频、高价态的特征所对应的嵌入表行或列,进行更细粒度的拆分或迁移,以平衡SC间的计算和通信负载。在我们的实验中,FDP为某些模型带来了额外的19%到21%的性能提升。
3.4 核心优化策略三:TC/SC流水线执行
这是提升系统吞吐量的“神来之笔”。观察图5,在严格的串行执行中,SC和TC如同两个必须交接棒的运动员,大部分时间总有一个在等待。流水线执行打破了这一步的严格依赖。
其原理是:允许SC提前开始下一步(Step N+1)的嵌入查找,而TC仍在处理当前步(Step N)的密集层计算。从数学上看,这相当于在反向传播更新嵌入时,使用的是上一步(Step N)的梯度,而非当前步(Step N+1)的梯度,即梯度延迟了一拍。
为什么这可行且有效?
- 梯度延迟的容忍性:在推荐模型这种超大规模、数据噪声丰富的场景下,训练过程本身具有很强的随机性(如大规模SGD)。延迟一拍的梯度可以看作是在梯度中引入了一个微小的、有偏的噪声。大量实验表明,这种延迟对模型的最终收敛质量和效果没有可观测的负面影响。
- 硬件利用率大幅提升:如图7所示,流水线化后,训练每一步的时间从
TC_time + SC_time缩短为max(TC_time, SC_time)。只要TC和SC的计算时间不是严重失衡,就能获得接近线性的加速。这对于那些嵌入层计算和密集层计算耗时相近的模型尤其有效。
实操心得:启用流水线会加剧对共享资源(如HBM和ICI)的争用,因为TC和SC同时在活跃地访问内存和通信。因此,需要仔细监控这些资源的利用率,并可能需要对混合分区策略进行微调,以平衡计算负载和通信压力。通常,这是一个迭代调优的过程。
4. 系统级保障:鲁棒的资源与容错管理
优化性能的同时,必须保证系统的稳定性和资源效率。在共享数据中心,训练任务动辄运行数周甚至数月,期间必然会遇到各种干扰。
4.1 智能错误处理与训练挂起
训练任务可能进入无法继续的状态。我们将其分为两类,并采取不同策略:
- 永久性错误:如模型配置错误、编译失败、内存溢出、数值溢出(NaN)。系统检测到此类错误后,会主动放置一个训练挂起信号。这意味着任务会暂停并释放所有TPU和输入读取器资源,等待工程师介入排查。这避免了宝贵的TPU资源被一个注定失败的任务长期占用。
- 瞬时性停滞:最常见的原因是输入数据尚未就绪(例如,SIG仍在物化当前训练所需的数据范围)。训练管道能检测到这种“数据未准备好”的状态,并同样触发训练挂起,释放计算资源。但与永久错误不同,系统的控制器会保持活跃,持续轮询数据可用性。一旦数据准备就绪,控制器会自动解除挂起,任务无缝恢复。
这种机制至关重要。如表II所示,在实际系统中,处于“挂起”状态的模型所需求的TPU芯片量,是正在活跃训练芯片量的2.49倍。如果不能快速识别并挂起出错任务,这些“僵尸”任务将严重浪费集群资源。
4.2 优雅的抢占处理
在共享环境中,高优先级任务、软件滚动更新或硬件维护都可能要求当前任务提前终止。粗暴地杀死任务会导致当前训练周期(epoch)的进度全部丢失。
我们实现了抢占通知协议。外部调度系统在决定要抢占某个任务时,会提前发送一个“预抢占通知”。任务接收到通知后:
- 通知控制器:任务主进程通知中央控制器即将关闭。
- 广播与收尾:控制器广播此消息给该任务的所有工作节点(如输入读取器)。所有节点尝试完成手头正在处理的工作单元。
- 检查点保存:训练任务迅速将当前进度(模型参数、优化器状态等)保存到持久化存储中。
- 优雅退出:保存完成后,任务主动退出。
- 等待重启:控制器会阻塞下一个训练周期的开始,直到该任务在新的资源上重启并重新加入训练管道。这确保了被抢占的epoch进度得以保留,实现了“断点续训”。
在我们的实践中,约61%的被抢占epoch成功保存了进度并得以恢复,极大地减少了因中断造成的计算浪费。
5. 性能评估与效果分析
我们在一套由128个TPUv4芯片组成的系统上,选取了五个占生产负载超过50%的代表性推荐模型进行评估。这些模型的密集部分参数量在5000万到3亿之间,并包含数百个大小、维度、访问模式各异的嵌入表。
5.1 嵌入优化效果分解
图7清晰地展示了各项优化技术的累积效果:
- 基线:仅使用行分区,且TC/SC串行执行。这是最朴素的实现。
- 流水线:启用TC/SC流水线执行。所有模型均获得提升,但对于SC瓶颈显著的模型A和E,提升幅度相对较小,因为其SC耗时远大于TC,流水线后整体耗时仍由SC决定。
- 混合分区:在流水线基础上,启用行、列、表混合分区。通过更精细的负载均衡,所有模型性能进一步上涨。特别是对于之前SC瓶颈的模型,负载被均摊后,SC耗时下降,与TC的耗时更加匹配,从而让流水线的效果得以充分发挥。
- 反馈导向分区:在前两者基础上,加入基于运行时剖析的FDP。对于模型B和E,由于存在价态变化剧烈的多值特征,FDP带来了额外的19%和21%的性能飞跃。模型A因主要负载集中在少数几个表上,即使不用FDP也能较好分布,因此提升有限。模型C和D在混合分区后已变为TC瓶颈,因此FDP对它们没有进一步帮助。
最终,这一系列嵌入优化技术为五个模型带来了58%到180%的性能提升,几何平均提升高达116%。
5.2 成本与效率的权衡
优化不仅是追求速度,更是追求性价比。SIG通过共享计算,将输入读取器的成本降低了数倍,虽然对以TPU成本为主的总成本影响比例看似不大(平均降低18%),但其战略意义在于:它使得为庞大的TPU集群持续供给数据成为可能,从而将TPU的利用率提升到了一个新的高度。没有SIG,TPU将因为输入瓶颈而大量闲置,实际训练成本会成倍增加。
6. 经验总结与未来展望
回顾整个优化历程,有几点深刻的体会:
- 全链路视角至关重要:不能只盯着TPU矩阵乘的峰值算力。输入管道、嵌入查找、通信、容错,任何一个环节的短板都会成为整个系统的瓶颈。必须像对待核心算法一样,对待这些系统工程问题。
- 拥抱“不精确”以换取效率:流水线执行带来的梯度延迟,本质上是一种用极小的、经验证可忽略的精度代价,换取巨大硬件利用率提升的权衡。在工业级系统中,这类基于对问题深刻理解的、可控的近似优化,往往是突破性能瓶颈的关键。
- 数据驱动的动态调优:无论是反馈导向的分区,还是输入读取器的动态伸缩,都依赖于对运行时数据的持续收集和分析。静态的、一刀切的配置无法应对生产环境中复杂多变的工作负载。
未来的探索方向也基于当前实践的延伸:
- 更智能的混合存储:目前嵌入表全部驻留在TPU HBM中。未来可以探索分层存储,将访问频率极低的“冷”嵌入行或表,卸载到主机内存甚至SSD上,从而在有限的HBM容量下支持更大的模型。这其中的挑战在于如何高效、动态地识别和迁移“冷热”数据。
- SIG的演进:一是进一步减少存储开销,例如只物化计算图中最昂贵的部分子图,而非全图。二是探索支持可变数据的SIG,使得追新训练也能受益于缓存共享,但这会引入数据一致性和生产风险共享的复杂问题。
优化大规模推荐系统的TPU训练,是一场持续在算法、系统、硬件交叉地带进行的工程探险。每一次性能的百分比提升,背后都是对数据流、计算图、硬件资源更精细的雕刻与编排。希望这些从实战中总结出的思路和细节,能为同行们提供一些有价值的参考。