FM × scDynamics(4): Metric Flow Matching for Smooth Interpolations on the Data Manifold

3 minute read

Published:

FM系列的第四篇,主要探讨了这样一个问题,传统的条件流匹配(CFM)所依赖的欧几里得空间中的直线插值路径,是不是最优的?有没有能够更忠实的反应数据内在动态结构的插值路径?论文《Metric Flow Matching》就是在讨论这个问题,它认为,当数据本身分布在一个弯曲的低维流形上时,直线路径会不可避免地“抄近道”穿过数据从未涉足的低密度区域。在这种区域训练向量场,不仅会引入高度的不确定性,更使得生成的中间状态缺乏真实的物理或生物学意义。所以MetricFM就提出来一个数据依赖的黎曼度量(Riemannian metric),重新定义了空间的几何结构,使得数据密集区域的“通行成本”更低。并在此基础上,最小化该度量下的路径动能,寻找近似的测地线(geodesics)作为CFM的插值路径。由于测地线天然倾向于沿着“成本最低”的路径前进,它会巧妙地贴合数据流形,从而让生成过程在数据支撑的区域内平滑演化,最终获得物理意义更明确、质量更高的生成结果。

1. 背景

在单细胞测序过程中,细胞会遭到不可逆的破坏,从而我们无法像拍电影一样追踪同一个细胞的完整生命历程,只能得到几张离散时间点下的单细胞快照。Single-Cell Trajectory Inference的目标就是基于这几张细胞快照,重建出连续的细胞状态演化过程。生成建模,特别是基于“匹配目标”的范式,如扩散模型、流匹配(Flow Matching, FM)和薛定谔桥(Schrödinger Bridge),为这一难题提供了强大的框架。其中,条件流匹配(CFM)作为一个无需模拟(simulation-free)的框架,通过为起始点和终点样本对构造插值路径,直接学习一个时间依赖的向量场,效率很高。但传统CFM所采用的直线插值 $x_t = (1-t)x_0 + t x_1$ 存在一个根本性的几何缺陷:它假设数据空间是平坦的。当数据实际位于一个弯曲的低维流形 $\mathcal{M}$ 上时,直线路径会穿过流形之外的数据稀疏区域。这迫使模型在一个“凭空想象”的位置上学习速度向量,不仅增加了学习难度,也让生成的中间分布 $p_t$ 丧失了解释性,无法真实地模拟数据内在的演变规律。

而作者发现,像单细胞基因表达这样的高维数据,其内在结构非常类似所谓的“镶嵌在高维空间中的低维流形”。这是因为细胞的分化和发育是一个连续的过程,因此代表不同细胞状态的数据点会形成一条或多条连续的、弯曲的路径,而不是随机散落在高维基因空间中。对于这个“流形假设”,我们可以粗浅的这样理解:地球本身是三维的,但对我们人类而言,我们在地球表面活动,地球就是一个二维的有曲度的平面。而且,连接地球表面两点的最短路径是沿着地表的曲线(测地线),而非穿透地壳的直线,如果强制人在地球上走直线,就会离开人类活动的地球表面,走进地壳里。所以同样的,在数据流形上,直接的直线插值也可能会“离开”数据所在的真实区域。

所以这篇MetricFM,主要的贡献就是设计了一种能够感知并贴合数据流形的插值机制,从而将向量场的训练过程约束在数据支撑的高置信度区域。作者巧妙地避开了直接参数化流形 $\mathcal{M}$ 所带来的坐标系选择、数值不稳定等难题,转而在原始空间 $\mathbb{R}^d$ 上引入一个数据依赖的黎曼度量 $g$,通过几何手段间接但有效地约束插值路径。

2. 黎曼几何:在弯曲空间中定义“直”

为了理解MFM,我们首先需要明确一些黎曼几何的基本概念。

2.1 黎曼流形与度量:重新定义局部距离

