谷歌近期发布的 Gemma 3n 模型,凭借其卓越的性能和创新的设计,迅速引起了业界的广泛关注。其中,MatFormer(Matryoshka Transformer,套娃Transformer)架构作为其核心技术之一,更是为模型的算力效率和灵活部署带来了革命性的突破。本文将深入探讨 MatFormer 的原理、优势以及在 Gemma 3n 中的应用,揭示其如何通过 嵌套Transformer 结构实现模型 Mix’n’Match,为资源受限环境下的模型部署开辟新途径。

MatFormer:嵌套Transformer的核心思想

MatFormer 的核心思想在于训练一个大型模型(例如 Gemma 3n 的 E4B 版本),同时将多个小型、功能完整的子模型(例如 E2B 版本)嵌入其中。这就像俄罗斯套娃一样,每个子模型都是一个独立的Transformer,但它们共享权重,无需单独训练或蒸馏。这种 嵌套Transformer 结构避免了传统模型缩放过程中需要进行的重新训练或复杂的专家路由机制,极大地提高了模型训练和部署的效率。

传统的 Transformer 架构包含两个主要的学习模块:多头自注意力机制(Multi-head self-attention)和前馈网络(Feed-forward network, FFN)。MatFormer 主要针对 FFN 模块进行优化,因为它通常是参数数量和推理延迟的主要贡献者。虽然 MatFormer 的优化可以应用于任何可学习的权重,但FFN的优化在实践中效果最为显著。

在诸如 Llama 3 和 Gemma 3n 这样的大型语言模型中,FFN 通常被实现为一个两层 MLP(Gated MLP):

  1. 第一层线性层(扩展/中间层):将输入嵌入投影到更高维度的隐藏空间。
  2. 激活函数:通常使用 GELU、ReLU 或者最近流行的 Gated 变体,例如 SwiGLU 或 SiLU。
  3. 第二层线性层(投影层):将隐藏表示投影回原始嵌入维度。

这里,d 是 token 嵌入维度(模型维度),d_ff 是 FFN 的隐藏维度。d_ff 通常远大于 dd_ff >> d),通常是 4 到 6 倍。例如,在一个 LLaMA 风格的 Transformer 中,如果 d = 4096,那么 FFN 通常使用 d_ff = 11008 甚至更大。这些层占据了模型大小和内存使用的大部分。

MatFormer 的关键创新在于它允许每个Transformer块中存在多个不同宽度的 FFN。不再为每一层固定一个 FFN 宽度,而是 嵌套 多个不同宽度的 FFN。

假设 d_ff 是我们想要支持的最大 FFN 宽度(例如 16384),那么 FFN 宽度为 m_i 的子模型仅仅使用完整权重矩阵的左上角子矩阵。例如,如果 d = 4096 并且 d_ff = 16384,我们可以选择一组粒度来训练:m1 = 4096, m2 = 8192, m3 = 12288m4 = 16384。这里的核心思想是,较小宽度的子模块是较大宽度的子模块的子集。

联合训练:确保所有粒度的性能

训练所有粒度的关键在于如何确保所有子模型的性能都良好。MatFormer 并没有在每个训练步骤中联合计算所有子模型的损失,而是采用了一种更高效的方法:在训练期间随机抽样子模型的粒度。

MatFormer 使用多粒度损失,计算公式如下:

MatFormer Loss Function

其中,M_i 是对应于第 i 个粒度的子模型(即,每个 FFN 中只使用前 m_i 个神经元)。

这种设计有两个关键优势:

  1. 避免了同时计算所有粒度的开销:每个步骤只评估一个子模型,从而避免了额外的计算负担。
  2. 更新共享权重矩阵:由于每个子模型都与更大的模型共享参数,因此在每个步骤中训练不同的子模型仍然会更新共享权重矩阵。

在原始论文中,作者以相同的概率均匀抽样每个子模型来计算损失。在论文的附录中,作者还尝试调整各个粒度的抽样概率。他们发现,增加最大粒度的损失权重可以提升较大模型的性能,同时对较小模型的性能产生轻微的降低。

Mix’n’Match:模型部署的无限可能

MatFormer 的一大亮点在于,它不仅可以生成多个经过训练的嵌套子模型,还可以在推理时通过一种简单的 Mix’n’Match 技术,免费生成数百个额外的子模型。这为模型部署带来了极大的灵活性和适应性。

