FM × scDynamics(3): Meta Flow Matching — Integrating Vector Fields on the Wasserstein Manifold
Published:
FM系列的第三篇,来讲一篇与OTCFM作者团队相关的新作:Meta Flow Matching(MFM)。论文定义了生物学与物理学中的一个共性问题:对由相互作用的实体(如细胞或粒子)组成的系统,建模其随时间的连续演化。若能学到这样的动力学规律,便可在不同样本与未见环境下预测整个人群分布的时间演化。
基于Flow的生成模型天然适合这一任务,因为它们直接学习“分布到分布”的演化;但现有基于Flow的方法在该领域普遍存在两点限制:(1)通常只针对单一初始人群(或少量预定义条件)训练,难以泛化到全新的起始人群;(2)多将个体视为独立粒子,忽略群体内部的相互作用。作者指出,许多自然过程的动力学不仅随时间变化,还取决于“当前的群体分布状态”(可形式化为Wasserstein概率分布流形上的位置)。据此提出MFM:用一个GNN对整个人群(初始分布)进行嵌入,将其作为条件输入流匹配模型,从而学习“密度依赖”的向量场,并能泛化到未见的人群。
Background
作者首先聚焦于单细胞生物学中的核心挑战:在缺乏纵向追踪情况下,建模细胞对环境(尤其是与其他细胞)响应时的动力学。细胞通过邻近信号传导相互影响,导致个体轨迹并非独立。同样的现象在许多自然科学的多体系统中也存在。现有多数方法(OT、扩散、FM/正规化流等)多在“种群层面”拟合分布配对,往往将细胞视作独立样本、难以纳入群体相互作用,并且即使能做条件生成,也受限于训练时见到的条件,泛化到“新的人群/新条件”能力不足。MFM通过在Wasserstein流形上学习“密度依赖”的向量场,显式纳入群体状态,从而提升跨人群的泛化。
Flow Matching
Flow Matching(FM)是一类生成建模框架。流被定义为两种概率密度 $p_0(x)$ 与 $p_1(x)$ 在样本空间中的连续插值。令 $x_t = f_t(x_0, x_1)$,其中 $(x_0, x_1) \sim \pi(x_0, x_1)$ 的边缘满足 $\int \pi(x_0, x_1) dx_1 = p_0(x_0)$、$\int \pi(x_0, x_1) dx_0 = p_1(x_1)$,则 $p_t$ 可写为 $p_t(x) = \iint \pi(x_0, x_1) \, \delta(x - f_t(x_0, x_1)) \, dx_0 dx_1$。
映射 $f_t$ 需要满足边界条件 $f_0(x_0, x_1) = x_0$ 和 $f_1(x_0, x_1) = x_1$。一个常见的选择是线性插值:$f_t(x_0, x_1) = (1-t)x_0 + t x_1$。那么,对应的边际密度 $p_t(x)$ 就是 $p_t(x) = \iint \pi(x_0, x_1) \delta(x - f_t(x_0, x_1)) dx_0 dx_1$。
FM 的核心是学习一个速度场 $v_t(x)$ 来描述从 $p_0$ 到 $p_1$ 的推进。根据连续性方程,理想速度场 $v_t^*(\xi)$ 可写为: \(v_t^*(\xi) = \frac{1}{p_t(\xi)} \mathbb{E}_{\pi(x_0,x_1)} \left[ \delta(\xi - f_t(x_0, x_1)) \frac{\partial f_t(x_0,x_1)}{\partial t} \right]\) 该期望难以直接计算,通常以神经网络 $v_t(x;\omega)$ 近似之。相应可得可积目标: \(\mathcal{L}_{\rm FM}(\omega) = \mathbb{E}_{\pi(x_0,x_1)}\int_0^1 \! dt \, \big\| \tfrac{\partial}{\partial t} f_t(x_0,x_1) - v_t(f_t(x_0,x_1);\omega) \big\|^2\) (线性插值时 $f_t=(1-t)x_0+tx_1$,故 $\partial_t f_t = x_1-x_0$)。标准FM多将样本独立处理,忽略群体相互作用的信息。
Conditional Flow Matching
在实际中,FM可扩展为条件生成,即 Conditional Flow Matching(CGFM)。对条件族 $p_t(x_t\,|\,c)$,其中 $(x_0,x_1)\sim\pi(x_0,x_1|c)$,其条件速度场为: \(v_t^*(\xi | c) = \frac{1}{p_t(\xi | c)} \mathbb{E}_{\pi(x_0, x_1|c)} \left[\delta(f_t(x_0, x_1) - \xi) \frac{\partial f_t(x_0, x_1)}{\partial t}\right]\) 于是,训练目标改为: \(\mathcal{L}_{CFM}(\omega) = \mathbb{E}_{p(c)}\mathbb{E}_{\pi(x_0, x_1 | c)} \int_0^1 dt \left\| \frac{\partial}{\partial t} f_t(x_0, x_1) - v_t(f_t(x_0, x_1) | c; \omega) \right\|^2\) 与原始FM相比,CGFM需先采样 $c\sim p(c)$,再从 $\pi(x_0,x_1|c)$ 采样配对,并将 $c$ 作为额外输入传给向量场模型。它可引入“已知的条件信息”,但通常仅对训练中见过的条件有效,难以泛化到全新的群体/条件。
Meta Flow Matching
作者进一步引入Wasserstein流形 $\mathcal{P}_2(\mathcal{X})$:可粗略理解为“分布的几何空间”,其中每个点是一个概率分布,一条曲线 $p_t$ 表示群体随时间演化的轨迹,曲线上每处的切向量对应动力学向量场。我们的目标是学习定义在 $\mathcal{P}_2(\mathcal{X})$ 上的“密度依赖”向量场,即: \(\frac{\partial p_t(x)}{\partial t} = - \langle \nabla_x, \, p_t(x) \cdot v_t(x, p_t) \rangle, \quad p_{t=0}(x) = p_0(x)\) 其中 $v_t(x, p_t)$ 依赖于当前的群体密度 $p_t$(如均场相互作用或Fokker–Planck型扩散所示)。这与仅在单条曲线上学习切向量(FM)或有限多条曲线(CGFM)不同:MFM意在学习“流形上的向量场”。
Methods
Meta Flow Matching(MFM)的核心是将Flow Matching在“输入分布”上进行摊销(amortization)。传统FM是:给定一个特定联合 $\pi(x_0,x_1)$,解如下最优化问题以得到对应该 $\pi$ 的最优向量场 \(v_t^*(\cdot, \pi) = \arg\min_{v_t} \mathcal{L}_{GFM}(v_t(\cdot), \pi(x_0, x_1))\) (此处 $\mathcal{L}_{GFM}$ 即FM目标的记法)。CGFM虽可引入外部条件 $c$,但当 $c$ 是预定义标签时,难以泛化到未见条件。
MFM不同之处在于:学习一个“元模型”,输入为初始分布(或其样本)本身,经由嵌入函数 $\varphi(\pi)$ 提取表征,使 $v_t(\cdot, \varphi(\pi)) \approx v_t^*(\cdot, \pi)$,实现对不同起始人群的“摊销”与泛化。
MFM的目标函数
具体地,设数据为若干人群联合分布 ${\pi(x_0,x_1\mid i)}_i$($i$ 可指一个实验/处理/个体)。若用CGFM并直接以 $i$ 作为条件,模型难以泛化到新的 $i$。
为实现泛化,MFM对初始边际 $p_0$ 进行嵌入 $\varphi(p_0)$,并以此为条件拟合速度场: \(\mathcal{L}_{\text{MFM}}(\omega; \varphi) = \mathbb{E}_{i \sim D} \mathbb{E}_{\pi(x_0, x_1 | i)} \int_0^1 dt \left\| \frac{\partial}{\partial t} f_t(x_0, x_1) - v_t(f_t(x_0, x_1) | \varphi(p_0); \omega) \right\|^2\) 含义为:
- 从所有实验的数据集 $D$ 中随机抽取一个实验 $i$。
- 从该实验的联合分布 $\pi(x_0, x_1 \mid i)$ 中随机抽取一个粒子的“起点-终点”对 $(x_0, x_1)$。
- 在实践中,对时间的积分通过从 $[0,1]$ 区间内随机采样一个时间点 $t$ 来近似。
- 计算真实瞬时速度与模型预测速度的平方误差(线性插值时 $\partial_t f_t = x_1 - x_0$)。
此外,论文命题(Prop. 1)证明:当条件变量 $c$ 的分布及其与边际的依赖已知时,存在嵌入使MFM退化为CGFM(取 $\varphi(p_0(\cdot\mid c)) = c$)。
基于GNN的群体嵌入
在生物数据中,$\pi$ 多以样本点集给出而非解析密度,因而 $\varphi$ 需处理大小可变、无序的点集。作者采用图神经网络(GNN)实现群体嵌入: \(\varphi(p_0, \theta) = \varphi\big(\{x_0^j\}_{j=1}^{N_i}; \theta\big), \quad (x_0^j, x_1^j) \sim \pi(x_0, x_1 \mid i)\) 实现要点:
- 对初始群体 ${x_0^j}$ 构建 $k$ 近邻图(特征空间度量)。
- 送入GNN(多轮消息传递),并配合 $k$NN 边池化(edge pooling)。
- 通过节点平均池化得到固定维度的群体嵌入向量。
GNN嵌入(参数 $\theta$)与向量场网络(参数 $\omega$)端到端联合训练,最小化 $\mathcal{L}_{\text{MFM}}$。除参数化嵌入外,作者亦讨论了其他嵌入形式(如直接用 $p_0$ 的密度值或核密度估计),体现“密度依赖”思想的一般性。
训练与采样算法
论文给出训练与采样伪代码:
Algorithm 1: Meta Flow Matching (training)
Input: 数据集 {(π(x₀, x₁|i), cⁱ)}_i, 速度场模型 v_t(·; ω), 群体嵌入模型 φ(·; θ)
for 训练迭代 do
i ~ U{1,N}(i) // 随机采样一个实验批次
(x₀ʲ, x₁ʲ, tʲ) ~ π(x₀, x₁|i)U[0,1](t) // 为每个实验采样粒子和时间点
f_t(x₀ʲ, x₁ʲ) ← (1-tʲ)x₀ʲ + tʲx₁ʲ
hⁱ(θ) ← φ({x₀ʲ}; θ) // 嵌入初始群体 {x₀ʲ}
L_MFM(ω, θ) ← mean(||f_t'(x₀ʲ, x₁ʲ) - v_tʲ(f_t(x₀ʲ,x₁ʲ)|hⁱ(θ), cⁱ; ω)||^2)
ω' ← Update(ω, ∇_ω L_MFM(ω, θ)) // 更新流模型的参数
θ' ← Update(θ, ∇_θ L_MFM(ω, θ)) // 更新嵌入模型的参数
ω ← ω', θ ← θ' // 应用更新
end for
return v_t(·; ω*), φ(·; θ*)
Algorithm 2: Meta Flow Matching (sampling)
Input: 初始群体 {x₀ʲ}, 处理条件 cⁱ, 模型 v_t(·; ω*) and φ(·; θ*)
h = φ({x₀ʲ}; θ) // 嵌入给定的初始群体
x₁ʲ = ∫₀¹ v_t(x_tʲ|h, cⁱ; ω) dt + x₀ʲ // 使用ODE求解器积分得到最终状态
return 预测的最终群体 {x₁ʲ}
实现细节(与实验部分一致):$v_t(\cdot\,|\,\cdot;\omega)$ 采用MLP 参数化;$\varphi(\cdot;\theta)$ 采用GCN(含 $k$NN 边池化)。
Related Work
Meta-learning over Distributions and Meta OT
“Meta”在此意为“在分布上摊销学习”,与Meta Optimal Transport(Meta OT)一脉相承:后者在多对输入边际上摊销OT求解策略。关键差异在于:
- 输入需求不同:Meta OT 需要起点与终点两个边际;MFM 仅以起始分布 $p_0$ 为输入,即可预测其发展(假设柯西问题唯一解)。这非常契合“仅知处理前,需预测处理后”的生物场景。
- 架构约束不同:多种OT方法(含Meta OT)常依赖ICNN以保证势函数的凸性;MFM不受此限,直接学习向量场,设计更灵活。
Single-Cell Generative Modeling
单细胞生成建模方面,AE/变分方法(scVI、scGen 等)常用于嵌入与OOD推断;OT/连续流方法(CellOT、TrajectoryNet、FM变体等)用于轨迹预测。MFM的优势在于:
- 纳入相互作用:通过GNN对整个人群嵌入,将群体结构融入动力学预测。
- 跨群体泛化:学习“规律”(向量场)而非仅拟合单一人群,可泛化到未见人群/条件。
Generative Modeling for Physical Processes
在物理多体建模中,GNN用于预测相互作用较为常见,但通常可获得连续时间的粒子轨迹(密集监督)。而在单细胞场景,我们往往仅有“起点—终点”两时刻(破坏性测序难以纵向追踪),因此需在稀疏监督下学习连续动力学,难度更高。MFM正是面向此设定而设计。
Experiments
作者在两类任务上验证MFM的泛化能力:(i)合成字母去噪任务;(ii)真实大规模单细胞药物扰动数据(器官类肿瘤类器官药筛,含多患者多处理条件)。
- 模型设定:$v_t$ 用MLP参数化;$\varphi$ 用GCN(含 $k$NN 边池化)参数化。报告多种 $k$ 的设置,并进行消融。
- 合成实验:相比FM与CGFM,MFM能将“去噪动力学”泛化到未见字母轮廓(测试集),生成质量以 $\mathcal{W}_1/\mathcal{W}_2$/MMD 度量显著优于FM/CGFM;同时探讨了用高斯源分布与OT耦合的变体。
- 生物实验:在“跨复现实验/跨患者”设定下,MFM在非OT方法中整体最佳,并与引入OT的FM/CGFM变体表现相当;在跨患者设定中,MFM在 $\mathcal{W}_1/\mathcal{W}_2$/MMD 与 $r^2$ 等指标上优于FM/CGFM与ICNN基线,显示能捕捉患者特异性响应并泛化至未见个体。
总体结论:MFM学习的是“分布的动力学规律”(流形上的向量场),而不局限于单一数据集/条件,因而能够在新的群体与环境下进行更稳健的分布级预测。