我们熟悉的欧几里得空间拥有统一的、固定的几何规则:两点之间直线最短,勾股定理四海皆准,不管空间位置怎么改变,至少距离标准不会变。而黎曼流形 $(M,g)$*则将这一概念推广到了任意弯曲的空间。它由两部分构成:

  1. 流形 $M$:一个点的集合。虽然整体上可以弯曲(如球面),但其任何一个微小的局部都近似于一个平坦的欧几里得空间。
  2. 黎曼度量 $g$:一个在流形上每一点 $x$ 都定义了的“局部标尺”。它具体表现为一个对称正定矩阵 $G(x)$,也称为度量张量。这个矩阵 $G(x)$ 定义了在点 $x$ 附近的无穷小距离和角度,告诉我们如何在局部衡量路径的长度。欧几里得空间可以看作是黎曼流形的一个特例,其度量在任何地方都是单位矩阵 $G(x) \equiv I$,因此其几何规则处处统一。 总体来说,就是让这个空间中的不同位置有了不同的距离度量,这为我们找到任意两点间的最短距离造成了麻烦。区别于欧几里得空间中直接找直线,我们在黎曼流形中需要通过复杂的办法,来寻找测地线作为最短路径。

2.2 测地线:弯曲空间中的“最短路径”

在黎曼流形中,连接两点的最短路径被称为测地线(geodesic) $\gamma^*$。它通过最小化路径的长度泛函得到:

\[\gamma_{t}^{*} = \arg\min_{\gamma_t} \int_{0}^{1} \lVert \dot{\gamma}_{t} \rVert_{g(\gamma_{t})} \, dt,\]

其中:

  • $\dot{\gamma}_{t}$ 是路径在 $t$ 时刻的速度向量。
  • \(\lVert \dot{\gamma}_{t} \rVert_{g(\gamma_{t})}\) 是在 $t$ 时刻,根据当前点 $\gamma(t)$ 的局部几何规则 $g$ (即矩阵 $G(\gamma(t))$) 计算出的速度大小。其计算方式为: \(\lVert \dot{\gamma}_{t} \rVert_{g(\gamma_{t})} = \sqrt{\langle \dot{\gamma}_{t},\, G(\gamma_{t}) \dot{\gamma}_{t} \rangle}.\)

从这个公式就可以看出来,在度量矩阵 $G(x)$ 范数很大的区域,经过的路径长度就会被放大,移动成本也会对应的增加。因此,最短的测地线会倾向于绕开这些高成本区域,转而沿着成本较低(我们期望是数据密集)的区域前进。这正是MFM利用黎曼几何来塑造插值路径的核心数学底层逻辑。

3. Metric Flow Matching

MFM的构造分为两个核心步骤:首先,利用数据依赖的度量构建近似的黎曼流形空间,学习近似的测地线路径;然后,在这些更合理的路径上构造更好的插值点\(x_t\),学习向量场,进行训练。

3.1 近似测地线:从动能最小化出发

最自然朴素的想法肯定是,直接在这样的低维数据流形上建模,但这在技术上太复杂了。作者转而考虑一个等价的目标,不要求路径在抽象的流形上,而是学习一个黎曼度量,让数据密集的区域距离变短,稀疏区域距离变长。这样,就可以让路径时刻靠近我们已有的数据点。但是,在这样复杂的度量下,直接求解复杂的测地线微分方程在计算上非常昂贵,而且很难像CFM一样做到simulation-free。作者就直接使用神经网络来学习这个最短路径测地线:

\[x_{t,\eta} = (1-t)x_0 + t x_1 + t(1-t)\,\varphi_{t,\eta}(x_0, x_1),\]

其中,$\varphi_{t,\eta}$ 是一个由神经网络参数化的修正项。可以看到,这个路径其实就是在直线路径的基础上,添加了一个非线性的修正项\(t(1-t)\,\varphi_{t,\eta}(x_0, x_1)\),并且这个修正项在 $t=0$ 和 $t=1$ 时为零,保证了路径的端点始终固定在给定的 $(x_0, x_1)$,同时允许路径在中间过程发生最大程度的弯曲,以贴合数据流形。

