openpi使用流匹配方式来训练专家模型
流匹配原理
https://blog.csdn.net/qq_37795208/article/details/159049034https://blog.csdn.net/qq_37795208/article/details/159049034openpi论文解读:
https://blog.csdn.net/qq_37795208/article/details/159049034https://blog.csdn.net/qq_37795208/article/details/159049034
借用上述博客的相关内容
PI0.5:输入图像、任务文本、机器人当前状态,再加上一段“当前还带噪声的动作序列”,模型学习预测“应该朝哪个方向把这段动作去噪”。训练时学这个方向,推理时从纯噪声开始反复更新 10 步左右,最后得到可执行的动作 chunk
1.流匹配大致原理
2.对应的流匹配计算代码
对应openpi/sr/openpi/models/pi0.py中的损失函数计算定义如下
#flow matching 的核心 @override def compute_loss( self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False ) -> at.Float[at.Array, "*b ah"]: ## 1. 数据增强与预处理 preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3) observation = _model.preprocess_observation(preprocess_rng, observation, train=train) # 2. 采样噪声与时间步 t batch_shape = actions.shape[:-2] noise = jax.random.normal(noise_rng, actions.shape)# 纯噪声,可以理解为一个随机的目标? time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001#采样时间Beta(1.5, 1)再裁剪到 [0.001, 0.999],训练时会随机抽取不同的 t,让模型见过不同去噪阶段 time_expanded = time[..., None, None] #noise:从高斯分布采样的一段噪声动作;t:flow matching 的时间步,从 0 到 1; #x_t = t * 噪声 + (1-t) * 真实动作,真实和噪声连线间的某一点 x_t = time_expanded * noise + (1 - time_expanded) * actions #x_t真实动作和噪声在时间步 t 下的线性混合状态 u_t = noise - actions #理论上对应的目标速度场 # 4. 前向传播:图像+文本+噪声动作 → 模型预测速度场 v_t # one big forward pass of prefix + suffix at once #prefix只有环境的编码,suffix只有动作和时间编码 prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)#prefix = 固定不变的环境信息(图像 + 文本 + 机器人状态),prefix_tokens:图像 + 文本 编码后的 tokens suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)#suffix = 每次都变的动作信息(带噪动作 x_t + 时间 t),suffix_tokens:带噪动作 + 时间 编码后的 tokens input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1) ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0) attn_mask = make_attn_mask(input_mask, ar_mask) positions = jnp.cumsum(input_mask, axis=1) - 1 (prefix_out, suffix_out), _ = self.PaliGemma.llm(#(图像文本编码,动作编码)全部拼起来 → 送入大模型(Gemma / PaliGemma) [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond] ) v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) #模型预测出来的速度场(最后一层线性层 → 输出 v_t) return jnp.mean(jnp.square(v_t - u_t), axis=-1) #最终损失 MSE,Loss = MSE( 模型预测速度场 v_t , 真实最优速度场 u_t )2.1)首先进行随机采样噪声noise,根据不同的时间步随机构造插值点x_t,并计算真实的方向向量u_t
2.2)输入数据的准备
这里会在输入前准备两组数据,分别为代表环境的图像+文本数据和代表动作的带噪声动作和时间数据。
2.2.1)图像和文本的token准备
这里prefix_tokens是将obs中的图像和语言通过预训练的PaliGemma VLM模型,将图像和字符串变成token。这里由于设置了train=false,图像编码器是冻结的不可训练,文本embedding 层可训练
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)#prefix = 固定不变的环境信息(图像 + 文本 + 机器人状态),prefix_tokens:图像 + 文本 编码后的 tokens suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)#suffix = 每次都变的动作信息(带噪动作 x_t + 时间 t),suffix_tokens:带噪动作 + 时间 编码后的 tokens其中这里的prefix_tokens中的文本token是包含了机械臂的状态state(位姿,关节),在如下可以看出
def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]: cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ") if state is not None: # This is the Pi05 format, where the state is part of the discrete language input. discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 state_str = " ".join(map(str, discretized_state)) full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " tokens = self._tokenizer.encode(full_prompt, add_bos=True)2.2.2)带噪声的动作和时间的映射
suffix_tokens是带噪声动作的token,adarms_cond是时间的adaRMS
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)#suffix = 每次都变的动作信息(带噪动作 x_t + 时间 t),suffix_tokens:带噪动作 + 时间 编码后的 tokens2.2.3)完整的数据输入过程
其中self.PaliGemma.llm不是用于预测token,而是将输入信息经过预训练大模型的 Transformer 来融合。
# prefix_out: (B, P, 2048) ← prefix 对应的输出(通常丢弃) # suffix_out: (B, H, 2048) ← suffix 对应的输出(用于预测动作) # 输入:环境信息(prefix_tokens)+ 动作信息(suffix_tokens) # 输出:经过 27 层 Transformer 处理后,融合了环境上下文的动作表示(suffix_out) # 作用:利用预训练大模型的 Transformer 来深度理解和融合环境信息与动作信息 (prefix_out, suffix_out), _ = self.PaliGemma.llm(#(图像文本编码,动作编码)全部拼起来 → 送入大模型(Gemma / PaliGemma) [prefix_tokens, suffix_tokens],# 两个独立的输入 mask=attn_mask, # 注意力掩码 (B, P+H, P+H) positions=positions, # 位置索引 (B, P+H) adarms_cond=[None, adarms_cond]# 时间条件(只给 suffix) )