# 1907.05242_PKM: Large Memory Layers with Product Keys * 首页: * PDF: * 引用: * 176(2026-01-19) * 组织: * Facebook AI Research * Sorbonne Universit´es, UPMC Univ Paris(索邦大学) * 链接 * ## 总结 ## From Moonlight ### 三句摘要 1. 💡 本文提出了一种名为 "Product Keys Memory" (PKM) 的新型神经网络层,旨在通过引入海量参数(最高达十亿)显著增加网络容量,同时保持可忽略的计算开销。 2. ⚙️ PKM 采用 product keys 结构,将 key 定义为两个 sub-keys 的串联,从而实现对大规模 key 空间的高效且精确的最近邻搜索,计算复杂度从 O(|K|) 降至 O(√|K|)。 3. 🚀 将 PKM 集成到 Transformer 架构中,在300亿词的语言模型任务上,一个12层的 PKM 模型在推断速度快一倍的情况下,性能超越了24层的 baseline Transformer 模型。 ### 关键词 - Product Keys: 乘积键是一种用于高效查找大量键(keys)的方法。在本文中,它指的是将一个查询(query)向量拆分成多个子查询,并分别在多个子键(sub-key)集合中进行最近邻搜索,以隐式地定义一个巨大的键空间。具体来说,本文将键表示为两个子键向量的组合(外积),例如 $K = \{(c, c') | c \in C, c' \in C'\}$, 其中 $C$ 和 $C'$ 是子键集合。这种结构允许通过分别搜索 $|C|$ 个和 $|C'|$ 个子键来近似找到与完整查询最匹配的键,从而将搜索复杂度从 $O(|K| \times d_q)$ 降低到 $O(\sqrt{|K|} \times d_q)$,其中 $|K|$ 是总键数,$d_q$ 是键的维度。这种方法极大地增加了存储容量,同时保持了快速的查询速度。 - Memory Layer: 记忆层是本文提出的一种可集成到神经网络中的模块,旨在显著增加模型的容量,同时保持较低的计算开销。它包含一个查询网络(query network)、一个用于搜索的乘积键(product keys)结构,以及一个值查找表(value lookup table)。输入首先通过查询网络生成查询向量 $q(x)$,然后该查询向量与大量的乘积键进行比较,以检索最相关的键。根据检索到的键的相似度得分,从值查找表中加权求和得到输出 $m(x)$。这种层的核心优势在于其能够通过乘积键机制支持近乎无限的参数量,而计算成本却只与 $\sqrt{|K|}$ 成正比。 - Key-Value Memory: 键值记忆是一种通用的内存架构,广泛应用于多种神经网络模型中,用于存储和检索信息。它由一个键(key)集合和一个与每个键关联的值(value)集合组成。当输入一个查询(query)时,模型会在键集合中搜索与查询最相似的键,然后根据相似度得分加权检索对应的值,将这些值组合起来作为输出。本文提出的记忆层也采用了这种键值对的模式,其中乘积键构成了键集合,而值查找表存储了对应的值。 - Nearest Neighbor Search: 最近邻搜索(Nearest Neighbor Search, NNS)是指在给定的数据集中,找到与一个查询点最相似(通常是距离最近或内积最大)的数据点(或一组点)的过程。在本文中,最近邻搜索是记忆层核心操作的一部分。通过乘积键的结构,文章实现了一种高效的、精确的最近邻搜索,能够在海量存储键中快速定位与查询向量最相关的几个键,从而实现高效的信息检索。 - Transformer: Transformer 是一种基于自注意力机制(self-attention mechanism)的深度学习模型架构,最初在自然语言处理(NLP)领域取得了巨大成功,并已广泛应用于机器翻译、语言理解和文本生成等任务。它通过堆叠多个自注意力层和前馈神经网络(FFN)层来处理序列数据。本文将提出的记忆层集成到 Transformer 架构中,通常用来替换其中的 FFN 层,以增强模型的容量和性能。 - Language Modeling: 语言建模(Language Modeling)是自然语言处理中的一个核心任务,其目标是估计一个词序列出现的概率,或者预测序列中的下一个词。成功的语言模型能够捕捉语言的统计规律和语义信息。本文选择大规模语言建模作为评估其记忆层方法的任务,因为这类任务通常需要极大的模型容量来处理海量数据,而本文提出的方法能够有效地扩展模型容量。 - Capacity: 容量(Capacity)在机器学习中通常指模型学习复杂模式的能力,或者说模型能够存储和表示的信息量。在神经网络中,模型容量通常与参数数量、网络深度或宽度等有关。本文提出的记忆层旨在通过其巨大的键值存储来显著增加神经网络的容量,使其能够更好地拟合大规模数据集,从而提高模型的预测精度。 - Computational Overhead: 计算开销(Computational Overhead)指的是模型在训练或推理过程中所需的计算资源,如计算量(FLOPs)或运行时间。本文的核心贡献之一是提出了一种能够极大增加模型容量的记忆层,但其计算开销却非常小。这是通过乘积键的快速搜索机制和对少量记忆值的稀疏读写来实现的,使得模型在提升性能的同时,计算效率并未显著下降。 - Product Quantization: 乘积量化(Product Quantization, PQ)是一种用于压缩高维向量或加速向量相似度搜索的技术。它通过将原始向量空间分解为多个子空间,并在每个子空间中使用独立的量化器(如 k-means)来量化子向量,然后将这些量化后的子向量拼接起来形成最终的压缩表示。本文提出的乘积键机制借鉴了 PQ 的思想,将大型键集合的搜索问题分解为在多个子键集合上的搜索,从而实现高效的近邻搜索。 - Sparse Reads/Writes: 稀疏读写(Sparse Reads/Writes)是指在访问记忆或参数时,只有一小部分被激活或更新。在本文中,记忆层的查询(query)只触发对少量(k 个)最近邻键及其对应值的读取,而大多数记忆单元保持不被访问。同样,在训练时,只有与当前输入相关的少量记忆单元(键和值)才会被更新。这种稀疏访问模式是实现模型高效训练和推理的关键。 - Multi-head Memory: 多头记忆(Multi-head Memory)是本文在记忆层中引入的一种增强机制,灵感来源于 Transformer 中的多头注意力(multi-head attention)。它允许模型同时并行地使用多个独立的记忆查询头(memory heads),每个头都有自己的查询网络和子键集合。这些头独立地检索记忆,然后将它们各自的输出加权求和,从而允许模型从不同的角度或在不同的表示子空间中探索和利用记忆信息,增强了模型的表达能力。 ### 摘要 本文介绍了一种名为 Product Key Memory (PKM) 的新型存储层,该层可轻松集成到神经网络中,旨在显著增加模型的容量,同时保持可忽略的计算开销。通过其设计和基于 Product Keys 的访问模式,该存储层能够实现快速且精确的最近邻搜索,从而将模型的参数量扩展到数十亿级别。 **1. 引言与背景** 随着训练数据量的不断增长,神经网络的规模也日益扩大,导致计算复杂性急剧增加。为了解决这一问题,研究人员致力于开发在有限计算预算下仍能提供高容量的架构。现有的 Memory-augmented Neural Networks 通常随着内存大小线性扩展,对于超大型内存而言开销巨大。尽管 Rae et al. [37] 曾提出利用外部索引结构实现稀疏读写,但其近似性要求在训练期间定期重新学习索引以避免漂移。本文提出的方法则提供了一种可在键空间上进行精确搜索的 Key-Value 存储层,解决了现有方法的局限性。 **2. 核心方法:Learnable Product Key Memories** PKM 层作为神经网络中的一个函数 $m: R^d \rightarrow R^n$,旨在提供巨大的模型容量。 **2.1 Memory Design** 存储层由三部分组成: * **Query Network (查询网络)**:函数 $q: x \rightarrow q(x) \in R^{d_q}$ 将 $d$ 维输入 $x$ 映射到 $d_q$ 维的潜在空间。通常,这是一个线性映射或一个 MLP。添加 Batch Normalization 层有助于增加训练期间的键覆盖率。 * **Key Selection (键选择)**:这是 PKM 的核心。与传统的 Key-Value 内存不同,PKM 中的键并非显式存储,而是通过两个子键集 $C$ 和 $C'$ 的笛卡尔积隐式定义。即 $K = \{(c, c') | c \in C, c' \in C'\}$。总键数为 $|K| = |C| \times |C'|$。每个子键集 $C$ 和 $C'$ 都包含 $d_q/2$ 维的子键。 * **Value Lookup Table (值查找表)**:与每个 Product Key 相关联的值向量集合。 **键选择过程:** 给定查询 $q(x)$,传统的键选择通过计算查询与所有键的内积来找到 $k$ 个得分最高的键: $I = \text{Tk}(q(x)^T k_i)$ (获取 $k$ 个最近邻) 其中 $I$ 是得分最高的 $k$ 个键的索引集合。接着,通过 Softmax 对这些选定键的得分进行归一化: $w = \text{Softmax} ((q(x)^T k_i)_{i \in I})$ (归一化 Top-$k$ 得分) 最后,输出 $m(x)$ 是选定键关联值的加权和: $m(x) = \sum_{i \in I} w_i v_i$ (聚合选定值) 对于 Product Keys,查询 $q(x)$ 被分成两个子查询 $q_1$ 和 $q_2$。然后,分别在子键集 $C$ 和 $C'$ 中找到与 $q_1$ 和 $q_2$ 最接近的 $k$ 个子键: $I_C = \text{Tk}((q_1(x)^T c_i)_{i \in \{1...|C|\}})$ $I_{C'} = \text{Tk}((q_2(x)^T c'_j)_{j \in \{1...|C'|\}})$ 通过这种方式,可以保证 $k$ 个最相似的键属于集合 $\{(c_i, c'_j) | i \in I_C, j \in I_{C'}\}$。 **2.2 Complexity (复杂性)** * **扁平键 (Flat Keys)**:需要 $|K|$ 次 $d_q$ 维向量比较,即 $O(|K| \times d_q)$ 操作。 * **Product Keys (积键)**:假设 $|C| = |C'| = \sqrt{|K|}$。首先,计算两个子查询与子键集的相似度,涉及 $O(|C| \times d_q/2 + |C'| \times d_q/2) = O(\sqrt{|K|} \times d_q)$ 操作。然后,在 $k^2$ 个候选 Product Keys 中找到 Top-$k$ 键,涉及 $O(k^2 \times d_q)$ 操作。 因此,总复杂度为 $O((\sqrt{|K|} + k^2) \times d_q)$。对于大型内存和较小的 $k$ 值,Product Keys 相比穷举搜索能将操作量减少约 1000 倍。 **2.3 Multi-head Memory Attention (多头内存注意力)** 为了增加模型表达能力,引入了多头机制。每个头独立计算一个查询,用于从内存中选择 $k$ 个键。所有头共享相同的值,但拥有独立的查询网络和子键集。内存的最终输出是所有头输出的简单求和:$m(x) = \sum_{h=1}^H m_h(x)$。这种机制提高了键的使用率,通常能提升性能。 **3. 实验** 文章在大型语言建模任务上对 PKM 增强的 Transformer 模型进行了实验,以展示其性能和内存使用情况。 **3.1 数据集** 使用了一个从 Common Crawl 中提取的 300 亿词级别的大型语料库(CC-News corpus),训练集包含 280 亿词。与 One Billion Word corpus 不同,该数据集未打乱句子,允许模型学习长距离依赖。数据使用 Moses toolkit 进行分词,并使用 fastBPE 进行 Byte Pair Encoding (BPE) 处理,词汇表大小为 60k。 **3.2 评估指标** * **困惑度 (Perplexity)**:衡量模型在测试集上的语言建模能力。 * **内存使用率 (Memory Usage)**:被访问值的分数 (百分比),反映内存容量的利用程度。 * **KL 散度 (KL Divergence)**:$log(|K|) + \sum z_i log(z_i)$,衡量内存访问模式的不平衡性。 * **速度 (Speed)**:以每秒处理的词数 (words per second) 衡量。 **3.3 训练细节** 使用 12、16 或 24 层的 Transformer 架构,维度为 1024 或 1600。PKM 层取代了 Transformer 中部分 FFN 层,并通过残差连接集成。模型使用 Adam 优化器训练,学习率设置为 $2.5 \times 10^{-4}$,内存值参数的学习率更高,为 $10^{-3}$。训练使用 PyTorch 和 32 块 Volta GPUs,并采用 float16 操作加速训练和减少内存占用。Product Keys 的搜索利用了 Johnson et al. [22] 的快速最近邻实现。默认配置为 4 个内存头,每个头选择 $k=32$ 个键,总内存槽数 $|K|=512^2$。 **3.4 主要结果** * **性能提升**:PKM 模型显著优于没有内存的基线模型。例如,一个带单个内存层的 12 层模型,其困惑度优于相同维度但有 24 层的无内存模型。添加 2 或 3 个内存层进一步提升了性能。 * **推理速度**:带 PKM 的模型推理速度快。一个 12 层、维度 1024 且带内存的模型,其困惑度优于 24 层(与 BERT large 配置相同)模型,但推理速度几乎快一倍。对于大型模型(维度 1600),添加内存层几乎不增加推理时间。 * **内存大小**:增加内存大小会持续降低测试困惑度。从 16k 增加到 1M,困惑度从 22.8 降至 18.0,且推理时间保持不变。 * **Query Batch Normalization**:在查询网络中加入 Batch Normalization 显著提高了内存使用率,特别是对于大型内存,例如,对于 1M 大小的内存,使用率从 25.8% 提升到 80.3%,困惑度从 19.8 降至 18.0。这表明 Batch Normalization 有助于更有效地利用内存容量。 * **内存位置**:在 Transformer 中层 (如第 4 或 5 层) 插入内存效果最佳。在太早或太晚的层插入内存效果不佳,这表明内存需要操作在更抽象的特征空间,并且需要后续层来处理和聚合信息。 * **头数 / k-NN**:增加头数或 $k$-NN 的数量都能提升困惑度和内存使用率。在 $h \times k$ 相同的情况下,模型内存使用率相似。4 个头和 32 个 k-NN 是速度和性能之间的良好折衷。 * **Product Keys vs. Flat Keys**:与扁平键相比,Product Keys 不仅速度更快(特别是对于大内存),而且内存使用率更高,困惑度更低。扁平键模型参数量更大,且计算开销随内存大小急剧增加,对于大内存而言几乎不可行。 **4. 结论** 本文提出的 Product Key Memory 层通过将键分解为 Product Set 和稀疏的内存值读写访问,极大地增加了神经网络的容量,同时保持了可忽略的计算开销。实验证明,该层在大型语言建模任务中取得了显著的性能提升,例如,一个 12 层的 PKM 增强模型能够超越 24 层的 BERT large 模型,而运行时间仅为其一半。 ## Abstract 本论文提出了一种**结构化记忆(structured memory)**,该结构可以方便地集成到神经网络中。该记忆模块设计为**容量极大**,可以在几乎不增加计算开销的前提下,将模型参数规模提升至**多达十亿级别**,从而显著提升模型的整体容量。 该记忆模块的设计基于**乘积键(product keys)**,支持**快速且精确的最近邻搜索**,使得在训练和推理过程中,模型能够在**计算效率与预测准确性之间取得更好的平衡**。 通过引入该记忆层,作者成功应用于**超大规模语言建模任务**,实验中使用了一个包含**300亿个词**的数据集,并将其集成到当前最先进的**基于Transformer的架构**中。 实验结果显示: - 一个**仅含12层**的增强记忆模型,**性能优于24层的标准Transformer模型**; - 并且在推理速度上,前者是后者的**两倍**。 最后,作者公开了代码以供复现实验结果。 ## 1 Introduction ### 背景与问题 神经网络广泛应用于机器翻译、图像分类和语音识别等复杂任务。随着训练数据的增加,模型规模不断扩大,例如视觉和自然语言处理领域的一些模型参数超过十亿。高容量模型能够更好地建模自然文本和图像数据,并提升泛化能力。然而,模型容量的提升也带来了训练和推理阶段计算复杂度的显著增加。 ### 研究趋势与挑战 近年来,研究者们致力于开发在有限计算预算下保持高容量的神经网络架构。例如,“On-device Visual Intelligence Challenge”强调了图像分类任务中准确率与计算复杂度之间的权衡。 一些研究尝试在不增加计算复杂度的前提下提升模型容量。例如,Rae 等人提出在神经网络中引入快速近邻搜索,通过稀疏读写操作使用大规模键值层。但该方法依赖于外部索引结构,存在近似性和训练过程中需定期重新学习的问题。 ### 本文方法 本文提出了一种可扩展的大容量键值记忆层,能够在键空间中实现**精确搜索**,而计算开销极低。与现有方法不同,本文将键定义为两个子键的拼接,借鉴了**乘积量化**(product quantization)的思想。这种结构隐式定义了大量键值对,其中值向量的数量随子键数量的平方增长。 尽管记忆槽(memory slot)数量庞大,但查找与输入最接近的键的效率很高,通常只需进行 $ \mathcal{O}(\sqrt{|\mathcal{K}|}) $ 次向量比较(其中 $ |\mathcal{K}| $ 为记忆槽总数)。所有记忆参数均可训练,但每次训练时仅更新少量记忆槽。这种稀疏性使得训练和推理都非常高效。 ### 应用与贡献 该方法适用于当前模型因数据量大而欠拟合或推理速度过慢的问题。作者将该记忆层集成到Transformer架构中,用于语言建模任务。选择Transformer是基于BERT和GPT-2的成功经验,表明模型容量提升能显著改善语言建模效果。 本文的主要贡献包括: 1. **提出新层结构**:在训练和测试阶段仅带来轻微计算开销,却显著提升模型容量。 2. **高效索引策略**:构造性地实现精确的最近邻搜索,避免了依赖需重新学习的索引结构。 3. **实验验证**:在一个包含24层、每层维度为1600的Transformer中,使用1个记忆层和12层即可超越24层Transformer的性能,且推理速度提高一倍。进一步实验表明,在不同复杂度的Transformer中加入更多记忆层,均能系统性地显著提升任务表现。 ### 图文说明 - **图1**:展示了键值记忆层的结构,输入经过查询网络生成查询向量,与所有键比较后输出加权值向量。 - **图2**:详细说明了查询生成和记忆设计,展示了乘积量化结构如何高效构建大规模键空间。 ## 2 Related work 本节回顾了提升神经网络容量但不显著增加计算复杂度的相关方法,主要包括以下几类: --- ### 1. **条件计算模型(Conditional Computation Models)** 这类方法通过将输入路由到大型网络中的不同子网络,使得每个输入只激活网络的一部分,从而节省计算资源。 - 典型方法包括: - 大型专家混合模型(Large Mixture of Experts) - 门控机制(Gating Techniques) - 基于强化学习的路由方法 --- ### 2. **带记忆增强的神经网络(Memory-Augmented Neural Networks)** 这类方法通过引入外部或内部记忆结构来增强模型的表示能力,尤其适用于如问答等需要处理变长输入的任务。 - 代表工作: - 基于记忆的神经层 - 支持在特征空间中读写,具有多种读写机制 - **局限性**:内存规模线性增长,难以扩展到非常大的记忆空间。 - 类似问题也出现在神经缓存模型(Neural Cache Models) --- ### 3. **离散化与近似技术(Discretization and Approximate Techniques)** 这类方法通过压缩权重或激活值来减少模型复杂度,加速推理过程。 - 示例: - Gerald 等人 提出将输入映射为低维二进制码,避免使用大型线性层 - Locality-Sensitive Hashing(LSH)用于近似点积计算 - **挑战**:训练时使用近似索引在高维空间中效果不佳,优化困难。 - **本文借鉴**:采用乘积量化(Product Quantization, PQ) 的思想,但目标不同: - 不用于构建近似索引 - 而是用少量可学习向量表示大量“键”(keys),通过反向传播更新 - 最近键的选取是**精确的**,继承了PQ的快速邻近搜索特性 --- ### 4. **稀疏表示模型(Sparsity Models)** 主要在无监督学习中研究,通过限制激活数量来提升效率。 - 代表方法: - k-稀疏自编码器(k-Sparse Autoencoder):保留编码中最大的k个值 - 胜者通吃自编码器(Winner Take All Autoencoder):利用mini-batch统计诱导稀疏性 - 稀疏访问记忆(Sparse Access Memory):通过阈值化和高效数据结构实现加速 - **问题**:依赖外部近似索引结构 ,需周期性重新训练 - **本文改进**:将键选择机制完全集成进网络结构中,避免外部索引 --- ### 5. **Transformer 与 注意力机制(Transformer and Attention Mechanisms)** Transformer 是当前 NLP 的主流模型,其基本结构由自注意力层和前馈网络(FFN)堆叠而成。 - **与本文模型的联系**: - 记忆层中的键(keys)、值(values)与自注意力机制类似 - **区别**: - 键和值不是来自输入 token,而是自由学习的嵌入向量 - 值的数量(即记忆大小)非常大 --- ### 总结 本节系统梳理了提升模型容量而不显著增加计算开销的多种方法,包括条件计算、记忆增强、离散化、稀疏建模以及Transformer结构。作者指出这些方法在扩展性、训练效率或外部依赖方面存在局限,并引出本文提出的“基于乘积键的大内存层”方法,旨在解决上述问题。 ## 3 Learnable product key memories 本节提出了一种可学习的记忆结构,作为神经网络中的一层,旨在提供**大容量**的表示能力。其核心思想是通过**结构化键(product keys)** 来高效地检索和更新记忆,从而在训练和推理过程中保持高效率。 --- ### 3.1 Memory design #### 高层结构(High-level structure) 记忆模块由三部分组成: 1. **查询网络(query network)**:将输入映射到一个低维查询向量; 2. **键选择模块(key selection module)**:包含两个子键集合,用于构建结构化的“产品键”; 3. **值查找表(value lookup table)**:存储与键对应的值。 流程如下: - 输入通过查询网络生成一个查询向量; - 该查询与所有键进行比较,选出**得分最高的k个键**; - 使用这些键的值进行加权求和,得到输出; - 所有参数均可训练,但每次只更新选中的k个键。 这种**稀疏更新机制**使得训练和推理都非常高效。 --- #### 查询生成:预处理网络(Query generation: pre-processing network) - 函数 $ q: x \mapsto q(x) \in \mathbb{R}^{d_q} $ 将输入 $ x \in \mathbb{R}^d $ 映射到一个低维空间; - 通常为线性变换或MLP,将维度从 $ d $ 降到 $ d_q = 512 $; - 添加**批归一化层**有助于提升键的覆盖范围,提升训练效果; - 实验验证见第4.5节。 --- #### 标准键分配与加权(Standard key assignment and weighting) 给定查询 $ q(x) $ 和键集合 $ \mathcal{K} = \{k_1, ..., k_{|\mathcal{K}|}\} $,执行以下步骤: 1. **选择k个最相似的键**: $$ \mathcal{I} = \mathcal{T}_k(q(x)^T k_i) $$ 其中 $ \mathcal{T}_k $ 表示取前k个最大值的索引。 2. **对得分进行Softmax归一化**: $$ w = \text{Softmax}((q(x)^T k_i)_{i \in \mathcal{I}}) $$ 3. **加权求和得到输出**: $$ m(x) = \sum_{i \in \mathcal{I}} w_i v_i $$ 其中 $ v_i $ 是与键 $ k_i $ 对应的值。 这些操作均可通过自动微分实现,使得该层可插入神经网络任意位置。 > **关键点**:操作(2)和(3)仅依赖前k个键,计算高效;但(1)涉及所有键的内积,效率低,因此引入**产品键结构**。 --- #### 产品键集合(The product key set) 定义键集合为两个子键集合的笛卡尔积: $$ \mathcal{K} = \{(c, c') \mid c \in \mathcal{C}, c' \in \mathcal{C}'\} $$ - $ \mathcal{C} $ 和 $ \mathcal{C}' $ 各包含 $ \sqrt{|\mathcal{K}|} $ 个子键; - 每个子键维度为 $ d_q / 2 $; - 总键数为 $ |\mathcal{K}| = |\mathcal{C}| \times |\mathcal{C}'| $。 **高效检索方法**: 1. 将查询 $ q(x) $ 分为两个子查询 $ q_1(x), q_2(x) $; 2. 分别在 $ \mathcal{C} $ 和 $ \mathcal{C}' $ 中找出前k个最相似的子键; 3. 构造 $ k \times k $ 个候选键; 4. 在这 $ k^2 $ 个键中选出最终的k个键。 > **保证性**:最终的k个键一定在上述候选集合中。 --- ### 3.2 Complexity #### 平面键(Flat keys)的复杂度 - 比较 $ |\mathcal{K}| $ 个键,每个键维度为 $ d_q $; - 总复杂度为 $ \mathcal{O}(|\mathcal{K}| \times d_q) $。 #### 产品键(Product keys)的复杂度 - 每个子键集合大小为 $ \sqrt{|\mathcal{K}|} $; - 比较两个子键集合,复杂度为 $ \mathcal{O}(\sqrt{|\mathcal{K}|} \times d_q) $; - 再在 $ k^2 $ 个候选键中选k个,复杂度为 $ \mathcal{O}(k^2 \times d_q) $; - 总复杂度为: $$ \mathcal{O}((\sqrt{|\mathcal{K}|} + k^2) \times d_q) $$ > **效率提升**:当 $ |\mathcal{K}| = 1024^2 $ 且 $ k $ 较小时,产品键比平面键快约 $ 10^3 $ 倍。 --- ### 3.3 Multi-head memory attention 引入**多头机制**提升模型表达能力: - 每个头独立生成查询,选择k个键; - 所有头的输出加权求和作为最终输出: $$ m(x) = \sum_{i=1}^H m_i(x) $$ 其中 $ H $ 为头的数量。 - 每个头有独立的查询网络和子键集合,但共享值集合; - 类似Transformer的多头注意力,但不是将一个查询拆分,而是生成多个独立查询; - 不同头倾向于选择不同的键,提升键的利用率; - 实验表明该机制显著提升性能,详见第4.5节。 --- ### 总结要点 | 模块 | 关键内容 | 优势 | |------|----------|------| | 查询网络 | 映射输入到低维空间,加批归一化 | 提升键覆盖范围 | | 产品键结构 | 两个子键集合的笛卡尔积 | 大幅减少检索复杂度 | | 多头机制 | 多个独立查询头 | 提升表达能力与键利用率 | | 复杂度分析 | $ \mathcal{O}((\sqrt{|\mathcal{K}|} + k^2) \times d_q) $ | 比平面键快 $ 10^3 $ 倍 | --- 如需进一步分析实验结果或实现细节,可继续提问。 ## 4 Experiments 本章节报告了在大规模 Transformer 模型中引入记忆模块的实验结果,并通过消融实验分析了不同记忆组件对模型性能和内存使用的影响。 ### 4.1 数据集 作者在大规模语言建模任务中评估记忆模块的效果。由于传统的 One Billion Word 数据集太小,容易过拟合,因此他们使用了一个更大、包含 280 亿词的 Common Crawl 新闻语料库(约 4000 万篇英文新闻),并保留 5000 篇作为验证和测试集。该数据集未打乱句子顺序,以支持长距离依赖建模。使用 Moses 工具包进行分词,并采用 BPE 编码(60k 个分割)以减少词汇量。 ### 4.2 评估指标 主要评估指标为测试集上的 **困惑度(perplexity)**。对于引入记忆的模型,还评估以下两个指标: - **内存使用率**:表示访问的值的比例,即 #非零值 / 总值数。 - **KL 散度**:衡量访问模式的分布与均匀分布的差异。KL 值越高,说明访问越集中于某些键。 ### 4.3 训练细节 - 使用 16 头注意力和学习的位置编码。 - 模型层数为 12、16 或 24 层,维度为 1024 或 1600。 - 使用 Adam 优化器(学习率 2.5e-4,β1=0.9, β2=0.98),并采用 Vaswani 的学习率调度策略。 - 内存中的键和查询网络使用相同学习率,而值使用更高的学习率(1e-3)。 - 使用 PyTorch 实现,训练在 32 块 Volta GPU 上进行,使用 float16 加速。 - 内存参数:H=4 个头,k=32 个最近邻,|𝒦|=512²=262k 个槽位。 - 内存插入位置:在 L 层中均匀分布,如 L=16、N=2 时插入第 6 和 12 层。 ### 4.4 实验结果 #### 表 1:不同模型的测试困惑度 | 维度 | 1024 | | | | 1600 | | | --- | --- | --- | --- | --- | --- | --- | | N 个记忆 | 0 | 1 | 2 | 3 | 0 | 1 | | 12 层 | 17.6566 | 15.6184 | 14.7886 | 14.4727 | 14.9994 | 13.6522 | | 16 层 | 16.6508 | 14.8761 | 14.1023 | - | 14.3643 | 13.1668 | | 24 层 | 16.0186 | 14.6188 | - | - | 13.9515 | - | **结论**: - 增加维度或层数可提升性能。 - 引入记忆比增加层数更有效。例如,12 层 + 1 个记忆的模型优于 24 层无记忆模型。 - 增加记忆数量(2 或 3 个)进一步提升性能。 #### 图 4:速度与困惑度的权衡 - 12 层 + 1 个记忆的模型比 24 层模型困惑度更低,且推理速度快近两倍。 - 对于 1600 维模型,加入记忆对推理速度影响极小。 --- ### 4.5 消融研究 #### 内存大小 - 内存越大,困惑度越低。16k → 1M,困惑度从 22.7972 降至 17.9819。 - 推理时间主要受访问的内存数量影响(由头数和 k 决定),与内存大小无关。 #### 查询的 BatchNorm - 对于小内存(16k、65k),BatchNorm 作用不大,使用率已接近 100%。 - 对于大内存(1M),BatchNorm 显著提升使用率(从 25.8% → 80.3%),困惑度从 19.79 → 17.98。 - 使用率与模型性能正相关。 #### 内存插入位置 - 插入第 4 或 5 层效果最佳,第 1 层效果最差。 - 插入第 6 层(接近输出)效果也不佳,说明需要上层网络进一步处理记忆信息。 #### 头数 / k-NN 数量 - 增加头数或 k-NN 数量可提升性能和内存使用率。 - 相同 h×k 的模型(如 (1,64), (2,32), (4,16), (8,8))性能相近。 - 4 头 + 32 个最近邻在速度与性能之间取得良好平衡。 #### 产品键 vs. 平坦键(Product Keys vs. Flat Keys) - 产品键通过分块搜索(|C|×|C|)实现高效查找,参数量比平坦键少 |C| 倍。 - 产品键显著提升内存使用率和困惑度,且推理速度更快。 - 平坦键在大内存下计算代价过高,难以训练。 #### 图 7:内存大小与推理速度 - 使用平坦键时,内存越大,速度越慢。 - 使用产品键时,内存大小对速度影响极小。 --- ### 总结 本章通过大规模实验验证了引入 **Product Key Memory(PKM)** 的有效性。PKM 能显著提升模型性能,同时保持较低的推理开销。关键发现包括: - **内存大小**:越大越好,但推理时间主要受访问数量影响。 - **BatchNorm**:对大内存至关重要,提升使用率和性能。 - **插入位置**:中间层效果最佳,输入和输出层效果较差。 - **头数与 k-NN**:增加可提升性能,但需权衡计算开销。 - **产品键**:相比平坦键,参数更少、速度更快、使用率更高。 最终,12 层 + 1 个 PKM 的模型在性能和效率上均优于 24 层无记忆模型。 ## 5 Conclusion 本节总结了论文的主要贡献和实验结果: - **核心创新**:论文提出了一种**记忆层(memory layer)**,可以在**几乎不增加计算开销**的前提下,显著提升神经网络的容量(capacity)。 - **关键技术点**:该记忆层的高效性依赖于两个关键设计: 1. **键(Key)的因子化**:将键表示为乘积集合(product set)的形式,降低了存储和计算复杂度。 2. **稀疏读写机制**:对记忆值(value)进行稀疏访问,减少了实际计算量。 - **应用方式**:该记忆层可以**集成到现有的神经网络架构中**,具有良好的兼容性。 - **实验结果**: - 在大规模语言建模任务中,使用了12层的记忆增强模型,达到了原本242424层的BERT-large模型的性能。 - **运行时间减少一半**,说明该方法在效率和性能之间取得了良好平衡。 总结:本论文提出了一种高效记忆增强机制,通过结构设计和稀疏访问策略,显著提升了模型性能,同时保持了低计算成本。