为了指导 $\varphi_{t,\eta}$ 的学习,作者利用了物理学和几何学的一个重要原理:测地线也是最小化动能泛函 (kinetic energy functional) 的路径。所以作者将优化目标设定为最小化其在度量 g 下的动能,从而定义了测地线损失,即路径的期望动能:

\[\mathcal{L}_g(\eta) = \mathbb{E}_{t,(x_0,x_1)\sim q}\!\left[(\dot{x}_{t,\eta})^{\top} G(x_{t,\eta};\mathcal{D}) \dot{x}_{t,\eta}\right],\]

通过梯度下降最小化 \(\mathcal{L}_g(\eta)\),我们就能学到一条近似测地线插值路径 \(x_{t,\eta^*}\)。这个过程完整地保留了CFM无需模拟的优势,同时将数据的几何结构显式地编码到了路径的定义中。

3.2 向量场学习与整体框架

在获得最优的近似测地线路径 \(x_{t,\eta^*}\) 后,第二阶段的向量场学习就顺理成章了。其目标是在这条路径上,让学习到的向量场 \(v_{t,\theta}\) 精确匹配路径的速度场 \(\dot{x}_{t,\eta^*}\)。损失函数也相应地调整为在黎曼度量下计算L2范数:

\[\mathcal{L}_{\text{MFM}}(\theta) = \mathbb{E}_{t,(x_0,x_1)\sim q}\left[\lVert v_{t,\theta}(x_{t,\eta^*}) - \dot{x}_{t,\eta^*} \rVert_{g(x_{t,\eta^*})}^2 \right], \quad \text{其中} \quad \eta^* = \arg\min_\eta \mathcal{L}_g(\eta).\]

MFM的完整流程因此分为两个阶段:

  1. 路径学习阶段:通过最小化 \(\mathcal{L}_g(\eta)\)预训练出最优的路径生成器 \(\varphi_{t,\eta^*}\)。
  2. 流场学习阶段:在这些学好的、沿着数据流形的路径 $x_{t,\eta^*}$ 上,通过最小化 \(\mathcal{L}_{\text{MFM}}(\theta)\) 来训练主速度场网络 \(v_{t,\theta}\)。

这个框架将一个困难的微分几何求解问题,巧妙地转化为了两个标准的神经网络训练问题,兼具理论的优雅性和实践的可行性。

3.3 与 Riemannian Flow Matching (RFM) 的关键区别

MFM并非首个将黎曼几何引入流匹配的工作,更知名的类似工作应该是Riemannian Flow Matching (RFM),但它与RFM存在本质区别,主要在于度量的来源。MFM依赖的是数据本身驱动生成的度量,通过学习一个数据依赖的度量来定义距离,引导路径贴近数据流形。而RFM假设空间的度量预先给定,和数据无关。在度量位置的情况下,需要一个类似AE的工具来参数化流形,再再上面定义几何结构。

4. 学习数据诱导的黎曼度量

现在已经明确了,作者的思路是通过学习一个数据依赖的黎曼度量g来指导插值路径,从而让他贴近数据流形,但是g具体应该怎么定义,还有待具体化。

4.1 LAND 度量:基于核密度估计的局部自适应

对于低维数据,作者提出了LAND(Locally Adaptive Normal Distribution)度量:

\[G_{\varepsilon}(x) = \left(\operatorname{diag}(h(x)) + \varepsilon I\right)^{-1}, \quad h_\alpha(x) = \sum_{i=1}^{N} (x_i^\alpha - x^\alpha)^2 \exp\!\left(-\frac{\lVert x - x_i\rVert^2}{2\sigma^2}\right).\]

其直观理解是:\(h(x)\) 在靠近数据点 $x_i$ 的区域值会更大。由于 $G(x)$ 与 $h(x)$ 成反比,这就使得数据密集区的度量值 $G(x)$ 变小,从而引导测地线倾向于穿过这些区域,以达到贴近数据流形的目的。LAND度量适用于维度较低、且核宽度 $\sigma$ 可以通过经验调节的场景。

4.2 RBF 度量:在高维场景下学习自适应尺度

