2305.16300_Random-Access Infinite Context Length for Transformers

Abstract

本文提出了一种新的注意力机制——Landmark Attention(地标注意力),旨在解决Transformer模型在处理长上下文时的内存瓶颈问题。传统方法如循环记忆或基于检索的扩展虽然能在一定程度上处理长序列,但要么牺牲了注意力机制的随机访问能力,要么依赖于独立的检索机制,难以与现有注意力机制兼容。本文方法通过引入“地标标记”来代表输入序列中的每个块,并训练注意力机制通过这些地标标记选择相关块,从而在不依赖额外机制的情况下实现长上下文的直接检索。该方法与系统的数据结构和内存层次无缝集成,支持任意长度的上下文处理。实验表明,该方法在性能上可与Transformer-XL媲美,同时显著减少每一步所需的检索标记数量。此外,使用该方法对LLaMA 7B进行微调后,其上下文长度可扩展至超过32k token,达到GPT-4的水平。

1 Introduction

本章介绍了论文的研究背景、提出的问题以及所提出的解决方法。主要内容总结如下:

  1. 问题背景

    • 大型Transformer模型在语言建模中表现出色,但其基于注意力机制的计算复杂度为平方级,导致内存占用大,限制了上下文长度。

    • 尝试通过引入循环记忆(如Transformer-XL)来扩展上下文长度,但牺牲了注意力机制的灵活性。

    • 基于检索的方法虽然能引入外部知识,但依赖额外的检索模型,难以实时适应长输入,且与标准注意力机制不兼容。

  2. 提出的方法

    • 作者提出了一种新的方法,通过在注意力机制中引入“landmark”标记,使得模型可以直接访问早期输入块,从而突破上下文长度限制。

    • 模型将输入分割为固定长度的块,每个块对应一个landmark,通过landmark的注意力得分决定是否访问该块,实现对远距离信息的随机访问。

    • 该方法在推理时仅对相关块进行计算,显著降低了计算和内存开销,计算复杂度降低比例等于块长度(例如50倍)。

  3. 方法优势

    • 任意上下文长度推理:无需额外训练,即可处理远长于训练时上下文长度的输入。

    • 降低推理时间和内存消耗:相比固定长度训练模型,节省大量计算资源。

    • 兼容扩展结构:可结合如FAISS等高效近邻检索结构,进一步提升效率。

  4. 实验结果

    • 在从头训练和微调预训练模型(如LLaMA 7B)中均验证了方法的有效性。

    • 微调后模型在32k token的超长上下文中仍能有效检索信息,性能可与基于循环记忆的Transformer-XL模型媲美。

  5. 图示说明

    • 图1对比了标准注意力与引入landmark的注意力机制,展示了landmark如何根据块内相似性动态控制注意力权重,实现对不同块中相同token的差异化处理。

总结:本章提出了一个基于landmark标记的高效扩展Transformer上下文长度的方法,在保持注意力灵活性的同时,显著降低了计算与内存消耗,并在多个实验中验证了其有效性。

3 Methodology

本文的方法论部分主要介绍了一种用于扩展Transformer模型上下文长度的新方法,重点在于如何通过引入“landmark token”(地标标记)和“Grouped Softmax”机制,使模型能够在处理超长序列时,仍然有效利用注意力机制。以下是该部分的总结:


总体思路

  1. 问题背景

    • Transformer在处理长文本时,计算量随上下文长度呈平方增长,难以处理非常长的输入。

    • 但由于注意力权重总和为1,实际上只有少数几个标记的权重较高。因此,可以通过只保留这些高权重的键(keys)来近似全上下文注意力。

  2. 核心思想

    • 将长文本划分为多个块(blocks),并在每个块后插入一个landmark token

    • 每个landmark token的表示作为其所在块的代表向量

    • 在推理时,通过计算当前token与landmark token的注意力,选择出相关性较高的块,从而实现对远程上下文的“随机访问”。


方法详解

