news 2026/1/23 10:58:35

DanceGRPO+FLUX:多模态生成强化学习模型的高效

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
DanceGRPO+FLUX:多模态生成强化学习模型的高效

一、背景介绍

Flux 模型小模型高效生成高质量图像的基础

Flux 虽是百亿级参数的大模型家族,但其中的轻量化变体(如 Flux.1 (schnell))以及核心技术,为小尺寸模型提供了高效生成的范式。其关键技术优势适配小模型的优化需求,具体体现在两点。一是采用 Rectified Flow(校正流)技术,拉直了传统扩散模型从噪声到图像的生成路径,将生成过程优化为近似直线的最短路径,大幅减少采样步数。像 Flux.1 (schnell) 仅需 4 步左右采样就能生成合理图像,这对小模型而言,意味着在降低计算成本的同时,还能避免多步迭代带来的精度损耗。二是创新的多模态融合架构,通过双文本编码器(CLIP+T5)精准解析文本语义,再结合双流转单流的 Transformer 注意力机制,实现文本与图像特征的深度交互。这种设计让小模型无需复杂结构,就能高效捕捉图文关联,提升生成图像的内容一致性。

DanceGRPO 框架:通过强化学习进一步提升小模型性能

DanceGRPO 是专门针对视觉生成领域 RLHF 方案不成熟的问题设计,能精准解决小模型训练中质量提升的核心痛点,具体优势有三。其一,兼容性强,适配 Flux 的核心范式。该框架创新性地将扩散模型和校正流模型(如 Flux)统一视为随机插值的特殊情况,二者的采样过程均可通过 SDE 实现,这让它能无缝对接 Flux 模型,针对性地开展强化学习优化,无需对 Flux 的基础架构做大幅修改,降低了小模型适配强化学习的成本。其二,显存压力低,适配小模型训练资源限制。此前 ReFL 等强化学习方案需对奖励模型和 VAE 解码特征反向传播,在视频生成等场景中显存压力极大,根本不适合小模型。而 DanceGRPO 通过采样部分时间步加速训练、去除作用不大的 KL 散度正则项等设计,大幅降低了计算和显存开销,同时还能让小模型在更多提示词样本上学习,提升泛化能力。其三,强化学习效果显著,精准优化核心指标。该框架通过多奖励模型叠加(图像美感、图文匹配等五类指标),让小模型能针对性提升薄弱项;同时通过固定初始化噪声、控制梯度更新频率等优化手段,避免训练中的奖励作弊和多样性下降问题。

强化学习框架对比

二、环境依赖

三、DanceGRPO+FLUX 整体流程

推理阶段(去噪生成图片,用于训练过程观察)

  1. **加载文本信息:**获取初始数据,将数据复制成 N 份作为输入。
  2. **去噪:**生成初始噪声,input 和当前噪音输入到 policy mode 预测噪声成分,去噪生成 latents。
  3. **图片生成保存:**基于推理阶段输出 latents,经过 vae mode 解码成 image,保存为文件用于观察过程。

关键点

  • 推理去噪生成图像:该模型中,默认一组生成 12 哥样本,即一个 prompt 会生成 12 哥大体相似而细节不同的图像,每个图像默认经过 16 步迭代去噪生成。
  • 去噪步长:步长随时间步长从大到小,因为初始噪声成分较多,相当于勾勒轮廓去噪步长可以大些,后面要收敛到正确终点,相当于描绘细节,需要慢慢去噪。
  • 图像多样性:去噪过程会加入随机扰动,局部优化,因此会有一组默认 12 张图片,每张整体相似而细节有差异的图片;一组内会进行对比,提升优势动作的概率。

Reward 阶段(jisaunq reward 值)

  1. **计算奖励值:**image 和 prompt 输入到 reward model,计算得到 reward 值。
  2. **计算相对优势值:**计算 reward 的组内平均值,每个 reward 和平均值比较,得到 advantage(组内相对优势)。

reward 详细流程

  • 计算 reward 值:基于 prompt,图例阶段得到的完全去噪的 image 值,输入到 reward mode 中,经过一系列计算得到每个 image 的 reward 值。
  • 计算 advantage 值:reward 值经过组内平均得到平均值,再用每个 iamge 的 reward 值和平均值对比,得到 advantage (相对优势值)。