在单细胞基因表达等高维场景中,单一的全局核宽度 $\sigma$ 难以调优。为此,作者引入了更强大、更灵活的RBF度量

\[G_{\text{RBF}}(x) = \left(\operatorname{diag}(\tilde{h}(x)) + \varepsilon I\right)^{-1}, \quad \tilde{h}_\alpha(x) = \sum_{k=1}^{K} \omega_{\alpha,k}(x)\exp\!\left(-\frac{\lambda_{\alpha,k}}{2}\lVert x-\hat{x}_k\rVert^2\right),\]

这里的 \(\hat{x}_k\) 是预先对数据进行了一次k-means聚类得到的k个聚类中心,而每个中心点关联的权重 $\omega_{\alpha,k}$ 与尺度(逆方差)$\lambda_{\alpha,k}$ 都是可学习的参数。这使得度量的局部形状能够由数据自适应地决定,大大减少了对人工超参数的敏感性。当使用RBF度量时,测地线损失清晰地展示了其几何正则化效果: \(\mathcal{L}_{g_{RBF}}(\eta)=\mathbb{E}_{t,(x_{0},x_{1})\sim q}\left[\sum_{\alpha=1}^{d}\frac{(\dot{x}_{t,\eta})_{\alpha}^{2}}{\tilde{h}_{\alpha}(x_{t,\eta})+\epsilon}\right]\) 这个公式清晰地表明:路径的速度 \(\dot{x}_{t,\eta}\) 在远离数据中心(即 $\tilde{h}_{\alpha}$ 值小)的区域会受到更大的惩罚,从而迫使路径保持在数据流形上。

4.3 结合最优传输:OT-MFM

为了给MFM提供更有意义的端点配对 $(x_0,x_1)$,作者采用了类似OT-CFM的策略。将OT耦合、测地线损失与RBF度量相结合,便得到了论文在实验中效果最佳的主力模型 OT-MFM$_{\text{RBF}}$

\[\mathcal{L}_{\text{OT-MFM}_{\text{RBF}}}(\theta) = \mathbb{E}_{t,(x_0,x_1)\sim \pi^*} \left[\lVert v_{t,\theta}(x_{t,\eta^*}) - \dot{x}_{t,\eta^*} \rVert_{g_{\text{RBF}}(x_{t,\eta^*})}^2\right].\]

4.4 两阶段训练策略

阶段一:用“几何/测地线损失”(Eq. 11)训练插值器 γ_η

Inputs

  • 数据:相邻时间点经验分布 q0, q1
  • 数据依赖度量 G_φ(x):如 RBF/LAND(可固定或与 η 一起训练)
  • 插值器 γ_η(t; x0, x1)(Eq. 4 的“直线 + 可学习弯曲项”)
  • 学习率, 迭代步数 T, 批量大小 B
  • Sinkhorn 正则 ε、迭代步数 K(仅用于 minibatch-OT)

Procedure for step = 1..T:

  1. 从 q0, q1 各采样 B 条样本,得批量 X0, X1
  2. π = MinibatchOT(X0, X1, cost = ||x0 - x1||², ε, K) 得到 OT 耦合 注意:代价使用欧氏距离,与数据度量解耦
  3. 从 π 中抽样 B 对匹配 (x0ᵢ, x1ⱼ)
  4. 采样 t ~ Uniform[0, 1](或每对各采一个 t)
  5. 前向:x_t = γ_η(t; x0ᵢ, x1ⱼ) 速度:v_t = ∂/∂t γ_η(t; x0ᵢ, x1ⱼ) # 自动微分
  6. 计算几何/测地线损失(Eq. 11): L_geo = mean_over_batch( v_tᵀ · G_φ(x_t) · v_t ) 若 G 为对角度量,可写作 sum_d G_d(x_t) * v_t[d]^2
  7. 反向传播,更新 η(以及可选的 φ) end for

输出:已训练好的插值器 γ_η(“geopath”),后续阶段固定使用

阶段二:在固定 γ_η 下做(OT-)CFM 回归向量场 v_θ

