FM × scDynamics(1): Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport
Published:
接下来开一个新系列,记录自己读的有关flow matching和single cell dynamics交叉领域的paper。第一节记录一下这个领域最经典的paper:Improving and Generalizing Flow-Based Generative Models with Minibatch Optimal Transport,从最经典朴素的I-CFM,OT-CFM,SB-CFM开始讲起。
1.背景介绍
Alexander Tong这篇文章,关注用连续时间流模型连接源分布与目标分布,并在不需要“模拟时间动态”进行训练的前提下,统一并推广现有的 conditional flow matching 思路。这篇文章算是CFM系列非常重要和基础的一篇文章,同时文章对应的代码库,TorchCFM,算是flow matching和diffusion领域非常重要的几个开源库之一。所以FM系列的第一篇就主要介绍一下OTCFM这篇文章,对于CFM的几个基础变体有一些基本的了解。
1.1 从基于模拟到免模拟的范式转变
在Flow Matching出来之前,flow-based model已经作为生成模型的一种,登上了生成模型的舞台。通常意义上的归一化流(Normalizing Flows)通过组合静态可逆模块来构建分布间的映射,并不需要数值积分;而连续归一化流(Continuous Normalizing Flows, CNF)利用神经网络驱动的常微分方程(Neural ODEs)来实现这一过程,虽然理论上更灵活,但其训练依赖于数值积分和对求解器进行反向传播,这带来了巨大的计算开销和数值不稳定性,限制了其扩展性。正是这类基于ODE数值模拟的CNF方法在速度和scaling上遭遇瓶颈,难以很好应用于大规模数据。
与此同时,扩散模型(Diffusion Models)在前几年爆火起来,其成功部分归功于一个简单且稳定的训练目标:直接对随机微分方程(SDE)的漂移项或得分函数进行回归,而无需在训练时模拟整个轨迹。这一思想启发了针对CNF的改进:能否也为其设计一个类似的“免模拟”训练框架呢?这正是近期流匹配(Flow Matching, FM)工作的出发点,但在OTCFM之前的FM方法通常假设源分布必须为高斯分布,限制了其通用性。OTCFM这篇文章就提出了一种更普世的CFM框架,可以灵活的从一个分布转换到另一个分布,摆脱了对于原分布是高斯分布的限制。相比较于传统的 diffusion 训练设定,CFM 不要求高斯源分布,更具通用性;当然,基于 SB/bridge 的扩散方法也可拓展到非高斯源分布。
此外,无论是经过改进的CNF还是扩散模型,都面临一个共同的瓶颈:推理(采样)速度可能较慢,因为生成一个高质量样本通常需要对ODE/SDE进行多步数值求解。CFM属于确定性ODE推理范式,具备高效推理的潜力;尤其当采用OT‑CFM学到更“直”的流时,往往能以更少的函数评估步数(NFE)达到相当或更好的质量。
2. 核心思想:从流匹配到条件流匹配
本文的核心任务是学习一个映射 $f: \mathbb{R}^d \to \mathbb{R}^d$,它能够将一个源概率分布 $q_0$ 精确地变换为目标概率分布 $q_1$。即,若 $x_0 \sim q_0(x_0)$,则 $f(x_0) \sim q_1(x_1)$。这一设定非常通用,既涵盖了从简单先验(如高斯分布)生成复杂数据(如图像)的传统任务,也适用于源和目标均为经验数据分布的场景(如单细胞轨迹推断)。为避免记号混淆,下文端点分布记为 $q_0,q_1$,由向量场生成的中间路径记为 $p_t$。
在连续归一化流(CNF)中,这种变换是通过一个由神经网络参数化的常微分方程(ODE)来实现的:
\[dx = u_t(x)\,dt \quad (1)\]这里的 $u_t(x)$ 是一个时间依赖的向量场 (vector field),它为时空中的任意一点 $(t, x)$ 指定了一个“速度”或“运动方向”。通过从 $t=0$ 到 $t=1$ 对这个ODE进行积分,我们可以得到其解 $\phi_t(x)$,这是一个从初始点 $x_0$ 到 $t$ 时刻位置 $x_t$ 的连续变换。当我们将这个变换应用于整个分布 $p_0$ 时,该分布会沿着由 $u_t$ 定义的流线进行演化,这个过程被称为概率流 (probability flow)。这个概率流和对应的速度场的关系,由物理学中著名的连续性方程所描述:
\[\frac{\partial p_t}{\partial t} = -\nabla\cdot\big(p_t(x)\,u_t(x)\big) \quad (2)\]这个方程建立了概率密度 $p_t$ 的时间演化与驱动它的向量场 $u_t$ 之间的直接联系。
基于此,Lipman et al. (2023) 提出了流匹配 (Flow Matching, FM) 的目标。其核心思想是,如果我们已经知道了一条连接 $q_0$ 和 $q_1$ 的理想路径 ${p_t, u_t}$,我们就可以通过一个简单的回归损失来训练神经网络 $v_\theta(t,x)$ 去逼近真实的向量场 $u_t(x)$:
\[\mathcal{L}_{\rm FM}(\theta)=\mathbb{E}_{t\sim\mathcal{U}(0,1),\,x\sim p_t(x)}\,\|v_\theta(t,x)-u_t(x)\|^2 \quad (3)\]在训练出能够很好拟合速度场的神经网络之后,就可以通过ODE采样器和神经网络构建出路径,从而完成生成任务。
这个想法看起来很美好,但在实践中,这个目标是难以计算的 (intractable),因为对于任意的 $q_0$ 和 $q_1$,其中间时刻的概率密度 $p_t$ 和向量场 $u_t$ 通常没有解析表达式。唯一的特例是当概率流 $p_t$ 是一条连接两个高斯分布的路径,即 $p_t=\mathcal{N}(\mu_t,\sigma_t^2 I)$ 时,其对应的向量场 $u_t(x)$ 具有一个已知的闭式解 (closed-form solution):
\[u_t(x)=\frac{\sigma_t'}{\sigma_t}\big(x-\mu_t\big)+\mu_t' \quad (4)\]其中 $\mu_t’$ 和 $\sigma_t’$ 分别是均值和标准差路径关于时间 $t$ 的导数。
为了解决一般分布下的难题,本文提出了条件流匹配 (Conditional Flow Matching, CFM)。核心思想是引入一个条件变量 $z$,将宏观上复杂的边际概率路径 $p_t(x)$ 分解为一系列微观上简单的“条件”概率路径 $p_t(x \mid z)$ 的加权混合:
\[p_t(x) = \int p_t(x \mid z)q(z)dz \quad (5)\]其中 $q(z)$ 是条件变量的分布。通过这种方式,我们就可以自由设计那些易于处理的条件路径 $p_t(x \mid z)$ 及其对应的、形式简单的条件向量场 $u_t(x \mid z)$来简化计算了。
基于上述设定,生成该混合路径的边际向量场 $u_t(x)$ 也可以写作条件混合的形式:
\[u_t(x) := \mathbb{E}_{q(z)}\frac{u_t(x|z)p_t(x|z)}{p_t(x)} \quad (6)\]虽然我们得到了理想的回归目标 $u_t(x)$,但在实践中,我们很难求解出分母中的边际密度 $p_t(x)$,从而也就很难计算出 $u_t(x)$ 作为训练时的靶标。
为了绕过这一障碍,CFM 采用了一种新的目标函数——条件流匹配损失:
\[\mathcal{L}_{CFM}(\theta) := \mathbb{E}_{t,q(z),p_t(x\mid z)}\|v_\theta(t,x) - u_t(x\mid z)\|^2 \quad (7)\]使用这个目标函数,就可以将神经网络 $v_\theta$ 的回归目标从那个无法计算的边际向量场 $u_t(x)$,替换为了我们可以自由设计且易于计算的条件向量场 $u_t(x\mid z)$。
这种替换操作的数学合法性由定理 3.2保证:在 $p_t(x) > 0$ 的条件下,虽然 \(\mathcal{L}_{CFM}\) 和 \(\mathcal{L}_{FM}\) 的值不同,但它们关于模型参数 $\theta$ 的梯度是完全相等的:
\[\nabla_\theta \mathcal{L}_{CFM}(\theta) = \nabla_\theta \mathcal{L}_{FM}(\theta) \quad (8)\]这意味着,我们可以通过最小化一个简单、可计算的CFM损失,来实现与最小化那个复杂、不可计算的理想FM损失完全相同的优化效果。因此,我们只需设计出合适的条件路径,便可利用CFM目标,让模型自动学习到正确的、复杂的边际动力学,而无需关心其背后难以名状的真实形式。
3. 方法:一个框架,三种实现
上一节我们建立了一个通用的、灵活的条件流匹配 (CFM) 框架,它为解决传统流模型训练不稳、通用性差的问题提供了理论基础。其核心在于,通过巧妙地选择配对策略 $q(z)$、条件路径 $p_t(x \mid z)$ 和路径速度向量 $u_t(x\mid z)$ 这三大构建模块,我们可以设计出不同特性、满足不同需求的流模型。本节将介绍论文提出的三种重要实现:I-CFM、OT-CFM 和 SB-CFM,它们分别代表了对该框架的一次基础验证、一次效率优化和一次通用性拓展。
3.1 I-CFM:框架可行性的基础验证
为了首先验证CFM框架的基本可行性,作者提出了最简单、最直观的实现方式——I-CFM (Independent CFM)。它采用最朴素的方法来定义三大模块,旨在证明即便不依赖复杂的技巧,CFM框架也能成功运作。
- 配对策略 $q(z)$:独立耦合 (Independent Coupling)
- 采用最简单的随机配对,$z=(x_0, x_1)$ 中的始末点分别从源分布 $q_0$ 和目标分布 $q_1$ 中独立随机抽取,即 $q(z)=q(x_0)q(x_1)$。
- 条件路径 $p_t(x\mid z)$ 与 速度向量 $u_t(x\mid z)$:高斯平滑的直线
路径:对于配好的一对 $(x_0, x_1)$,路径被假定为从 $x_0$到 $x_1$ 的直线,并加入固定的高斯噪声(方差为$\sigma^2$)以平滑学习过程。
\[p_t(x\mid z)=\mathcal{N}\big(x\,\big|\,(1-t)\,x_0+t\,x_1,\;\sigma^2\mathbf{I}\big) \quad (9)\]- 速度:由于路径是匀速直线,速度向量是恒定的,即从起点指向终点的向量 $(x_1-x_0)$。
- \[u_t(x\mid z) = x_1-x_0 \quad (10)\]
对于给定的样本对 $z=(x_0,x_1)$,训练目标在条件路径上对 $x$ 与 $t$ 为常量,即 $(x_1-x_0)$。网络仍以 $(t,x)$ 为输入进行学习。
- 结论:I-CFM 的成功证明了 CFM 框架的有效性,并成功地将流匹配方法推广到了任意数据分布,为后续更精巧的设计打下了坚实的地基。然而,其“随机配对”的策略也带来了明显的缺点:可能会匹配相距很远的点,导致训练信号方差较大,学习效率低,且最终学到的全局流比较“弯曲”,推理时需要更多的计算步数(NFE)。
3.2 OT-CFM:追求效率与“最优”的确定性路径
I-CFM 验证了可行性,但其效率和路径的最优性有待提升。因此,作者提出了 OT-CFM (Optimal Transport CFM) 作为对 I-CFM 的一次重大改进,其核心动机是在 CFM 框架内追求一条成本最低的确定性路径 (deterministic path)。
这背后的理论武器是最优传输(OT)。OT 旨在寻找最高效的“运输方案”,其动态形式(Dynamic OT)寻找的是一条总动能最小的路径:
\[W_2^2(q_0,q_1)=\inf_{\{p_t,u_t\}}\int_0^1\!\int_{\mathbb{R}^d} p_t(x)\,\|u_t(x)\|^2\,dx\,dt \quad (11)\]而这等价于一个更容易处理的静态问题(Static OT),即寻找一个最优的耦合(配对方案)$\pi$,使得配对点之间距离平方的期望最小:
\[W_2^2(q_0,q_1)=\inf_{\pi\in\Pi(q_0,q_1)}\int_{\mathbb{R}^d\times\mathbb{R}^d}\|x-y\|^2\,d\pi(x,y) \quad (12)\]OT-CFM 正是利用了这一等价性,将配对方式从“随机”升级为“最优”。
- 配对策略 $q(z)$:最优传输耦合 (Optimal Transport Coupling)
- 不再独立抽取,而是根据一个从静态OT问题解出的最优传输计划 $\pi(x_0, x_1)$ 来联合抽取样本对 $z=(x_0, x_1)$。实践中,因全局OT计算成本过高,采用小批量最优传输 (Minibatch OT) 作为高效的近似策略。
- 条件路径 $p_t(x\mid z)$ 与 速度向量 $u_t(x\mid z)$:依然是高斯平滑的直线
- 这部分与I-CFM完全相同。OT-CFM的巧妙之处在于,它通过优化配对环节,使得提供给模型的这些简单的“直线路径”本身就更有意义、更“直”,从而让最终学到的全局流自然地趋向于最优。
- 结论:通过引入OT,OT-CFM 学习到的流更“直”,显著降低了训练信号的方差,实现了更快的训练收敛和更高效的推理(需要更少的 NFE)。
3.3 SB-CFM:从随机路径到统一框架
在解决了确定性最优路径后,作者进一步拓展了CFM框架的边界,以展示其更强大的通用性,即处理带随机性的路径 (stochastic path)。这需要一个新的数学工具:熵正则化最优传输。其目标函数是在标准OT的基础上加入一个熵项 $H(\pi)$,由正则化参数 $\lambda$ (在本文中对应 $2\sigma^2$) 控制权重:
\[W(q_0, q_1)^2_{2, \lambda} = \inf_{\pi_\lambda \in \Pi} \int_{\mathcal{X}^2} \|x-y\|^2 \,d\pi_\lambda(x, y) - \lambda H(\pi) \quad (13)\]该公式在最小化运输成本和最大化路径随机性(熵)之间取得平衡,是通向概率性路径的关键。这恰好与物理学中的薛定谔桥(Schrödinger Bridge, SB)问题不谋而合,该问题旨在寻找连接两个分布的“最可能”的随机过程。
因此,作者提出了 SB-CFM (Schrödinger Bridge CFM),其目标不再是寻找成本最小的路径,而是在 CFM 框架内找到概率最大的随机路径。
- 配对策略 $q(z)$:熵正则化最优传输耦合
- SB-CFM的配对策略正是由熵正则化OT给出的耦合 $\pi_{2\sigma^2}(x_0, x_1)$。
- 条件路径 $p_t(x\mid z)$ 与 速度向量 $u_t(x\mid z)$:布朗桥 (Brownian Bridge)
路径:这种配对策略自然地引出了一条随机路径——布朗桥。它像一根两端固定在$x_0$和$x_1$的琴弦,在中间会随机振动。其噪声方差随时间变化,在起点和终点为0,在中间最大。
\[p_t(x\mid z)=\mathcal{N}\big(x\,\big|\,(1-t)x_0+tx_1,\;t(1-t)\,\sigma^2\mathbf{I}\big) \quad (14)\]速度:由于路径是随机曲线,其速度向量也变得更复杂,包含一个指向终点的基础速度和一个将路径“拉回”中心线的修正项。
\[u_t(x\mid z)=\frac{1-2t}{2t(1-t)}\Big(x-\big((1-t)x_0+tx_1\big)\Big)+(x_1-x_0) \quad (15)\]
- 统一视角:熵正则化OT不仅是SB-CFM的理论基础,更提供了一个统一三者的精妙视角。通过调节正则化参数 $\lambda$(或等价的噪声水平 $\sigma$),我们可以平滑地在三个模型之间过渡:
- 当 $\lambda \to 0$ ($\sigma \to 0$) 时,我们只关心成本,模型退化为标准OT,对应 OT-CFM。
- 当 $\lambda \to \infty$ ($\sigma \to \infty$) 时,我们只关心随机性,模型退化为独立耦合,对应 I-CFM。
- 当 $\lambda$ 为特定正值时,则对应 SB-CFM。
因此,这三种看似不同的方法,实际上是同一个统一框架下、由不同正则化强度催生的三种不同表现形式,完美展现了 CFM 框架的内在一致性与强大扩展性。