训练阶段(计算 loss,更新梯度)

  1. **记录去噪过程:**前面步骤会记录每个样本的去噪过程状态,包括 reward 值,advantag 值,log_p 值(代表当时策略的对数)。
  2. **计算新策略对数:**此时 policy model 会生成新预测值,根据新预测值计算出 new_log_p 值(代表新策略的对数)。
  3. **计算旧策略比率:**f(new_log, old_log) = ratio,代表某行为在新旧策略的概率比。
  4. **计算 loss 值:**基于 ratio 和 advantage 计算出 loss 值。

训练详细流程

  • loss 是基于 advantage 和 ratio 计算得出的,当 advantage 和 ratio 处于不同值时代表不同的含义
advantageratio含义
>0>1该动作为优势动作,且新策略该动作概率更大,新策略正确的提升了该动作的概率,新策略更优
>0<1该动作为优势动作,且新策略该动作概率更小,新策略错误的抑制了优势动作,后续需要提高 ratio

  • 第一次计算 loss 时,policy mode 还没有更新权重,此时 new_log_p 和 old_log_p 实际上是一样的,就是虽然定义上是新旧策略,但实际上新旧策略的权重一样。
  • loss 值会基于 advantage 和 ratio 一并计算,所以开始的 loss 值依赖于样本的 advantage 值,默认的梯度更新频率为 4 哥样品一次,当处理本组第四个样品后 ratio 就会开始变化了。

存在两个 loss 值,clipped_loss 和 unclipped_loss,都是基于 advantage 和 ratio 计算得到的,但是 clipped_loss 的计算中加入了 clip_range,约束了最终计算值的范围,防止局部过度优化。

四、模型部署流程

  1. 拉取代码:GitHub - XueZeyue/DanceGRPO: An official implementation of DanceGRPO: Unleashing GRPO on Visual Generation
gitclone https://github.com/XueZeyue/DanceGRPO.git

  1. 下载权重

FLUX:https://huggingface.co/black-forest-labs/FLUX.1-dev

HPS:https://huggingface.co/xswu/HPSv2/tree/main

open_clip:https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/tree/main

  1. 其它依赖安装
  1. 仓库未实现懒加载,所以会导入许多用不到的三方库,可以直接注释,避免引入太多无用的依赖,耗费开发时间。
  2. 一些为调用的接口也可以进行规避,例如 flashatth 三方库接口等。
# DanceGRPO/fastvideo/models/mochi_hf/modeling_mochi.py# 注释掉以下from liger_kernel.ops.swigluimportLigerSiLUMulFunction;flash_attn_no_pad.py# flash_attn_no_pad.py# 注释掉flash_attn的导包,flash_attn_no_pad注释掉中间逻辑,直接return;

执行安装脚本:

./env_setup.sh fastvideo
  1. 修改<font style="color:rgb(37, 43, 58);background-color:rgb(246, 247, 249);">preprocess_flux_embedding.py</font>
# # 引入torch_npuimporttorch_npu from torch_npu.contribimporttransfer_to_npu# "./data/flux"写死的路径改成参数# 原 : pipe = FluxPipeline.from_pretrained("./data/flux", torch_dtype=torch.bfloat16).to(device)pipe=FluxPipeline.from_pretrained(args.model_path,torch_dtype=torch.bfloat16).to(device)
  1. 修改<font style="color:rgb(37, 43, 58);background-color:rgb(246, 247, 249);">train_grpo_flux.py</font>
# # 引入torch_npuimporttorch_npu from torch_npu.contribimporttransfer_to_npu
  1. 执行 Flux GRPO 脚本:
bash./scripts/finetune/finetune_flux_grpo.sh

五、模型验证

验证流程将 GRPO 的推理、reward、训练三个阶段单独抽离对齐,再进行全流程验证,采用“分 - 合” 验证策略:

  • 单独阶段对齐能隔离不同模型和框架的差异,聚焦每个环节的前向计算准确性(比如推理阶段的动作生成、reward 阶段的评分计算、训练阶段的梯度更新),避免因单个阶段误差累积掩盖问题。
  • 全流程对齐则能验证阶段间数据传递的一致性,尤其要关注跨框架交互时的数据格式、精度损失等细节。

