前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >超GFlowNet 4个数量级加速

超GFlowNet 4个数量级加速

作者头像
CreateAMind
发布2024-06-04 13:57:39
570
发布2024-06-04 13:57:39
举报
文章被收录于专栏:CreateAMindCreateAMind

Generative Marginalization Models

https://github.com/PrincetonLIPS/MaM https://arxiv.org/pdf/2310.12920

comparison: Any-order ARM (AO-ARM) [20], ARM [30], GFlowNet [39, 76], Discrete Flow4[65] and Probabilistic Circuit (PC)5[45].

4个数量级的加速

摘要

我们介绍边缘化模型(MMs),这是一种新的用于高维离散数据的生成模型家族。它们通过明确地建模所有引发的边缘分布,提供可扩展和灵活的生成建模,具有可处理的似然。边缘化模型通过神经网络的单次前向传递快速评估任意边缘概率,克服了具有精确边缘推断的方法(如自回归模型(ARMs))的一个主要限制。我们提出了一种可扩展的学习边际的方法,基于“边际自一致性”marginalization self-consistency的概念。与先前的方法不同,MMs 还支持任意阶的生成模型的可扩展训练,用于高维问题的能量训练设置,目标是将学习到的分布与给定的期望概率匹配(由非标准化(对数)概率函数(如能量或奖励函数)指定)。我们展示了所提出模型在各种离散数据分布上的有效性,包括二进制图像、语言、物理系统和分子,适用于最大似然和基于能量的训练设置。在两种设置中,MMs 在评估边际概率方面实现了数量级的加速。对于基于能量的训练任务,MMs 可以实现高维问题的任意阶生成建模,超越了先前方法的能力。代码位于 https://github.com/PrincetonLIPS/MaM。

1 引言‍‍

深度生成模型已经在包括图像生成、音频合成、自然语言建模和科学发现在内的各个领域取得了显著进展。然而,迫切需要更好地支持对涉及边际概率 p(x)和条件概率 p(x|x)的关键问题进行高效概率推断,适用于变量的适当子集 S、U、V。直接处理这些数量的能力在应用中至关重要,如异常检测[50, 40]、掩码语言建模[11, 73]、图像修复[74]和受限蛋白质/分子设计[69, 55]。此外,具有对任意变量子集进行推理的能力使用户能够根据其特定需求和偏好利用模型。例如,在蛋白质设计中,科学家可能希望手动引导从用户定义的子结构到相关变量上的特定路径下生成蛋白质。这需要生成模型执行任意边际推理。为此,神经自回归模型(ARMs)[3, 30]已经被开发出来,以便根据将高维联合分布建模为使用概率链规则对单变量条件进行因式分解的思想来促进条件/边际推理。

为了在最大似然估计(MLE)的设定下扩展 ARMs 并实现任意顺序的生成建模,已经做出了许多努力[30,66,20],在应用领域取得了巨大进展,如遮蔽语言建模[73]和图像修复[20]。然而,在最广泛使用的现代神经网络架构(例如 Transformers [68]和 U-Nets [53])中,边际似然评估受到 O(D)神经网络传递次数的限制,其中 D 是序列的长度。这种扩展使得在自然语言和蛋白质等数据中出现的长序列上评估似然变得困难。与 MLE 相比,在基于能量训练(EB)的设置中,我们只能访问未归一化(对数)概率函数(由奖励或能量函数指定),而不是经验数据样本,可以逐点评估以使生成模型匹配。在这种情况下,ARMs 仅限于固定顺序的生成建模,并且在训练中缺乏可扩展性。

当在基于能量的训练中匹配对数概率时,用于扩展MLE条件式训练的子采样技术不再适用(详情见第4.3节)。

为了提高离散数据生成建模的可扩展性和灵活性,我们提出了一种新的生成模型家族,即边缘化模型(MAMs),它直接对变量x中的任何子集xS 的边缘分布p(xS )进行建模。直接访问边缘具有两个重要优点:1)显著加快了对任何边缘的推断速度:和2)使得在MLE和EB设置下可扩展地训练任意阶生成模型成为可能。