Inputs

  • 与阶段一相同的 q0, q1、度量 G_φ(推导/正规化用)
  • 冻结的插值器 γ_η
  • 向量场网络 v_θ(t, x)
  • 其余训练超参同上

Procedure for step = 1..T:

  1. 从 q0, q1 采样批量 X0, X1
  2. π = MinibatchOT(X0, X1, cost = ||x0 - x1||², ε, K) 得到与阶段一相同形态的 OT 配对
  3. 从 π 中抽样 B 对 (x0ᵢ, x1ⱼ);采样 t ~ Uniform[0, 1]
  4. 构造训练点: x_t = γ_η(t; x0ᵢ, x1ⱼ) u_t = ∂/∂t γ_η(t; x0ᵢ, x1ⱼ) # 目标速度(条件路径的速度)
  5. 条件流匹配(CFM)回归目标: 论文建议把回归目标按度量范数正规化,从而数值稳定, 等价于使用标准的 L2 回归形式来训练 v_θ L_flow = mean_over_batch( || v_θ(t, x_t) − u_t ||² )
  6. 反向传播,更新 θ end for

输出:已训练好的向量场 v_θ;推断时用欧氏 Euler 步近似积分得到流映射

5. 与广义能量最小化框架的联系

作者进一步揭示,MFM的测地线损失 $\mathcal{L}_{g}(\eta)$ 本质上是最小化一个数据依赖的动能。这可以看作是对传统方法的推广:标准CFM的直线路径,正是在欧氏度量 $G(x)=I$ 下最小化动能 $\lVert\dot{x}_t\rVert^2$ 的结果。MFM的点睛之笔在于,它不是简单地最小化动能,而是学习一个数据依赖的动能场,在这个场中,沿着数据流形的运动“成本”最低。

更有趣的是,MFM框架也从一个更深刻的视角统一了数据依赖势能的建模思路。像GSBM这类方法,它们通过引入一个人工设计的势能项 $V_{t}(x_{t})$ 来引导路径,其最小化的总能量为: \(\mathcal{U}(x_{t},\dot{x}_{t})=\underbrace{K(\dot{x}_{t})}_{\text{动能}} + \underbrace{V_{t}(x_{t})}_{\text{势能}}\) 这类方法的一个核心痛点是势能 $V_t$ 缺乏通用的设计原则,需要研究者为特定任务去手动构建,通用性较差。

而MFM的测地线损失函数 $\mathcal{L}_g(\eta)$ 竟然可以通过一个简单的代数变换,完美地分解为标准动能和势能之和的形式。推导过程如下:

  1. 从MFM的测地线损失出发,它是在黎曼度量 $g$ 下的广义动能: \(\mathcal{L}_g(\eta) = \mathbb{E}[\dot{x}_{t,\eta}^\top G(x_{t,\eta}) \dot{x}_{t,\eta}]\)

  2. 通过引入单位矩阵 $I$ 进行恒等变形,将 $G$ 分解为欧氏空间部分和黎曼几何修正部分: \(\mathcal{L}_g(\eta) = \mathbb{E}[\dot{x}_{t,\eta}^\top (I + (G(x_{t,\eta}) - I)) \dot{x}_{t,\eta}]\)

  3. 展开后,我们得到清晰的两项: \(\mathcal{L}_g(\eta) = \mathbb{E}[ \underbrace{||\dot{x}_{t,\eta}||^2}_{\text{标准动能}} + \underbrace{\dot{x}_{t,\eta}^\top (G(x_{t,\eta}) - I) \dot{x}_{t,\eta}}_{\text{数据驱动的势能}} ]\)

这个推导有力地证明了,MFM中的势能 \(V_{t,\eta} = \dot{x}_{t,\eta}^\top (G(x_{t,\eta}) - I) \dot{x}_{t,\eta}\) 并非人工设计,而是完全由数据驱动的度量 $g$ 自动产生。它不仅是可学习的,还与路径参数 $\eta$ 共同优化。这为我们从一个统一的几何视角来理解和构造生成模型的正则项(regularization)提供了深刻的洞见。