1. Landmark Token 的训练

  • 插入方式:每隔 ℓ_block 个token插入一个landmark token。

  • 训练目标:让landmark token的键向量能够代表其所在块的内容。

  • Grouped Softmax机制

    • 将常规token和landmark token分为不同的组,分别进行softmax计算。

    • 每个token的注意力权重是其所在组内softmax的结果,再乘以其块的landmark token的softmax结果。

    • 这样模型必须在当前块和远程块之间做出权衡,从而实现对远程块的注意力控制。

2. 推理过程

  • 输入处理:同样每隔 ℓ_block 个token插入一个landmark token。

  • 分块处理:将长输入划分为多个 ℓ_local 长度的块,逐步处理。

  • 缓存机制

    • 使用缓存存储前面块的键值向量(包括landmark token)。

    • 通过只保留landmark token的缓存,可以大幅节省内存。

  • 注意力检索

    • 每个token在每个注意力层中计算与缓存中所有landmark token的注意力分数。

    • 选择得分最高的 k 个块,将其内容作为额外的上下文。

    • 与本地注意力矩阵拼接后,应用Grouped Softmax,计算加权值向量作为最终表示。

  • 灵活性控制

    • 可以限制每个token或每个注意力头的检索块数量。

    • 例如,合并不同头的注意力得分,或限制每个窗口查询的块数量,提高推理效率。


位置编码处理

  • 问题:Transformer依赖位置编码,但无法外推到训练中未见的长度。

  • 解决方案

    • Stingy Position Mapping(吝啬位置映射)

      • 为每个块分配固定长度的位置空间,保证最近的 k 个块具有唯一位置索引。

      • 旧块的位置映射到固定位置,避免索引冲突。

      • 通过这种方式,模型仍能通过位置信息访问到最近的块。

  • 实现细节

    • 使用Rotary Positional Encoding(旋转位置编码)。

    • 在缓存中存储无位置信息的键向量,在检索时再添加对应位置信息。


与其他方法的对比

  • FAISS等近邻搜索方法

    • 虽然可以高效检索相似向量,但缺乏语义控制。

    • 本方法通过Grouped Softmax直接由模型控制检索,更具语义相关性。

  • 块级检索的优势

    • 比单个token检索更符合传统注意力模式,能保留局部上下文信息。


总结

  • 本文提出了一种基于landmark token和Grouped Softmax的注意力机制,能够在保持模型语义理解能力的同时,显著扩展Transformer的上下文长度。

  • 通过分块、缓存和位置映射的组合,实现了**“随机访问”无限上下文长度**的目标。

  • 该方法在训练和推理过程中均保持高效,具有良好的扩展性和实用性。

3.3 Memory & Computation

本节讨论了所提方法在内存与计算效率方面的优势,主要总结如下:

  1. 训练效率
    所提方法在训练时计算开销极小,几乎没有额外负担。与传统Transformer不同,它不需要维护先前键值缓存(KV-cache),并且训练时的上下文长度与推理时的上下文长度解耦。这意味着无论推理时的上下文长度如何,训练时间仅是常数(𝒪(1)),而传统Transformer的训练时间会随上下文长度二次增长。

  2. 推理效率
    在推理阶段,特别是自回归生成时,传统方法需要对所有先前token计算注意力分数,而本方法只需关注关键标记(landmark tokens)和检索到的块(blocks)。由于每个块的大小固定,因此计算注意力的开销保持不变。虽然查找关键标记的开销线性增长,但其频率为每 ℓblock + 1 个 token 才增加一次,从而整体计算量减少约 ℓblock 倍。例如,当 ℓblock = 50 时,可带来约50倍的效率提升。

  3. 内存优化
    传统Transformer需要缓存所有先前的键值对以进行高效生成,而本方法只需访问关键标记,并将其他块卸载到慢速内存(如CPU),按需加载检索到的块,从而显著减少内存占用。

  4. 扩展性与结合能力
    块的检索可通过更高效的数据结构(如FAISS)进一步优化。此外,本方法可自然结合Flash Attention以进一步降低开销。尽管本文出于实验灵活性考虑未使用该组合,但作者已发布基于Triton的高效实现版本。

综上所述,该方法在训练和推理阶段均实现了显著的计算与内存效率提升,尤其适用于长上下文场景。

4 Experiments