marginalization models (MaMs), that directly model the marginal distribution 𝑝(𝐱𝒮) for any subset of variables 𝐱𝒮 in 𝐱. Direct access to marginals has two important advantages: 1) significantly speeding up inference for any marginal, and 2) enabling scalable training of any-order generative models under both MLE and EB settings.

该模型的独特结构使其能够同时表示给定离散联合概率质量函数的所有边缘分布的耦合集合。为了使模型有效,它必须与概率的求和规则一致,我们将这个条件称为“边缘化自洽性”(见图1);学习如何通过可扩展的训练目标来强制执行这一条件是本工作的关键贡献之一

图1:边缘化模型(MAMs)使得能够使用神经网络θ来估计任何边缘概率,该神经网络学习“边缘化掉”变量。该图说明了在位串上(代表分子)边缘化单个变量的过程,为清晰起见,选择了两个替代方案(通常是K)。条形图表示概率质量。

我们展示了MAMs可以在最大似然和基于能量的训练设置下使用可扩展的学习目标进行训练。我们在各种离散数据分布上展示了MAMs在这两种设置下的有效性,包括二进制图像、文本、物理系统和分子。我们通过实验证明,MAMs在边缘似然评估中实现了数量级的加速。对于基于能量的训练,MAMs能够将任意阶生成模型的训练扩展到先前方法无法实现的高维问题。

2 背景

我们首先回顾了训练生成模型的两种主流设置:最大似然估计和基于能量的训练。然后我们介绍了自回归模型。

最大似然估计(MLE)

给定一个数据集

,从数据分布

独立同分布(i.i.d.)地抽取,我们的目标是通过最大似然估计来学习分布

这等价于在经验分布下最小化Kullback-Leibler散度,即最小化

。这是在生成图像(例如,扩散模型[59, 18, 60])和语言(例如GPT [49])中最常用的设置,在这些领域,我们可以很容易地从分布中绘制观察到的数据。

奖励函数 r(x) 可以根据人类偏好或从第一原理上由物理系统定义。例如,(a) 在对齐大型语言模型时,r(x) 可以表示人类偏好[43, 42];(b) 在分子/材料设计中,它可以指定样本的测量或计算属性与某些功能要求的接近程度[2];(c) 在建模物理系统的热力学平衡系集合时,它是给定状态的(负)能量函数[41, 72, 9]。

方程(2)中的训练目标可以使用 REINFORCE 算法[71]的蒙特卡罗梯度估计进行优化。学习到的生成模型 θ 允许我们高效地从感兴趣的分布中近似地生成样本,否则通过运行具有能量函数 f(·) 的 MCMC 将更加昂贵。

自回归模型

自回归模型(ARMs)[3, 30]通过使用链规则将复杂的高维分布 p(x) 分解为单变量条件概率来建模:

任意阶自回归模型(AO-ARMs)Uria等人[66]提出学习自回归模型的条件概率,其包括{1,...,D}的所有排列。在MLE设置下,模型ϕ通过最大化一个下界目标[66, 20]来进行训练,该目标使用对排序的均匀分布的期望。这个目标允许可扩展地训练AO-ARMs,利用有效的并行评估,在一次前向传递中为每个标记评估多个一步条件概率,使用U-Net [53]和Transformer [68]等架构。然而,在EB设置中,仅对任意阶条件概率建模本身就会带来训练挑战。我们将在第4.3节中更详细地讨论这个问题。

3 边缘化模型

我们提出了边缘化模型(MAMs),这是一种新型的生成模型,可以在高维问题上实现可扩展的任意阶生成建模,同时实现了最大似然和基于能量的训练的有效边缘评估。边缘化模型的灵活性和可扩展性得益于对边缘分布的显式建模和边缘化自洽性的强制执行

在本文中,我们专注于使用离散变量向量进行离散结构的生成建模。这种向量表示涵盖了各种具有离散结构的现实世界问题,包括语言序列建模、蛋白设计以及具有基于字符串的表示形式的分子(例如,SMILES [70] 和 SELFIES [29])。此外,向量表示本质上适用于任何离散问题,因为将任何离散对象编码为离散变量向量是可行的。