记录关键节点的对齐数据(如中间特征、概率分布、loss 值、梯度等),既能作为阶段验证的基准,也能在全流程中快速定位误差来源。

随机性固定

load 版本(准确但麻烦)

通过torch.savetorch.load的方式将程序中涉及随机性的变量,在 NPU 和 GPU 上保持一致。

  1. 关闭shuffle,固定训练的数据顺序
# fastvideo/train_grpo_flux.py中,shuffle设为falsesampler=DistributedSampler(train_dataset,rank=rank,num_replicas=world_size,shuffle=False,seed=args.sampler_seed)
  1. prev_sample 固定
    1. GPU代码修改如下,在GPU上运行后保存下来
# 1. 添加全局变量COFF_STEP,控制coff生成的step数COFF_STEP=0def flux_step(): global COFF_STEP......ifgrpo and prev_sample is None: coff=torch.randn_like(prev_sample_mean)torch.save(coff, f"saves/coff_{COFF_STEP}_{torch.distributed.get_rank()}.pt")prev_sample=prev_sample_mean + coff * std_dev_t COFF_STEP+=1
2. NPU 上加载
coff=torch.load(f"saves/coff_{COFF_STEP}_{torch.distributed.get_rank()}.pt",map_location=f"cuda:{torch.cuda.current_device()}")
  1. input_latents 固定
    1. GPU代码修改如下,在GPU上运行后保存下来