该章节主要探讨了如何通过引入“地标标记(landmark tokens)”来提升Transformer模型在长上下文场景下的语言建模和微调能力。内容分为两个部分:


4.1 语言建模实验

本部分评估了地标标记在两个具有长距离依赖的语言建模任务上的效果:

  • PG-19(英文书籍,37亿token)

  • arXiv数学论文(56亿token)

模型与训练

  • 使用了类似GPT-2的架构,12层解码器,128维嵌入,1024维度嵌入层,FFN隐藏层4096。

  • 采用AdamW优化器,学习率0.002,余弦调度器,权重衰减0.001。

  • 训练时使用混合精度(bfloat16)和梯度累积。

  • 使用地标标记时,这些标记被加入数据集并参与训练,保持批量机制不变。

结果

  • 模型在使用地标标记并检索相关块后,可以获得与Transformer-XL相当的perplexity(困惑度),但计算量(FLOPs)更低。

  • 模型在推理时可以处理比训练时更长的上下文长度(例如2048或4096),性能依然良好。

  • 通过调整检索块的数量(k值)和存储的块数量,可以在性能和效率之间取得平衡。

  • 例如,使用局部上下文长度250和k=2时,模型的性能与上下文长度512相当。

块检索的粒度

  • 每个注意力头甚至每个token都可以检索不同的块(最细粒度)。

  • 减少检索粒度(如固定检索块对所有token或头)会略微影响性能,但仍优于基线模型。

  • 当为所有token使用相同的检索块时,模型性能仅下降0.23点。


4.2 微调预训练模型

本部分探讨了如何使用地标标记对预训练模型(如LLaMA 7B)进行微调,以支持更长的上下文长度。

实验方法

  • 微调LLaMA 7B 15000步,上下文长度为512。

  • 使用RedPajama数据集进行微调。

  • 通过设计特殊格式的提示(prompt)测试模型能否从长文本中准确检索隐藏的“通行密钥(pass key)”。

评估方式

  • 生成包含大量无关文本的长提示,中间嵌入一个随机生成的通行密钥(1-50000之间)。

  • 模型需从长文本中检索该密钥,并在生成的前100个token中正确输出。

  • 比较使用地标标记的模型和原始LLaMA在不同上下文长度下的检索准确性。

结果

  • 原始LLaMA在上下文长度超过2048时检索失败。

  • 使用地标标记的模型在更长的上下文(如32K)中仍能准确检索密钥。

  • 为减少内存使用,作者还引入了将KV缓存卸载到CPU的技术(详见附录G)。


总结

本章节展示了通过引入地标标记,可以有效提升Transformer模型在长上下文语言建模和微调中的性能。模型不仅在更长的上下文中保持良好的语言建模能力,还能高效地检索关键信息,且具备更高的可解释性。此外,通过调整块检索策略,可以在性能与计算效率之间达到良好平衡。

5 Future Work

本章节“未来工作”探讨了几种可能的改进方向和未解决的问题,主要包括以下几点:

  1. 位置编码的外推(Extrapolating Positional Encoding)
    当前模型在处理远超训练时长度的上下文时存在局限,主要在于位置编码无法有效处理超出训练长度的序列。本文提出了一种特殊索引方法结合landmark tokens来解决这一问题,但该方法仅能基于语义而非位置进行注意力分配。虽然这为长上下文建模提供了重要改进,但如果能够实现基于精确索引的注意力机制,性能将进一步提升。目前已有方法限制了远距离token的注意力,这与目标背道而驰。作者在附录中简要探讨了可能的解决方案,但更深入的研究留待未来完成。一旦该方法成熟,可直接与landmark tokens结合,实现任意长度的推理。

  2. 层次化landmark tokens(Hierarchical Landmarks)
    在大规模场景中,landmark tokens可以存储在k近邻数据结构中以提升检索效率和减少内存占用。另一种思路是引入层次结构,高层landmark tokens控制对低层landmark的注意力。作者在附录中尝试加入一种“门控”token,用于决定是否需要检索,或在不同内存层级中实现缓存命中/未命中的判断。该方向仍需进一步探索。

  3. 训练中使用缓存(Training with Cache)
    本文主要采用标准训练方法,虽然推断时使用的检索机制与训练中的softmax机制相似,但由于特殊的索引方式,训练时引入缓存可能带来额外收益。对此类训练方式的探索也被列为未来工作。