我们将(4)称为“边缘化自洽性”,任何有效的分布都应该遵循这个规则。边缘化模型 θ 的目标是尽可能准确地估计任何变量子集 xS 的边缘概率 p(xS )。为了实现这一目标,我们训练了一个深度神经网络 pθ,该网络在最小化完整联合分布上的 pθ(x) 和 p(x) 之间的距离的同时,强制执行边缘化自洽性。换句话说,MAM 可以通过单次前向传递学习对任意变量子集进行边缘推断。

参数化

一个由神经网络 θ 参数化的边缘化模型接收 xS 并输出边缘对数概率 fθ(xS ) = log pθ(xS )。请注意,对于不同的子集 S 和 S',xS 和 x'S 属于不同的向量空间。为了统一输入到神经网络的向量空间,我们引入了一个增强的向量空间,该空间还额外包括了“边缘化掉”的变量 xSc。通过定义一个特殊符号“□”来表示“边缘化掉”的变量的缺失值,增强的向量表示是D维的,并定义为:

采样

使用边际化模型进行采样,可以通过选择任意顺序,一次对一个或多个变量进行采样,从而从已知分布中进行采样。在本文中,我们集中讨论一次生成一个变量的采样过程。为了得到每一步生成的条件,我们可以使用概率的乘积法则:

然而,如果不严格执行(5)中的单步边缘化一致性,上述抽样就不是一个有效的条件分布。

因为它可能不会严格地总和为一。因此,我们使用以下归一化的条件分布:

边缘化自洽性的可扩展学习

在训练中,我们通过在对数空间中最小化(5)中约束条件的平方误差来施加边缘化自洽性。评估(5)中的每个边缘化约束需要K次神经网络前向传递,其中K是离散值xd可以取的数量。当K很大时,这使得小批量训练具有挑战性,为了解决这个问题,我们用另一个神经网络ϕ参数化的可学习条件概率增强了边缘化模型。(5)中的边缘化约束可以进一步分解为K个并行的边缘化约束。

通过将方程(4)中的原始边缘化自洽性分解为方程(7)中高度并行的边缘化自洽性,我们得到了O(KD·D·D!)个约束。尽管约束数量增加了,但通过对约束进行抽样,训练边缘化自洽性变得高度可扩展。在训练过程中,我们指定了一个分布q(x)来采样边缘化约束。在实践中,它可以设置为感兴趣的分布,以执行边缘推断,例如𝑝data 或生成模型的分布 𝑝𝜃,𝜙 的分布。在实证实验中,我们发现使用分解为高度并行自洽性错误的目标是学习具有可扩展性的边缘的关键因素。

4 训练边缘化模型

4.1 最大似然估计训练

在这种设置下,我们使用最大似然目标训练MAMs,同时额外强制执行方程(5)中的边缘化约束:

两阶段训练解决上述优化问题的一种典型方法是将约束转化为惩罚项,并优化受惩罚的目标,但我们在实践中发现学习速度慢且不稳定。相反,我们确定了一种替代的两阶段优化形式,理论上等价于方程(8),但导致更高效的训练:

命题1. 在对使用的神经网络是通用逼近器的温和假设下,解决(8)中的优化问题等价于以下两阶段优化过程:

第一阶段可以解释为与AO-ARMs [66, 20] 中的拟合条件概率的方式相同,而第二阶段则是从条件概率中提取边缘概率。这种直觉来自于概率的链式规则:最优的条件概率ϕ与边缘概率θ之间存在一一对应关系,即对于任何σ和x,

。通过假设神经网络是通用逼近器,我们可以先优化最优条件概率,然后优化相应的最优边缘概率。我们在附录A.1中提供了证明细节。

4.2 基于能量的训练

在这种设置下,我们使用方程(2)中的基于能量的训练目标来训练MAMs,并使用一个惩罚项来强制执行方程(5)中的边缘化约束:

对于自洽性惩罚项,我们从指定的感兴趣的数据分布中采样数据x,并从均匀分布中采样排序σ、步骤d。

通过持续的MCMC进行高效采样