在训练过程中,MatFormer 明确优化了 g 个子模型,每个子模型对应于所有层中特定的粒度(即 FFN 宽度)。例如,一个子模型可能在所有层中使用宽度 m_1(例如 4096),另一个子模型可能在所有层中使用 m_3(例如 12288)。但这并不限制我们在推理期间使用统一的宽度。

在推理时,我们可以逐层改变 FFN 宽度。例如,一层可以使用 m_2,下一层可以使用 m_3,再下一层又可以使用 m_2,以此类推。由于所有的 FFN 块都是嵌套的,这些混合模型无需重新训练、无需微调,也无需重新定义架构。

这种逐层配置的空间非常巨大。对于 L 层和 g 个粒度,有 g 的 L 次方个可能的子模型。即使对于较小的值(例如 32 层,4 个粒度),也有数十亿个选项。

然而,并非所有这些模型的性能都一样好。为了选择一个好的配置,作者提出了一个简单的启发式方法:

  1. 偏好 FFN 宽度在层之间缓慢且一致变化的模型:避免从一层到下一层在小粒度和大粒度之间跳跃。
  2. 偏好单调或逐渐增加的序列:例如,在早期层中使用 m_2,然后在后面的层中逐渐增加到 m_3

这种启发式方法与模型的训练方式非常吻合,因为训练期间抽样的子模型总是使用统一的宽度。因此,虽然混合层组合从未经过明确训练,但与训练机制相似的配置通常具有更好的泛化能力。

在论文中,MatFormer 的作者将此方法应用于两种基于 Transformer 的模型:LLM 和 ViT(分别称为 MatLM 和 MatViT)。对于使用 MatFormer嵌套 FFN 训练的 LLM,作者发现 MatFormer 的所有粒度子模型的性能都优于各自单独或独立训练的类似大小的 vanilla Transformer LLM。在 MatLM 的实验中,他们训练了 4 个嵌套粒度,d_ffn(FFN 维度)/ d(模型维度)的比率为 → {0.5, 1, 2, 4}。他们将相应的模型命名为 S、M、L 和 XL。在这里,XL 是完整模型。

虽然 MatFormer 论文中训练的最大 Transformer 模型为 850M 参数,但 Gemma 3n 使用 MatFormer 实现 4B(有效 5B)参数表明了这种方法的可扩展性。

Mix’n’Match 模型的性能

作者还观察到,Mix’n’Match 模型倾向于落在精度-计算权衡曲线(帕累托最优精度-模型大小)上。这意味着,基于可用的计算资源,可以在推理期间创建尽可能大的 Mix’n’Match 模型,这可能会比相对较小的经过训练的子模型具有更好的精度。

这具有有趣的含义。例如,在许多部署用例中,我们有足够的计算资源(GPU 内存)来加载一个 20B 参数的模型,但最接近的变体是 14B 或 33B。使用 14B 参数的模型会导致资源利用不足和低于标准的性能,而 33B 参数的模型则无法放入内存。这种 Mix’n’match 方法允许免费创建几乎完美大小的变体,并具有相应的精度权衡。

Gemma 3n 基于 MatFormer 架构,也使我们能够执行 Mix’n’Match,并在 2B 和 4B 之间创建更多模型大小。在性能方面,Gemma 3n E4B 的 MMLU 精度为 62.3%,E2B 的精度为 50.90%。Mix’n’Match 的变体,例如 2.54B 的 MMLU 精度为 55.40%,2.69B 的精度为 57.70%,这明显优于 2B,略低于 4B。这使我们能够获得完美尺寸的精确模型,用于设备部署用例 – 智能手机、平板电脑、Mac 等。

为了尝试创建 Gemma 3n 模型的 Mix’n’match 变体,Gemma 3n 团队发布了 Matformer Lab,这是一个 Colab 笔记本,您可以在其中对模型进行切片并将其推送到 Hugging Face。他们还在 Hugging Face 上发布了一些最佳切片配置(正如我们上面所看到的,尽管我们可以生成 g 的 L 次方个模型,但并非所有 Mix’n’Matches 都可能对所需的尺寸是最佳的) – google/gemma3n-slicing-configs。这些配置包括在层级级别(更细粒度)或在块级别(包括 4 个本地层和 1 个全局层)更改 FFN 维度。

一致性预测:加速推理解码

作者关注的另一个方面是将较小模型作为较大模型的 嵌套 子模型进行训练,这在于预测 tokens 的一致性。验证这种一致性的一种方法是查看推测解码的拒绝率。