总结:本节提出了三个未来研究方向,分别围绕位置编码的外推、landmark tokens的层次化结构以及训练过程中缓存的引入,旨在进一步提升模型在长上下文处理中的性能和效率。

6 Conclusion

本章总结如下:

本文提出了一种新颖的方法,通过注意力机制从记忆中检索相关块,无需依赖递归机制构建记忆。该方法能够直接访问先前的token,从而实现准确的信息检索,避免了传统方法中因递归导致的“遗忘历史数据”问题。实验表明,该方法在性能上可与Transformer-XL等递归模型相媲美,同时计算资源消耗更少。此外,基于注意力的检索过程具有可追踪性和可解释性,有助于理解模型生成输出所依赖的信息。研究结果还表明,该方法能够处理远超训练阶段所见的长上下文,并可通过微调高效集成到现有预训练模型中(如LLaMA 7B),从而提升其检索能力。总体而言,该方法支持对任意长度上下文的高效推理,适用于处理大规模输入和细粒度信息。

Acknowledgment

本章节为致谢部分,作者对Matteo Pagliardini的深入讨论和Olivia Simin Fan在论文初稿阶段的宝贵反馈表示感谢。同时,作者也感谢瑞士国家科学基金会(SNSF)提供的项目资助(资助编号:200020_200342)。

Appendix A Grouped Softmax Example

本附录通过一个示例说明了使用“landmark tokens”(地标token)来实现注意力机制的方法。在该例子中应用了因果掩码(causal mask),索引为2、5和8的token被选为地标token,可通过条件 \(p_i = i\) 来识别。假设索引6处的token的查询向量与所有其他键向量的点积为全1向量。在计算其注意力时,索引8处的地标token被忽略,而其他两个地标token则作为其对应块内token的注意力“门控”。最终,注意力权重会分配到块内的普通token,因此地标token本身的注意力权重始终为零。

Appendix B Dataset Description

本附录描述了两个用于模型训练和评估的数据集:

  1. arXiv Math 数据集
    该数据集来自 proof-pile 项目,包含了 arXiv 上数学类论文的 TeX 文件。数据经过清洗,去除了使用非 utf-8/16/32latin1 编码的文件,并通过一些启发式规则保留结构良好的论文(例如包含章节、节、小节的文件)。最终数据集包含约 56 亿个 token,适用于数学证明相关的训练任务。

  2. PG-19 数据集
    PG-19 是一个大型英文书籍数据集,包含 1919 年前出版的书籍,源自 Project Gutenberg。训练数据集包含约 37 亿个 token。该数据集常用于评估模型在处理长距离上下文关系方面的能力,已有相关研究引用其测试结果。

Appendix C Number of Unique Retrieved Blocks

本附录主要研究在使用PG19验证集进行语言建模时,模型在不同检索灵活性设置下所检索到的唯一缓存块数量的分布情况,以评估检索策略对模型性能和效率的影响。以下是关键内容的总结:

  1. 实验设置

    • 输入按2048个token的长度分段处理,每250个token为一个块。

    • 每次检索前K=4个缓存块,并记录倒数第二个块中检索到的唯一块数量

    • 为保证统计一致性,选择倒数第二块而非最后一块,因为最后一块可能token数较少,影响分析稳定性。

    • 每个样本、每层均单独记录,并绘制不同灵活性下的分布图(图5)。

  2. 关键发现

    • 在最灵活的检索设置下,模型在多数情况下几乎访问了缓存中的所有块(共35个),在分布图中表现为最后一个bin的峰值。

    • 一个有趣的发现是,在柔性检索中,初始层表现出“低块数检索”的峰,这与早期层注意力的局部性(Locality of Attention)一致。

    • 当限制检索灵活性(如只允许不同head检索不同块)时,唯一块数量显著下降(通常低于10个),尽管每个head可检索4个块,理论上最多可访问32个。

    • 这种减少虽然会略微影响模型的困惑度(perplexity),但能够显著提升性能,因为减少了缓存块加载所需的带宽,从而有利于将部分计算卸载到较慢设备上。

  3. 缓解策略

    • 增加检索块的数量(例如从k=2增至k=4)可以在一定程度上补偿灵活性降低带来的困惑度损失。

  4. 进一步优化

    • 当允许不同token检索不同块(但不同head固定)时,也能在保持较低带宽需求的同时,实现较好的困惑度性能。