我们需要从pθ中获得廉价而有效的样本,以便执行REINFORCE算法,因此通过随机选择一个排序并使用条件分布pϕ(xσ(d)|xσ(<d))进行块Gibbs采样步骤来维护一组持续的马尔可夫链(完整算法见附录A.4),类似于持续性对比散度[64]。当条件网络pϕ的样本接近于边缘网络pθ时,这些样本可以作为来自边缘网络pθ的近似样本。否则,我们还可以使用重要性采样来获得无偏估计。

4.3 解决ARM限制

我们详细讨论了MAMs如何解决ARMs的一些局限性。前两者适用于两种训练设置,而后两者仅适用于基于能量的训练。

1) 慢速的边缘推断

由于顺序条件建模,使用ARMs(或AO-ARMs估计任意边缘)评估边缘pϕ(xo)需要将神经网络ϕ应用多达D次,这在高维数据的时间和内存上是低效的。相比之下,MAMs能够通过一次神经网络前向传递估计任何任意边缘。

MAMs通过强制自洽性约束来解决这个问题。MAMs不受固定排序的限制,因为边缘是与排序无关的,我们可以针对边缘化自洽性约束的期望进行优化。

3) 高维问题上的训练不可扩展

当最小化log pϕ(x|σ)与目标log p(x)之间的差异时,ARMs需要对条件概率求和以评估log pϕ(x|σ)。一个可能的解决方法是对一步条件概率pϕ(xσ(d)|xσ(<d)进行子抽样以估计pϕ(x),但这会导致由于得分函数和距离项的乘积,即方程(9)中的REINFORCE梯度的高方差,这两者都具有高方差(我们在实验中验证了这一点,见图3)。因此,为了计算目标函数的梯度,基于能量的训练需要对条件进行D次序列评估。这个约束导致了批量B样本的有效批量大小为B×D,严重限制了ARMs在高维问题上的可扩展性。此外,当维数很高时,从ARMs获取REINFORCE梯度估计器的蒙特卡洛样本是很慢的。由于固定的输入顺序,这个过程需要D个顺序抽样步骤,使得更具成本效益的抽样方法如持续性MCMC变得不可行。边缘化模型通过直接使用边缘神经网络来估计对数似然,避开了这一挑战。此外,任意阶训练的支持通过利用持续性MCMC方法实现了高效的抽样。

5Related Work

Autoregressive models

深度学习的发展极大地提升了ARMs在不同模态(包括图像、音频和文本)上的性能。任意阶(与排序无关)ARMs首次在[66]中引入,通过使用任意阶的下界目标在最大似然设置下进行训练。最近的工作,如ARDM [20],展示了任意阶离散建模图像/文本/音频的最新性能。Germain等人[16]训练了一个带有掩码的自动编码器,输出给定排序的所有一步条件概率序列,但是其生成效果不如[67, 73, 20]等方法在给定掩码下预测一步条件概率。Douglas等人[14]训练了一个AO-ARM作为提议分布,并使用重要性采样来估计DAG结构贝叶斯网络中的任意条件概率,但在合成数据集上的实验验证有限。Shih等人[57]利用修改后的ARMs训练目标提高了边缘推断性能,但失去了任意阶生成能力。在基于能量的训练设置中,ARMs被应用于科学问题[9, 72],但在D较大时缩放性受到影响。在第4.3节中详细比较了MAMs和ARMs。

任意条件/边缘模型

对于连续数据,VAEAC [25]和ACFlow [31]将条件变分编码器和归一化流的概念扩展到了模拟任意条件的建模。ACE [62]通过直接建模能量函数提高了任意条件模型的表现力,这减少了参数化的约束,但伴随着近似归一化常数的额外计算成本。概率电路(PCs)[6, 45]不使用神经网络作为函数逼近器,而是通过构建一个具有特定结构约束的计算图,使用求和和乘积运算来为条件和边缘提供可行的概率模型。PC的示例包括Chow-Liu树[7]、算术电路[10]、和和积网络[47]等。Peharz等人[45]通过将算术运算合并为单个单块einsum运算和自动微分来提高PCs的可扩展性。[33, 34]展示了在连续图像数据上,通过从训练的深度生成模型中提炼潜在变量,PCs的潜力。然而,表达能力仍受到结构约束的限制。上述所有方法都专注于MLE设置。

GFlowNets [2, 4]将生成问题阐述为将终端状态处的概率流与目标归一化密度匹配。与ARMs相比,GFlowNets通过假设可学习的生成路径通过有向无环图(DAG)来灵活地建模生成过程。可学习生成路径的优点是能够灵活地建模生成过程,但代价是牺牲了任意阶生成和精确似然评估的灵活性。在固定的生成路径下,GFlowNets退化为固定阶ARMs [75]。在附录A.3中,我们进一步讨论了GFlowNets与AO-ARMs/MAMs之间的联系和区别。对于离散问题,Zhang等人[76]利用轨迹平衡目标[38]在平方距离损失上对GFlowNets进行了训练。这在D较大时不可扩展(与第4.3节中的ARMs相同的原因),并使得直接访问边缘不可用。

在MLE设置中,还额外从数据中学习了能量函数,以便使用基于能量的训练来训练模型。

6 实验

我们在包括二进制图像、文本、分子和物理系统在内的离散问题上,针对最大似然估计(MLE)和基于能量的训练(EB)设置进行了边缘化模型(MAM)的实验。我们考虑以下基线进行比较:任意阶ARM(AO-ARM)[20]、ARM [30]、GFlowNet [39, 76]、离散流 [65] 和概率电路(PC)[45]。MAM、PC和(AO-)ARM支持任意边缘推断。离散流允许精确的似然评估,而GFlowNet需要使用重要性样本对似然进行近似。为了评估AO-ARM的边缘推断,我们可以使用一个集成模型通过对几个随机排序进行平均(AO-ARM-E),或者使用一个随机排序(AO-ARM-S)。一般来说,AO-ARM-E应该始终优于AO-ARM-S,但成本更高。神经网络架构和训练超参数的详细信息可以在附录B中找到。

6.1 最大似然估计训练

二进制MNIST 我们在表1中报告了负测试似然(位/数字)、边缘估计质量和每个小批量的边缘推断时间(大小为16)。为了保持GPU内存使用相同,我们顺序评估ARMs的似然。MAM和AO-ARM都使用了一个包含4个ResNet块和注意力层的U-Net架构(详见附录B)。由于GFlowNets无法扩展到大型架构如U-Net,因此我们使用了张等人[76]的MLP报告GFlowNet的结果。对于MAM,我们使用条件网络来评估测试似然(因为这也是MAM生成数据的方式)。边缘网络用于评估边缘推断的质量。边缘似然估计质量的评估,我们进行了一项受控实验,其中我们随机遮挡测试图像的部分,并生成具有不同遮挡水平的多个样本(参见图4)。该过程允许我们获得一组不同但可比较的样本,每个样本与不同的似然值相关联。对于每个模型,我们评估生成样本的似然,并将其与AO-ARM-E的估计进行比较,因为它在测试数据上实现了最佳的似然。我们在一组随机测试图像上重复这个受控实验。报告了平均斯皮尔曼和皮尔逊相关系数,以衡量给定模型与AO-ARM-E之间边缘推断似然的相关性强度。MAM在边缘推断中实现了接近4个数量级的加速,同时与AO-ARM-S的质量相当。PC在边缘推断中也非常快,但在质量方面存在差距。生成的样本和关于部分图像的额外边缘推断在附录B中。

分子集 (MOSES) 我们在从 ZINC 数据库 [61] 提炼的基准分子数据集 [46] 上测试 MAM 的生成模型。相同的指标被报告为 Binary-MNIST。似然质量的测量方式类似,但是是在随机的测试分子组上而不是生成的分子上进行测量。MAM 和 AO-ARM 生成的分子与标准的最先进的分子生成模型相当,例如 CharRNN [56]、JTN-VAE [26] 和 LatentGAN [48](参见附录 B), 任意订单生成的额外可控性和灵活性。MAM 支持更快的边际推理,这对于领域科学家推断(子)结构的可能性非常有用。生成的分子和性质直方图可在附录 B 中找到。

Text8 Text8 [37] 是一个广泛使用的字符级自然语言建模数据集。该数据集包括来自维基百科的1亿个字符,分成250个字符的块。我们遵循与二进制MNIST相同的测试程序,并报告相同的指标。离散流的测试NLL来自[65],由于没有开源实现,无法评估其他指标。

6.2 基于能量的训练

我们将MAM与使用条件总和评估log pϕ的ARM和使用一步条件估计log pϕ的ARM-MC进行比较。ARM可以被视为学习自回归条件的黄金标准,因为其梯度需要在完整的生成轨迹上进行评估,这是最具信息性和昂贵的。MAM使用边缘网络评估log pθ,并为批处理中的每个数据点子采样一个一步边缘化约束。对于ARM和GFlowNet,有效批处理大小为B×O(D),对于批处理大小为B的ARM-MC和MAM,为B×O(1)。MAM和ARM使用带基线的REINFORCE梯度估计器来优化KL散度。GFlowNet是基于平方距离的每个样本梯度进行训练的[76]。

Ising model[24]模拟相互作用的自旋,在数学和物理学中广泛研究(参见MacKay[35])。我们研究了一个正方形晶格上的伊辛模型。D个站点的自旋由一个D维二进制向量表示,其分布为

),其中

