https://arxiv.org/pdf/2402.09353
DoRA: Weight-Decomposed Low-Rank Adaptation
DoRA: Weight-Decomposed Low-Rank Adaptation
DoRA(Weight-Decomposed Low-Rank Adaptation)是一种用于大模型微调的高效参数优化方法,通过分解权重矩阵并结合低秩适配技术,显著减少训练参数量,同时保持模型性能。
核心思想
DoRA将预训练权重矩阵分解为幅度(magnitude)和方向(direction)两部分,并对方向部分应用低秩适配(LoRA)。这种分解方式能够更精细地控制权重更新,避免直接微调全参数带来的计算开销。
数学表达形式为:
W=m⋅V∣∣V∣∣F W = m \cdot \frac{V}{||V||_F}W=m⋅∣∣V∣∣FV
其中:
- WWW是原始权重矩阵
- mmm是幅度标量
- VVV是方向矩阵
- ∣∣V∣∣F||V||_F∣∣V∣∣F表示Frobenius范数(矩阵元素的平方和的平方根)
实现方法
权重分解
将原始权重WWW分解为幅度mmm和归一化方向V∣∣V∣∣F\frac{V}{||V||_F}∣∣V∣∣FV。幅度表示权重的重要性,方向决定特征变换的性质。
低秩适配
对方向矩阵VVV应用LoRA技术,使用低秩矩阵AAA和BBB进行更新:
V=V0+BA V = V_0 + BAV=V0+BA
其中V0V_0V0是冻结的初始方向,A∈Rr×kA \in \mathbb{R}^{r×k}A∈Rr×k,B∈Rd×rB \in \mathbb{R}^{d×r}B∈Rd×r是可训练的低秩矩阵(r≪d,kr \ll d,kr≪d,k)。
训练过程
仅训练幅度参数 $ m $ 和低秩矩阵 $ A,B $,冻结原始权重 $ W $。更新公式为:
W′=m′⋅V0+BA∣∣V0+BA∣∣F W' = m' \cdot \frac{V_0 + BA}{||V_0 + BA||_F}W′=m′⋅∣∣V0+BA∣∣FV0+BA
优势特点
- 参数效率:相比全参数微调,可减少90%以上的训练参数量。
- 性能保留:在多项NLP任务中达到或超过全微调(full fine-tuning)的效果。
- 训练稳定:幅度与方向解耦使优化过程更平滑,避免梯度爆炸/消失。
- 模块化设计:可灵活应用于Transformer的各类权重矩阵(Q/K/V/FFN)。
注意事项
- 秩(rank)的选择需要平衡参数效率和性能,通常4-32之间效果较好。
- 初始化策略影响收敛速度,建议对 ( A ) 使用Kaiming初始化,( B ) 初始化为零。
- 可与其它高效微调方法(Adapter、Prefix-tuning)结合使用。
https://github.com/NVlabs/DoRA/blob/main/commonsense_reasoning/peft/src/peft/tuners/dora.py
self.weight_m_wdecomp=nn.Linear(1,out_features,bias=False)# self.weight_m_wdecomp.weight # shape: out_features, 1self.fan_in_fan_out=fan_in_fan_out self.Wdecompose=Wdecompose# whether to tune only the magnitude component of Wdecompose or notself.dora_simple=dora_simple# whether to use dora simple to save up GPU memoryifself.Wdecompose==False:ifr>0:self.lora_A=nn.Linear(in_features,r,bias=False)self.lora_B=nn.Linear(r,out_features,bias=False)self.scaling=self.lora_alpha/self.r# Freezing the pre-trained weight matrixself.weight.requires_grad=Falseself.reset_parameters()iffan_in_fan_out:self.weight.data=self.weight.data.Tdefreset_parameters(self):nn.Linear.reset_parameters(self)ifhasattr(self,"lora_A"):# initialize A the same way as the default for nn.Linear and B to zeronn.init.kaiming_uniform_(self.lora_A.weight,a=math.sqrt(5))nn.init.zeros_(self.lora_B.weight)deftrain(self,mode:bool=True):nn.Linear.train(self,mode)ifself.Wdecompose==False:self.lora_A.train(mode)self.lora_B.train(mode)self.weight_m_wdecomp.train(mode)ifnotmodeandself.merge_weightsandnotself.merged:# Merge the weights and mark itifself.Wdecompose:norm_scale=(self.weight_m_wdecomp.weight/(torch.linalg.norm(self.weight,dim=1)).unsqueeze(1))weight=norm_scale*self.weight self.weight.data.copy_(weight.detach())else:ifself.r>0:new_weight_v=self.weight+transpose(self.lora_B.weight @ self.lora_A.weight,fan_in_fan_out=self.fan_in_fan_out)*self.scaling weight=(self.weight_m_wdecomp.weight/(torch.linalg.norm(new_weight_v,dim=1)).unsqueeze(1))*new_weight_v self.weight.data.copy_(weight.detach())self.merged=Trueelifself.merge_weightsandself.merged:raiseNotImplementedErroradapters中的配置
config=DoRAConfig()model.add_adapter("dora_adapter",config=config)