总体结论:通过调整检索的灵活性(如允许不同token或不同head检索不同块),可以在模型性能与计算效率之间取得平衡。减少检索块的唯一性虽可能影响模型表现,但通过增加检索数量可部分补偿这一损失,同时显著提升系统的带宽效率。

Appendix D Context Miss Token

本节主要介绍了一种新的方法,通过引入“上下文缺失标记”(Context Miss Token, CMT)来增强Transformer模型的无限上下文处理能力。主要内容总结如下:

  1. CMT的定义与位置

    • CMT是一个特殊的标记,始终位于输入的起始位置(位置 -1)。

    • 它的作用是向模型发出信号,表明当前上下文中的信息不足以进行预测,需要从记忆中检索相关的landmarked块。

  2. CMT的训练机制

    • 通过调整landmark tokens的分组方式,将部分landmark tokens的注意力控制权交给CMT。

    • 在训练过程中,使用一种新的分组规则(Grouping Scheme),根据不同的条件将不同token分配到相应的组中,从而控制注意力权重的计算。

    • 注意力权重的计算公式(公式6)中,CMT的激活会影响与控制的landmark tokens的注意力值,从而决定是否需要检索记忆中的内容。

  3. CMT的直观理解

    • CMT类似于缓存未命中(cache miss)的信号,当模型无法从当前的landmark tokens中获得足够信息时,它会被激活。

    • 由于CMT是第一个token,它无法接收其他token的信息,因此其表示是固定的,不依赖于输入内容。

    • CMT的引入允许模型构建多层次的landmark块存储,类似于多级缓存系统,可在不同层次之间进行跳转。

  4. 实验与评估

    • 在PG19数据集上训练包含CMT的模型,并在推理阶段使用CMT来决定是否进行检索。

    • 通过设置不同的阈值,可以忽略某些CMT的注意力得分,从而模拟不进行检索的情况。

    • 实验结果显示,通过CMT可以显著减少检索次数,同时保持较高的语言建模性能。例如,使用0.3的阈值时,可以减少57%的检索次数,而困惑度(perplexity)仅轻微上升。

    • 与不使用CMT的基线模型相比,使用CMT的模型在训练初期效果略有下降,但作者认为可以通过延长训练时间来改善。

总结:CMT是一种有效的方法,用于帮助Transformer模型在长距离上下文中判断是否需要检索记忆中的信息。它通过引入新的注意力控制机制,实现了对landmark结构的增强,并能够在保持性能的同时减少不必要的计算资源消耗。

Appendix E Positional Augmentation

本章《附录E:位置增强》主要讨论了如何通过改进Transformer模型中的位置编码方法,使其能够更好地处理比训练时更长的上下文序列。核心内容总结如下:

  1. 背景与动机

    • 标准Transformer模型难以外推到训练时未见过的位置(位置编码无法泛化到更长的序列)。

    • 现有方法通过限制或衰减远距离位置的注意力来处理长序列,但这与本文的目标相冲突,因为本文希望模型能够访问任意距离的token。

  2. 提出的解决方案:Stingy Mapping 与位置增强

    • 本文提出了一种位置增强方法(positional augmentation):在每个“地标token”之后,随机增加后续token的位置编码值,跳跃幅度在1到p_jump之间。

    • 这种位置跳跃的增强方式,使模型在训练中接触到更长的位置跨度,从而提升其对更长上下文的外推能力。

  3. 实验设置与结果

    • 使用p_jump=100,在PG19数据集上训练语言模型,最大训练上下文长度为512。

    • 实验表明,随着上下文长度的增加,使用位置增强的模型在长度达到1400 token时仍表现出较低的困惑度(perplexity),接近理论估计的1612 token上限。

    • 相比之下,使用标准位置编码的模型在1024 token时性能开始下降。

    • 在训练长度(512 token)上,位置增强模型的性能略低于标准模型,这是由于其学习任务更复杂,可能需要更多训练步长。

  4. 结论与展望

    • 位置增强方法在长上下文任务中表现出良好的潜力,有助于提升Transformer模型的外推能力。

    • 更深入的评估与改进位置编码方案留作未来工作。