在推测解码中,不是总是从大型目标模型中自回归地预测 token,而是有一个较小的“草稿模型”和一个较大的“验证模型”(目标模型)。草稿模型执行 token 推测,即,它预测接下来的 n 个 token 及其概率,目标模型并行地执行对预测 token 的验证。通过比较目标模型对草稿模型预测的相同 token 的概率,执行拒绝抽样。当草稿模型与目标模型相比差异很大时,草稿模型的预测将被拒绝。

正如人们可以看到的,这里最重要的加速来自于草稿模型一致地预测与目标模型相似的 token 分布,从而导致更少的拒绝。如果拒绝率更高,那么我们将无法获得推测解码所期望的推理加速。

MatFormer 的作者表明,由于较小的子模型 嵌套 在较大(XL)模型中并与其一起训练,因此它们显示出草稿模型和验证模型所需的更一致的性质。当使用独立训练的“S”大小的模型与 XL 模型作为目标进行传统推测解码时,与目标模型的自回归解码相比,推理速度提高了 10%,而使用“S”大小的 MatLM 模型时,速度提高了 16%。

此外,由于两个 MatLM 模型(S 和 XL)是共同训练的,因此它们允许共享注意力缓存(这无法使用独立训练的模型来完成,因为潜在表示将非常不同)。

我发现更有趣的是,在部署期间,通用 XL 模型可以存储在内存中,并根据计算约束用于提取可适应的较小模型。这种动态资源自适应模型提取和推理可能对边缘用例非常有用。

MatViT:图像检索的强大工具

作者还训练了 ViT(Vision transformers)编码器的 MatFormer(MatViT)。除了观察到与 MatLM 类似的趋势外,MatViT 的一个有趣的用例是快速图像检索。

在典型的图像检索中,获取图库图像的嵌入(图库创建)是离线发生的并被存储,而查询(创建查询图像的嵌入)是实时发生的。这意味着图库嵌入可以使用更大、更精确的模型,而查询嵌入可能需要更小的模型以实现快速推理。

由于 MatFormer 中的较小子模型是 嵌套 且联合训练的,因此它们倾向于与较大模型共享相同的嵌入空间,并导致基于最近邻居的检索产生有意义的距离。这允许使用 XL 模型进行图库编码,同时使用较小(例如 S)模型进行查询。

MatFormer 与 MoE 的区别

几乎所有主要的最新开源 LLM Transformer 都是 MoE 架构。混合专家(MoE)模型通过向每个 Transformer 层添加多个 FFN(称为专家)来增加容量。在推理时,路由器决定每个 token 使用哪个(哪些)专家,并且仅执行那些选择的路径。这降低了每次前向传递的计算成本,因为大多数专家保持不活动状态。但是,仍然需要将所有专家权重加载到内存中,因为路由器可以在运行时选择它们中的任何一个。这使得 MoE 模型难以部署在内存受限的设备上 – 虽然计算是稀疏的,但内存使用量仍然很高。

相比之下,MatFormer 不依赖于路由或多个专家。相反,每个 FFN 都被设计为通过在较大的子网络中 嵌套 较小的子网络来支持多个粒度。在推理时,可以选择并独立运行特定的子模型。只需要将该子模型的参数加载到内存中。这使得 MatFormer 更适合于设备或低内存推理,在这种情况下,存储完整模型是不可行的。

结论:MatFormer引领模型部署新范式

MatFormer 代表了我们在模型训练、推理和部署方面的一种重大转变。传统上,大型模型使用大量的计算预算进行训练,而较小的变体是在之后创建的 – 要么单独训练,要么从较大的模型中蒸馏出来。MatFormer 通过共享权重训练在一个 Transformer 中 嵌套 多个功能子模型来打破这种模式,从而消除了对昂贵的蒸馏、重新训练或像混合专家这样的复杂路由的需求。

这种设计解锁了一个平滑的资源自适应模型连续体 – 从轻量级、随时可用的移动部署到全容量推理引擎 – 所有这些都来自一个基础模型。Gemma 3n 的 E2B 和 E4B 的成功、Mix’n’Match 变体的多功能性以及 MatViT 编码器证明的跨模式泛化性表明,MatFormer 风格的架构可能成为研究实验室发布可扩展模型系列的新默认方式。通过 嵌套Transformer 结构,MatFormer 不仅提高了 Gemma 3n 的算力效率,更重要的是,为未来模型部署带来了无限的可能性,尤其是在资源受限的边缘计算场景下。