,J是二进制邻接矩阵。这些模型,尽管简单,与高熵合金的复杂行为有类似之处[9]。我们将MAM与ARM、ARM-MC和GFlowNet在一个10×10(D = 100)和一个较大的30×30(D = 900)伊辛模型上进行比较,其中ARM和GFlowNets无法扩展。按照Grathwohl等人[17]的方法生成了2000个基础真实样本,并对这些样本进行测试负对数似然度的测量。我们还通过从学习模型中采样并评估

来测量

图5包含了生成样本的−EJ(x)的KDE图。正如在第4.3节中描述的那样,ARM-MC梯度具有高方差,无法收敛。它也往往会崩溃并收敛到单个样本。MAM在边缘推断方面具有显著的加速,并且是唯一支持任意阶生成建模的模型

从KL散度和似然度的表现来看,与固定/学习顺序的模型相比,只略逊色,这是预期的,因为任意阶建模比固定阶建模更难,而MAM正在解决一个更复杂的任务,即联合学习条件和边缘。在一个30×30(D = 900)的伊辛模型上,MAM在基础真实样本上实现了0.835 bpd,而ARM和GFlowNet则无法扩展。生成样本的分布显示在图5中。

具有目标特性的分子生成 在这个任务中,我们对训练生成模型以达到特定目标性质g(x),如亲脂性(logP)、合成可达性(SA)等感兴趣。我们定义分子的分布遵循