def sample_reference_model(args, device, transformer, vae, encoder_hidden_states, pooled_prompt_embeds, text_ids, reward_model, tokenizer, caption, preprocess_val, step,# # # 增加参数输入,用于序列文件记录,找到相关调用处,加上该入参)def train_one_step(args, device, transformer, vae, reward_model, tokenizer, optimizer, lr_scheduler, loader, noise_scheduler, max_grad_norm, preprocess_val, step,# # # 增加参数输入,用于序列文件记录,找到相关调用处,加上该入参)def sample_reference_model();......ifargs.init_same_noise: input_latents=torch.randn((1, IN_CHANNELS, latent_h, latent_w),# (c,t,h,w)device=device,dtype=torch.bfloat16,)torch.save(input_latents, f"saves/input_latents_{step}_{torch.distributed.get_rank()}.pt")
2. NPU上加载
input_latents=torch.load(f"saves/input_latents_{step}_{torch.distributed.get_rank()}.pt",map_location=f'cuda:{device}')
  1. perms 固定
    1. GPU代码修改如下,在GPU上运行后保存下来
def train_one_step():......perms=torch.stack([torch.randperm(len(samples["timesteps"][0]))for_inrange(batch_size)]).to(device)torch.save(perms, f"saves/perms_{step}_{torch.distributed.get_rank()}.pt")
2. <font style="color:rgb(37, 43, 58);">NPU上加载</font>
perms=torch.load(f"saves/perms_{step}_{torch.distributed.get_rank()}.pt",map_location=f'{device}')

使用 CPU 进行随机性固定

固定seed可用于模型训练复现,但是不同的设备如GPU和NPU在同样的seed下生成的值也是不一样的,但是不同设备上都有CPU,因此可以固定seed后使用CPU生成张量,以此让GPU和NPU上生成的张量输入保持相同

  1. fastvideo/train_grpo_flux.py:91修改为
ifgrpo and prev_sample is None: prev_sample=prev_sample_mean + torch.randn_like(prev_sample_mean.cpu()).to(prev_sample_mean.device)* std_dev_t
  1. <font style="color:rgb(59, 62, 85);">fastvideo/train_grpo_flux.py:270</font>修改为
ifargs.init_same_noise: input_latents=torch.randn((1, IN_CHANNELS, latent_h, latent_w),# (c,t,h,w)dtype=torch.bfloat16,).to(device)
  1. fastvideo/train_grpo_flux.py:657修改为
sampler=DistributedSampler(train_dataset,rank=rank,num_replicas=world_size,shuffle=False,seed=args.sampler_seed)
  1. fastvideo/train_grpo_flux.py:1061增加
importrandom def seed_all_own(seed=1234,mode=True,is_gpu=True): random.seed(seed)os.environ['PYTHONHASHSEED']=str(seed)os.environ['GLOBAL_SEED']=str(seed)np.random.seed(seed)torch.manual_seed(seed)torch.use_deterministic_algorithms(mode)ifis_gpu: os.environ['CUBLAS_WORKSPACE_CONFIG']=':4096:8'os.environ['CUDA_LAUNCH_BLOCKING']='1'torch.cuda.manual_seed_all(seed)torch.cuda.manual_seed(seed)torch.backends.cudnn.deterministic=True torch.backends.cudnn.enable=False torch.backends.cudnn.benchmark=False else:importtorch_npu os.environ['HCCL_DETERMINISTIC']='true'os.environ['CLOSE_MATMUL_K_SHIFT']='1'torch_npu.npu.manual_seed_all(seed)torch_npu.npu.manual_seed(seed)print("====== seed all ========")seed_all_own(is_gpu=False)from msprobe.pytorchimportseed_all seed_all(mode=True)

推理流程对齐

推理流程对齐的内容主要是 GRPO 去噪后生成的 latents,latents 解码成图片后对比:固定随机性,将GPU、NPU上使用相同noise的latents使用vae解码,再保存,此时只需要对比生成图片的差异。 关键代码:decoded_image[0].save(img_path),这里会保存训练过程中,模型在每个step,每次generation中生成的图片,可以直观的看到训练过程中的的变化。

# # # # sample_reference_model函数def sample_reference_model(): with torch.inference_mode(): with torch.autocast("cuda",dtype=torch.bfloat16): latents=unpack_latents(latents, h, w,8)latents=(latents /0.3611)+0.1159image=vae.decode(latents,return_dict=False)[0]decoded_image=image_processor.postprocess(image)decoded_image[0].save(f"./images/flux_{step}_{rank}_{index}.png")

Reward Model 对齐

DanceGRPO 模型涉及多个 model,强化学习中需要对齐的主要是loss和reward值,这里讲的是如何对齐reward。

此处采取的方法是把reward model单独拿出来,for循环多步,对比GPU和NPU的值reward值,代码修改如下:

forstepinrange(1,1001):# text = tokenizer([batch_caption[0]]).to(device=device, non_blocking=True)image=torch.load(f"/home/grpo/DanceGRPO/save/images-1/image_{step}_{rank}.pt")text=torch.load(f"/home/grpo/DanceGRPO/save/texts-1/text_{step}_{rank}.pt")# torch.save(image, f"/home/GRPO/DanceGRPO/save/images-1/image_{step}_{rank}.pt")# torch.save(text, f"/home/GRPO/DanceGRPO/save/texts-1/text_{step}_{rank}.pt")ifrank==0: print(f"image_{rank}_{step}: ", image,"\n\n")print(f"text_{rank}_{step}: ", text,"\n\n")with torch.no_grad(): with torch.amp.autocast("cuda"): outputs=reward_model(image, text)ifrank==0: print(f"output_{rank}_{step}: ", outputs,"\n\n")image_features, text_features=outputs["image_features"], outputs["text_features"]logits_per_image=image_features @ text_features.T hps_score=torch.diagonal(logits_per_image)all_rewards=[]all_rewards.append(hps_score)all_rewards=torch.cat(all_rewards,dim=0)samples={"rewards":all_rewards.to(torch.float32)}ifrank==0: print(f"samples_{rank}_{step}: ", samples,"\n\n")gathered_reward=gather_tensor(samples["rewards"])ifrank==0: print(f"gather_reward_{rank}_{step}: ", gathered_reward,"\n\n")ifdist.get_rank()==0: print("gathered_hps_reward", gathered_reward)with open('./hps_reward.txt','a')as f: f.write(f"{gathered_reward.mean().item()}\n")samples_batched={k: v.unsqueeze(1)fork,vinsamples.items()}samples_batched_list=[dict(zip(samples_batched, x))forxinzip(*samples_batched.values())]fori, sampleinlist(enumerate(samples_batched_list)):ifrank==0: print(f"sample_{rank}_{step}: ", sample["rewards"],"\n\n")ifdist.get_rank()%8==0: print("hps reward", sample["rewards"].item(),"\n\n\n\n\n")# print("ratio", ratio)# print("advantage", sample["advantages"].item())# print("final loss", loss.item())

生成1000个reward值,其精度对比效果如下(绝对误差≈0.015%):

数据、图片来自昇腾官方数据。

端到端对齐

对齐标准

固定随机性后,需要按照如下标准关注对齐结果:

  • 关注推理阶段生成的图片,主观对齐
  • 关注训练过程中的loss(生成模型loss较小,参考价值有限)
  • 关注reward scores,200步误差5%以内

对齐步骤

端到端对齐流程主要关注两方面,一方面是综合度量模型训练的指标:推理阶段图片+loss+rward scores,另一方面是下游任务推理效果。

全流程对齐具体步骤:

  • 两边加载相同的预训练权重。
  • 固定随机性:整体随机性与确定性计算固定(seed_all,mode=True),noise在cpu侧生成。
  • 保存关键信息:推理阶段的图片、reward阶段的rewardvalues、训练阶段模型loss,同时保存权重,用于对齐推理效果,此处注意需要持续关注推理阶段生成图片的效果,具体例子为在替换rope融合算子时,loss结果与reward差异不大,但推理阶段出现了花图。

端到端流程结构

六、常见问题

如遇到ROPE部分不支持complex128计算问题,NPU场景需要适配修改___CODE_BLOCK_PLACEHOLDER___211250

is_mps=ids.device.type=="mps"is_npu=ids.device.type=="npu"#增加改行##下面增加is_npu判断freqs_dtype=torch.float32ifis_mps or is_npuelsetorch.float64

七、总结

DanceGRPO+FLUX 模型在 AI 生图领域,解决 FLUX 在生成过程中与人类审美、语义对齐等方面的适配问题,大幅提升其生图质量与稳定性。展望未来,多模态生成强化学习模型有望在更多领域开花结果,如影视特效制作中实现更逼真的虚拟场景与角色创建,教育领域中打造沉浸式的学习环境,医疗领域辅助医生进行手术模拟与病情可视化分析等 。同时,随着技术发展,模型将不断优化,生成效率与质量进一步提升,在处理复杂任务、理解模糊指令等方面取得更大突破,为各行业数字化转型与创新发展注入强大动力 。

注明:昇腾PAE案例库对本文写作亦有帮助。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/1/17 0:09:57

day42(12.23)——leetcode面试经典150

86. 分隔链表 86. 分隔链表 咱也是成功发现leetcode的bug了哈哈哈 题目&#xff1a; 题解&#xff1a; /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}* ListNode(int val) { this.val val;…

作者头像 李华
网站建设 2026/1/17 1:49:22

html转盘抽奖程序

网页代码如下&#xff1a; <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>幸运转盘抽奖</tit…

作者头像 李华
网站建设 2026/1/20 15:31:58

Excel表格大全:模板+教程合集(每日更新)

本期介绍&#xff08;预览图在下方&#xff09;&#xff1a; Excel 表格模板包、视频教程、图文教程及配套练习素材&#xff0c;核心覆盖个人工作计划、企业多部门办公的全场景表格模板&#xff0c;以及从基础到进阶的 Excel 技能教程。适用人群包括职场办公族、财务人员、企业…

作者头像 李华
网站建设 2026/1/21 22:10:34

基于langchain1.X构建企业级智能体开发平台之环境和项目搭建

前提说明&#xff1a;由于langchain1.0之前的版本和现在的1.0有非常大的调整&#xff1b;我这边的langchain指的是langchain1.0及以后的版本; 项目说明&#xff1a;我们这个教程并不是一步步从0开始教大家上手langchain框架&#xff0c;而是要求大家具备了一定的了解基于这个项…

作者头像 李华
网站建设 2026/1/14 19:27:16

基于SpringBoot的冷链运输生鲜销售系统计算机毕业设计项目源码文档

项目整体介绍在生鲜电商规模化、冷链管控精细化需求升级的背景下&#xff0c;传统生鲜销售存在 “冷链轨迹不可溯、损耗率高、订单履约低效” 的痛点&#xff0c;基于 SpringBoot 构建的冷链运输生鲜销售系统&#xff0c;适配消费者、冷链运维人员、商家、平台管理员等角色&…

作者头像 李华