总的来说,本章通过引入一种简单但有效的位置跳跃(positional jumps)机制,为Transformer模型处理无限长上下文提供了可行的思路。

Appendix F Additional Extensions and Details

该章节主要补充了关于“地标注意力机制(Landmark Attention)”的扩展细节与实现方式,包括以下三个方面:


1. 掩码语言建模(Masked Language Modeling)

  • 本文主要研究下文预测任务(next-word prediction),但简要探讨了如何将地标注意力机制扩展到掩码语言建模任务中。

  • 与下文预测不同,掩码语言建模需要模型能同时访问前后的上下文,因此不能按顺序从左到右处理输入。

  • 提出了一种解决方案:按层处理整个输入,每层将输入划分为块,每个块单独处理,并从其他块中检索出 top-k 的地标块。

  • 本部分为概念性讨论,实际实验验证留待未来工作。


2. 与 Flash Attention 的结合

  • Flash Attention 是一种基于块的注意力机制,优化了内存访问效率。

  • 可以将地标注意力机制与 Flash Attention 结合,通过设置地标块的频率与 Flash Attention 的块大小相等,实现高效计算。

  • 在这种设置下,GroupedSoftmax 的计算开销很小,因为只需保留当前块和地标块的 softmax 值。

  • 进一步地,地标块可以用于决定在 block-sparse flash attention 中哪些块可以被丢弃。

  • 如果限制每个块内的 token 使用相同的检索块,则 Flash Attention 可在推理阶段直接应用。

  • 实现了这种结合,并通过 Triton 框架验证,成功将 LLaMA7B 的上下文长度从 512 提升至 2048


3. 检索块数量与块大小的权衡

  • 在推理阶段,地标块需要存储在内存中,这会略微增加内存和计算开销(与块大小成反比)。

  • 增大块大小可以减少内存使用检索耗时,但若检索块数 k 固定,则每个块的长度需要减小,以适应模型的最大上下文长度限制。

  • 因此,块大小增加会导致输入块数量增加,从而延长推理时间

  • 图 7 展示了在不同块大小和不同检索块数 k 下的推理时间,结合了地标注意力与 Flash Attention 的高效实现,并限制了块的检索灵活性。


总结

本节详细探讨了地标注意力机制在不同任务和实现方式下的扩展与优化,包括其与掩码语言建模的适配、与 Flash Attention 的高效结合,以及推理中块大小与检索块数的权衡问题。这些扩展为实现更长上下文长度的 Transformer 模型提供了实用的技术思路与实现路径。

Appendix G Offloading KV Cache to CPU

本附录介绍了一种将KV缓存(Key-Value Cache)卸载到CPU以解决大上下文长度(context length)下内存瓶颈的方法。与以往通过复杂方法压缩KV缓存的方案不同,该方法利用“landmark attention”机制,将KV缓存块卸载到CPU存储,并只在需要时将相关块加载到GPU中,同时将所有landmark块保留在GPU内存中,从而实现更低的内存占用。该方法使得LLaMA 7B模型可以支持超过32,000 token的上下文长度。

尽管该方法有效,但频繁的CPU-GPU数据传输会影响推理速度。为此,作者通过限制检索块的数量来减少数据传输量:在头部(head)之间允许选择不同的块,但在每个token之间则保持一致。具体实现方式是在所有token上计算每个块的最高得分,并经过softmax归一化后选择得分最高的前5个块进行加载。

在实验中,该方法在32,070 token的上下文长度下进行了测试,经过微调的LLaMA模型在50个随机生成的提示中实现了98%的pass key检索准确率,证明了其在接近GPT-4上下文长度上的有效性和可行性。