,其中g∗是属性的目标值,τ是温度参数。我们针对目标值4.0和-4.0的亲脂性,分别以τ = 1.0和τ = 0.1训练ARM和MAM。两个模型都使用批量大小为512进行了4000次迭代训练。结果如图6和表5所示(附录B中还有其他图表)。研究结果与伊辛模型实验一致。再次表明,MAM的性能略低于ARM。然而,只有MAM支持任意阶建模并能够扩展到高维问题。图6(右)显示了D = 500时MAM的分子生成情况。

7 结论

总之,边缘化模型是一种新颖的用于高维离散数据的生成模型家族,提供可伸缩且灵活的生成建模,具有可计算的似然性。这些模型显式地建模了所有引出的边缘分布,允许通过神经网络的单次前向传递快速评估任意边缘概率。MAMs还支持任意阶生成建模的可伸缩训练目标,这是以往方法在基于能量的训练设置下很难实现的。潜在的未来工作包括设计新的神经网络架构,自动满足边缘化自我一致性。我们感谢智能概率系统普林斯顿实验室的成员和匿名评审人员的有价值的讨论和反馈。我们还要感谢Andrew Novick和Eric Toberer在科学应用中基于能量的训练方面的有价值的讨论。本工作部分支持由NSF资助,编号为IIS-2007278和OAC-2118201。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-05-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 CreateAMind 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 5Related Work
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档