2205.13147_MRL: Matryoshka Representation Learning

总结

总结

  • 这个是我在使用 vllm 部署 embedding 模型时遇到: --hf_overrides {"is_matryoshka": true,"matryoshka_dimensions": [32,64,128,256,512,768,1024]}

Matryoshka Representation Learning (MRL)

  • 核心思想

    • 通过嵌套优化若干低维子向量,在同一个高维表示向量中学习出具有粗到细信息结构的多尺度表示。

DeepSeek 总结

论文核心思想

传统的模型通常只为每个输入(如图片、文本)输出一个单一、固定维度的向量表示(例如,512维或1024维)。这种方法存在一个明显的缺点:在所有下游任务中,无论复杂度如何,都必须使用整个高维向量,计算成本固定且高昂。

Matryoshka Representation Learning 的灵感来源于俄罗斯套娃(Matryoshka Dolls),其核心思想是:

让模型能够同时生成一个“嵌套”的表示序列,从极低维(如16维)到完整高维(如1024维)。 下游任务可以根据自身的精度和速度需求,自由选择合适维度的表示,而无需重新训练模型。

关键技术与方法

MRL的实现方法非常巧妙且简单,主要包含两个部分:

  1. 嵌套标签训练:

    • 在训练过程中(例如,图像分类任务),模型不仅使用完整的表示向量(如d维)来计算损失,还会同时从该表示中顺序地截取前k(例如 k = d, d/2, d/4, ..., 非常小的数)来分别计算分类损失。

    • 这些不同维度的表示共享同一个主干网络,但各自拥有独立的小型分类头(一个简单的线性层或浅层MLP)。

    • 总损失函数是所有这些不同维度损失值的加权和。

  2. 弹性推理:

    • 训练完成后,在推理阶段,用户可以根据需求“撕掉”表示向量的外层,使用更短的向量。

    • 例如,对于需要快速响应的预览或简单检索任务,可以使用前64维;对于需要高精度的最终分类,则使用全部1024维。

    • 这实现了精度与效率的无缝权衡

主要优势与贡献

  1. 效率大幅提升: 在几乎不损失精度的情况下,推理速度显著加快,存储和传输成本显著降低。论文显示,在ImageNet分类任务上,使用仅4.5%的原始向量维度(如从1024维降到46维),仍能保持90%以上的准确率。

  2. 向后兼容: 与现有生态系统完全兼容。任何能够处理d维向量的系统,都可以无缝处理MRL产生的d维向量,同时还能获得所有更小子维度的能力。

  3. 简单通用: MRL是一种通用框架,而非特定架构。它可以轻松地集成到现有的各种模型、任务和损失函数中,包括有监督学习、对比学习(如CLIP) 等。

  4. 强大的低维表示: 由于在训练时显式地优化了低维表示,MRL产生的低维向量比通过PCA等后处理方式从高维向量压缩得到的向量性能要好得多

应用场景

  • 大规模检索系统: 使用低维向量进行初步快速粗排,再用高维向量进行精细重排,极大降低计算负担。

  • 边缘设备部署: 在手机、IoT设备等算力有限的场景下,使用低维向量进行高效推理。

  • 多精度服务: 为付费等级不同的用户提供不同精度的服务(如免费用户使用低维结果,付费用户使用高维高精度结果)。

  • 降低CLIP等双塔模型成本: 论文特别展示了MRL如何显著提升CLIP模型的效率,使其在图像-文本检索任务中快得多。

总结

《Matryoshka Representation Learning》提出了一种极其优雅且实用的表示学习新范式。它打破了“一个表示对应一个维度”的传统思维,通过一次训练获得一套可伸缩的表示,完美地实现了精度与效率的灵活权衡。由于其简单性、有效性和通用性,该工作迅速成为表示学习领域的一个重要里程碑,并被广泛应用于后续研究和工业实践中。

Abstract

Abstract 部分主要介绍了一种新型的表示学习方法——Matryoshka Representation Learning (MRL),它旨在解决传统固定容量表示在面对不同下游任务时的灵活性不足问题。

问题背景

  • 现代机器学习系统中,学习表示(learned representations)是核心组成部分,广泛服务于多种下游任务。

  • 但在训练这些表示时,通常无法提前知道每个下游任务的计算和统计约束条件

  • 传统的固定容量表示(rigid fixed-capacity representations)往往在任务适配性上表现不佳,可能过大或过小,造成资源浪费或性能下降。

提出的问题

因此,作者提出关键问题:

能否设计一种灵活的表示,能根据下游任务的计算资源进行自适应调整?

核心贡献

本文的核心贡献是提出 Matryoshka Representation Learning (MRL),其主要特点包括:

  • 多粒度编码:MRL 将信息以粗粒度到细粒度的方式进行编码,使得单个嵌入(embedding)可以适应不同计算资源的下游任务。

  • 最小修改:MRL 对现有表示学习流程的改动极小,并且在推理和部署阶段不增加额外成本

  • 表示质量保障:MRL 学得的表示在准确性和丰富性上不劣于独立训练的低维表示。

MRL 的优势

MRL 提供了三方面的性能优势:

  1. 更小的嵌入尺寸:在 ImageNet-1K 分类任务中,使用相同准确率的情况下,嵌入尺寸可减少14 倍

  2. 更高的推理速度:在大规模检索任务(ImageNet-1K 和 ImageNet-4K)中,实际速度提升可达 14 倍

  3. 更高的准确性:在长尾小样本分类任务中,精度提升最多可达 2%

同时,MRL 在鲁棒性方面与原始表示相当。

应用广泛性

  • MRL 可以无缝扩展到大规模网络数据集(如 ImageNet、JFT)。

  • 跨多种模态:包括视觉(ViT、ResNet)、视觉+语言(ALIGN)和语言(BERT)。

总结:MRL 是一种高效、灵活且通用的表示学习方法,能够在不同计算资源条件下提供高质量的嵌入表示,适用于多种任务和模态。

1 Introduction

1. 背景与动机

学习到的表示(learned representations)是现实世界机器学习系统中的核心组件。一旦训练完成并冻结,这些 d 维表示可以用于多个下游任务,并且包含了丰富的信息。深度表示的部署通常分为两个步骤:

  1. 特征计算阶段:计算表示本身,其成本较高但是一次性支出;

  2. 表示利用阶段:在下游任务中使用这些表示,其计算成本与表示维度、数据量(N)和标签空间(L)相关。

在大规模(web-scale)场景下,表示的使用成本远高于特征计算成本。现有的固定维度表示缺乏灵活性,难以适应不同任务对精度和计算资源的差异化需求。

此外,人类的感知具有从粗到细的分层结构,但基于梯度的深度学习模型倾向于在整条表示向量中弥散信息,导致难以压缩或选择性使用部分维度。当前的弹性表示方案(如训练多个低维模型、子网络联合优化、后处理压缩等)虽然在一定程度上缓解了问题,但往往面临训练/维护成本高、多轮前向传播、存储开销大、实时特征选择计算复杂或精度显著下降等问题。

2. 本文方法与创新点

为了解决上述问题,本文提出了一种新的表示学习方法 Matryoshka Representation Learning (MRL),其核心思想是通过嵌套优化若干低维子向量,在同一个高维表示向量中学习出具有粗到细信息结构的多尺度表示。

  • MRL 在训练过程中显式优化 O(log d) 个不同维度的表示子向量,这些子向量嵌套在同一个高维表示中,因此得名 Matryoshka(套娃)。

  • MRL 可以适配任何现有的表示学习框架,并扩展到多种标准任务(计算机视觉、NLP 等)。

  • 该方法在部署时具有零额外成本的弹性,根据精度和资源限制动态选择表示维度,从而实现近似最优的精度-计算权衡。

3. 方法优势与应用

MRL 学习的 Matryoshka 表示具有如下优势:

  • 无需额外训练成本:前 m 维表示在精度上与独立训练的 m 维表示相当;

  • 信息递增:随着维度增加,表示信息逐渐丰富,具有从粗到细的层次结构;

  • 灵活性与多保真度:适用于多种部署场景,实现高效的自适应部署。

在实际应用中,MRL 在 大规模分类检索 任务中表现出色:

  • 在 ImageNet-1K 上,通过自适应分类,MRL 能在相同精度下将表示维度减少 14 倍

  • 在自适应检索中,通过多阶段检索策略,MRL 实现了 128 倍的理论 FLOPS 和 14 倍的实际耗时提升,且精度与传统方法相当;

  • MRL 的表示子向量间具有较强的语义相关性,有助于提升长期学习任务(如长尾继续学习)的性能;

  • MRL 的粗到细结构还可以用于分析分类难度和信息瓶颈问题。

4. 主要贡献

本文的主要贡献如下:

  1. 提出 Matryoshka Representation Learning (MRL),用于生成灵活的表示,支持自适应部署(第 3 节);

  2. 在大规模分类与检索任务中,MRL 实现了 14 倍的速度提升且保持高精度(第 4 节);

  3. MRL 可无缝适配多种模态(视觉 - ResNet & ViT、视觉+语言 - ALIGN、语言 - BERT)和网页级数据(如 ImageNet-1K/4K、JFT-300M);

  4. 提供了 MRL 表示在其他下游任务中的进一步分析(第 5 节)。


小结:本文系统性地提出了一种新的表示学习框架 MRL,通过嵌套优化多尺度表示向量,实现表示的弹性与高效部署,为大规模机器学习系统中的精度-资源平衡问题提供了有效解决方案。

3 Matryoshka Representation Learning

这篇论文章节主要介绍了Matryoshka Representation Learning(MRL),这是一种新型的多层次、多粒度的表示学习框架。以下是章节的结构化总结,尽量保留了原文标题结构,并突出重点内容。


目标:

  • 给定一个输入数据点 \( x \in \mathcal{X} \),我们希望学习一个 \( d \) 维的表示向量 \( z \in \mathbb{R}^d \)

  • 对于每一个嵌套维度 \( m \in \mathcal{M} \subset [d] \),表示向量的前 \( m \) 个维度 \( z_{1:m} \in \mathbb{R}^m \) 应该是可迁移、通用的表示

  • 表示学习的多个粒度由集合 \( \mathcal{M} \) 控制,其元素个数通常小于 \( \log(d) \),如从 \( d \) 开始持续减半,直到达到信息瓶颈(低维表示)。

模型结构:

  • 使用深度神经网络 \( F(\cdot; \theta_F): \mathcal{X} \rightarrow \mathbb{R}^d \),参数为 \( \theta_F \)

  • 表示向量 \( z = F(x; \theta_F) \)

  • 每个嵌套维度 \( m \in \mathcal{M} \) 都会与一个独立的线性分类器 \( \mathbf{W}^{(m)} \in \mathbb{R}^{L \times m} \) 进行组合,用于分类任务。

目标函数:

  • 使用多分类交叉熵损失函数 \( \mathcal{L} \),对每个嵌套维度 \( m \) 分别优化损失。

  • 总体目标是最小化加权损失的平均值:

\[ \min_{\left\{\mathbf{W}^{(m)}\right\}_{m\in\mathcal{M}},\ \theta_{F}} \frac{1}{N}\sum_{i\in[N]}\sum_{m\in\mathcal{M}} c_{m} \cdot \mathcal{L} \left( \mathbf{W}^{(m)} \cdot F(x_i; \theta_F)_{1:m} \ ;\ y_i \right) \]
  • 其中,\( c_m \geq 0 \) 是维度 \( m \) 的重要性权重。

  • 默认情况下,所有 \( c_m = 1 \),即所有维度具有相等的重要性。

  • 该目标函数可以用次梯度下降法求解。


MRL-E: 高效版本 MRL

  • 为提高内存效率,MRL 提出了一种变体:MRL-E (Efficient MRL)

  • 它通过权重共享(weight-tying),将所有线性分类器的权重统一为一个大的矩阵 \( \mathbf{W} \in \mathbb{R}^{L \times d} \)

  • 即,每个维度 \( m \) 使用该矩阵的前 \( m \) 列作为分类器:\( \mathbf{W}^{(m)} = \mathbf{W}_{1:m} \)

  • 这种方式显著降低内存消耗,尤其适用于输出空间非常大的任务。

  • 相应的算法细节见附录中的 Algorithm 1 和 Algorithm 2


MRL 在不同学习框架中的适应性

  • MRL 可以无缝适应大多数表示学习框架,包括大规模 Web 数据训练任务。

  • Masked Language Modeling (MLM)

    • MRL-E 是其自然扩展,因为输入嵌入矩阵和分类器权重可以共享。

  • Contrastive Learning(对比学习)

    • MRL 应用于对比的两个嵌入向量。

    • 每个维度的归一化需要独立处理,以获得最佳效果。

    • 更多细节参见 Appendix C(模型训练部分)。


总结(重点内容)

  • MRL 的核心思想:每个嵌套维度的表示向量都应是独立且通用的。

  • MRL 的优势:多粒度表示、提高模型泛化能力、支持多种学习框架(如分类、对比学习、语言建模)。

  • MRL-E 的优势:通过权重共享,显著降低内存消耗,适合大规模应用。


实验结果(简要)

  • 尽管只显式优化了 \( \log(d) \) 个维度,MRL 的表示在非嵌套维度上也具有良好的插值性能

  • 更多实验和消融研究参见第 5 节和附录。


术语对照(中文翻译)

  • Matryoshka Representation Learning (MRL):套娃式表示学习

  • Linear Classifier:线性分类器

  • Empirical Risk Minimization:经验风险最小化

  • Cross-Entropy Loss:交叉熵损失

  • Weight-Tying:权重共享

  • Contrastive Learning:对比学习

  • Masked Language Modeling (MLM):掩码语言建模

  • Information Bottleneck:信息瓶颈


该章节为 MRL 提供了完整的理论框架,并说明了其在不同学习任务中的应用方式,强调了其多粒度、通用表示能力和高效实现,是后续实验和应用的基础。

4 Applications

本节主要介绍了Matryoshka Representation Learning (MRL) 的多种应用,并对学习到的多保真度表示进行了广泛的评估。此外,还展示了 MRL 在大范围部署中的下游应用,包括:

  1. Adaptive Classification(自适应分类,AC)

  2. Adaptive Retrieval(自适应检索,AR)


4.1 表示学习(Representation Learning)

本部分详细说明了 MRL 在多种表示学习设置下的应用,包括:

  • 视觉任务的监督学习:使用 ResNet50 在 ImageNet-1K 和 ViT-B/16 在 JFT-300M 上的训练。

  • 视觉与语言的对比学习:使用 ALIGN 模型(ViT-B/16 视觉编码器 + BERT 语言编码器)。

  • 语言建模:使用 BERT 在英文维基百科和 BooksCorpus 上训练。

MRL 的关键优势在于,在不搜索最优超参数的情况下,使用与独立训练基线相同的超参数,即可实现高效的多保真度表示学习。实验中,ResNet50 输出 2048 维表示,ViT-B/16 和 BERT 输出 768 维表示,而 MRL 显式优化了嵌套维度(如 8, 16, 32, …, 2048)。

重点内容

  • MRL 与 MRL–E(一种变体)模型在多个任务中优于独立训练的低维模型(FF)、降维方法(SVD)、子网络方法(slimmable networks)和随机特征选择。

  • MRL 模型在所有表示维度中保持较高的准确率,同时减少了对高维模型的依赖,降低了训练和部署成本。


4.2 分类(Classification)

本节通过 线性分类(Linear Probe, LP)1-NN(1 近邻) 的方式,评估了 MRL 学习到的表示在分类任务中的质量和容量。

重点内容

  • 图 2 和 图 3 显示:MRL 模型在所有表示维度上的准确率与独立训练的 FF 模型相当,甚至在低维表示中表现更优(高 2%)。

  • MRL 模型在 ViT-B/16 和 ALIGN 模型上的实验也展示了其在大规模数据(如 JFT-300M)上的良好扩展性。

  • MRL 显式优化了 \( O(\log(d)) \) 个嵌套维度,但在所有维度上均表现出插值行为(interpolating),从而实现了灵活的粒度控制。

补充说明

  • 后处理方法(如 SVD、随机特征)和子网络方法在低维表示下准确率下降显著,进一步突显了 MRL 的优势。

4.2.1 自适应分类(Adaptive Classification, AC)

MRL 通过建立模型级联(cascades),在保持高准确率的同时,显著降低了计算成本。

重点内容

  • 使用验证集训练出每个嵌套分类器的阈值,当模型对当前低维表示的置信度不足时,自动切换到更高维表示(如 8→16→32)。

  • 图 7 显示:MRL–AC 模型在 ImageNet-1K 上的准确率与 512 维 FF 模型相当,但平均表示维度仅为 37,比 2048 维基线仅低 0.8%。

  • 随着类别数增加,MRL–AC 的计算效率优势更加明显。


4.3 检索(Retrieval)

本节评估了 MRL 在图像检索任务中的表现,特别是在大规模数据集 ImageNet-1K 和 ImageNet-4K 上的性能。

背景知识

  • 图像检索通常依赖于最近邻搜索(NN)或近似最近邻搜索(ANNS)方法(如 HNSW)。

  • 检索成本随数据库规模线性增长,尤其在 ImageNet-4K 这类大规模数据集上,检索成为计算瓶颈。

重点内容

  • MRL 模型在所有表示维度上均优于基线方法,尤其是在低维表示中表现更好。

  • 图 7 显示:MRL 模型在 mAP@10 指标上显著优于 FF 和 SVD 等后处理方法。

  • MRL 模型支持多粒度检索,无需额外训练多个模型或进行多次前向传播。

4.3.1 自适应检索(Adaptive Retrieval, AR)

AR 通过分阶段检索 + 重排序(re-ranking),在保证准确率的前提下显著降低计算和存储成本。

核心机制

  • 使用低维表示(如 Ds=16)筛选出候选图像(如 K=200),再用高维表示(如 Dr=2048)进行重排序。

  • 图 8 表明:AR 模型在相同准确率下,计算成本可降低 14 倍(理论)和 14 倍(实际)。

  • Funnel Retrieval 提出了一种一致性级联机制,通过逐步缩小候选集和增加表示维度,进一步提升效率。

补充说明

  • AR 和 Funnel Retrieval 方法适用于大规模多阶段检索系统(如搜索引擎、推荐系统)。

  • MRL 与 ANNS 技术(如 HNSW)高度兼容,可进一步提升效率-准确率的平衡。


总结

本章展示了 MRL 在多个任务(分类、检索)和多种模型架构(ResNet50、ViT、BERT)中的广泛应用。MRL 的核心优势在于:

  1. 多保真度表示学习:在不同维度上保持高准确率,减少对高维模型的依赖。

  2. 计算效率:通过自适应机制(AC/AR)显著降低推理成本。

  3. 灵活部署:适用于大规模数据和多阶段检索系统,具备良好的扩展性和实用性。

5 Further Analysis and Ablations

Robustness(鲁棒性)

本节评估了在 ImageNet-1K 上训练的 MRL 模型在 out-of-domain 数据集(如 ImageNetV2/R/A/Sketch)上的鲁棒性,并与 FF 基线模型进行比较。实验结果显示:

  • MRL 模型的分类性能在 ImageNet-A 数据集上比原始表示提高了 0.6%,相对提升了 20%,表明其具有更强的鲁棒性。

  • 在基于检索的鲁棒性评估中,MRL 模型在 mAP@10 指标上优于 FF 基线模型约 3%,表明其在图像检索任务中也更稳健。

  • MRL 模型的零样本鲁棒性与 Wortsman 等人的研究一致。

  • MRL 模型在图像-文本对的余弦相似度方面,正样本对与随机样本对之间的区分度也有所提升。

Few-shot and Long-tail Learning(少样本与长尾学习)

在少样本学习方面,MRL 模型使用最近类中心法(Nearest Class Mean)进行了详尽评估,结果显示:

  • MRL 模型在不同样本数量和类别数量下的性能与 FF 表示模型相当。

在长尾学习框架 FLUID 中,MRL 模型表现出独特的优势:

  • MRL 模型在长尾分布中新增类别的准确率高出其他模型约 2%,且不会降低其他类别的准确率。

  • 对于预训练类,不同维度表示的准确性差异较小,这表明在样本较少的情况下,高维表示可能更有效。

  • 这些结果进一步支持了不同任务对模型容量需求不同的观点。

图 9(Grad-CAM 分析)

Grad-CAM 分析展示了 MRL 模型在不同维度下的预测行为:

  • 在低维度(如 8 维)时,模型容易受到场景中其他相关对象的干扰。

  • 随着维度增加,模型能更准确地聚焦关键区域(如正确识别“sweatshirt”而非“sunglasses”)。

  • 这表明 MRL 模型在低维度下虽然表现不佳,但不会完全失效,具备“优雅失败”的特性,并且不同维度之间可能存在有用的信息差异。

Disagreement across Dimensions(维度间差异)

本节探讨了 MRL 模型在不同维度下的性能差异:

  • 通常情况下,随着表示维度的增加,准确率逐渐提高。

  • 但某些实例在低维度下反而表现更好,这表明 MRL 模型存在维度间性能的“不一致”。

  • 理想的路由策略(将样本分配到合适维度)可使分类准确率提高最多 4.6%。

  • 低维模型在区分同级类别或处理多目标场景时表现不佳。

  • 图 9 和相关附录展示了这类情况的典型案例,强调了 MRL 模型在信息瓶颈分析中的潜力。

Superclass Accuracy(超类精度)

本节研究了不同维度下的超类分类表现:

  • 随着信息瓶颈变紧(表示维度变低),细粒度分类的准确率下降迅速,但超类分类的准确率下降较缓。

  • MRL 模型在不同维度下均保持较高的超类分类准确率,表明其能捕捉关键语义信息,可作为粗分类任务的高效方法。

  • 图 11 展示了 MRL 模型在不同超类上的准确率趋势,表明在某些超类中(如“garment”),从低维到高维的表示转换能显著提升分类性能。

  • 对于视觉区分明显的类(如“oscine (songbird)”),即使在低维度下也能取得良好的类间区分效果。

5.1 Ablations(消融实验)

本节进行了多项消融实验,验证了 MRL 框架的设计选择:

  • 微调可行性:MRL 可以通过对现有预训练模型进行低成本的微调实现,表明其实用性强,易于部署。

  • 损失权重优化:适当调整嵌套损失的权重可以提升低维表示的性能,而不会牺牲高维表示的准确性。

  • 维度选择策略:实验验证了 MRL 不使用极低维度作为初始粒度的合理性,并支持使用对数间隔代替均匀间隔,以更高效地覆盖准确率提升区域。

  • 检索性能饱和:随着表示维度和检索长度的增加,检索性能在一定阈值后趋于饱和,表明存在最优的检索配置。


总结

第五章通过多个角度对 MRL 模型进行了深入分析与实验验证,重点包括:

  1. 鲁棒性:MRL 模型在多种 out-of-domain 场景下表现优越,尤其在 ImageNet-A 和图像检索任务中。

  2. 少样本与长尾学习:MRL 模型在样本有限和长尾分布场景下展现出良好的适应性。

  3. 维度差异:MRL 模型在不同维度下的表现存在差异,低维模型可能在某些场景下表现更好,支持信息瓶颈分析。

  4. 超类分类:MRL 模型在超类分类任务中保持良好的性能,表明其能有效捕捉语义结构。

  5. 消融实验:验证了 MRL 的设计选择(如维度选择、微调策略等)的有效性,并提供了实用部署建议。

整体来看,MRL 模型不仅在性能上具有竞争力,还在适应性、鲁棒性和可解释性方面展现出独特价值,为多维度表示学习提供了新的研究方向。

6 Discussion and Conclusions

本节对前一节的实验结果进行了深入分析,并提出了未来可能的研究方向。同时也对 MRL 方法的核心思想和实验成果进行了总结。


1. MRL 的潜在改进方向

实验结果揭示了 MRL 的一些潜在弱点,这些方向可以作为未来研究的重点:

  • 1. 嵌套损失的权重优化
    如何在准确率与效率之间找到帕累托最优(Pareto optimal)的平衡,是一个值得研究的方向。可能的解决方法之一是借鉴任意时间神经网络(anytime neural networks)中的自适应损失平衡机制 [[41]]。

  • 2. 多保真度下的损失设计
    在不同嵌入维度下使用不同的损失函数,可以更针对性地优化特定应用场景的需求。例如,888 维可以注重高召回率(high recall),而 2048 维则可以提高鲁棒性。

  • 3. 学习可微的搜索数据结构
    在 Matryoshka 表示之上学习一种可微的 k-d 树(differentiable k-d tree)等数据结构,可以提升对大规模数据集和表示的感知能力,从而实现更高效的检索。

  • 4. 多目标 MRL 与端到端可学习搜索结构的联合优化
    将 MRL 与端到端可学习的搜索结构联合优化,有望在大规模网络搜索应用中实现数据驱动的自适应检索,从而满足实际部署需求。


2. MRL 方法的总结与贡献

MRL(Matryoshka Representation Learning)是一种灵活的表示学习方法,能够在单个嵌入向量中编码多粒度的信息。这使得 MRL 能够根据下游任务的统计复杂性计算资源进行自适应调整。

通过实验验证,MRL 能够用于大规模自适应分类自适应检索任务。在标准基准测试中,MRL 在平均使用14× 更小的表示尺寸的情况下,达到了与固定特征基线模型相当的准确率。

此外,基于 Matryoshka 表示的自适应初筛与重排序系统,在保证与基线相当的 mAP@100 表现的同时,将计算代价降低了 128×实际运行时间提速 14×。这些结果展示了 MRL 在资源受限环境中的显著优势。

最后,大多数现有的模型推理与向量搜索的效率优化技术,与 MRL 是互补的,能够进一步提升其在极端计算环境中的部署能力。

Acknowledgments

本研究得到了许多人的帮助与支持。作者对Srinadh Bhojanapalli、Lovish Madaan、Raghav Somani、Ludwig Schmidt和Venkata Sailesh Sanampudi表示感谢,感谢他们提供了有益的讨论和反馈。Aditya Kusupati特别感谢Tom Duerig和Rahul Sukthankar的支持。

论文中部分大规模实验工作得到了Google Cloud和Google Research提供的研究GCP信用额度的支持,这一点是实验部分的重要资助来源,应予以强调。

Gantavya Bhatt的部分研究得到了CONIX Research Center的资助,该中心是JUMP计划的六个中心之一,而JUMP是Semiconductor Research Corporation (SRC) 在DARPA赞助下的一个研究计划。

Sham Kakade的研究得到了美国国家科学基金会(NSF)CCF-1703574和美国海军研究办公室(ONR)N00014-22-1-2377项目的资助。Ali Farhadi的研究则部分由NSF项目IIS 1652052和IIS 17303166、DARPA项目N66001-19-2-4031和W911NF-15-1-0543,以及来自Allen Institute for Artificial Intelligence的捐赠支持。这些资助来源体现了本研究的广泛支持,尤其在资金和技术方面的协助对研究的开展至关重要。

Appendix A Code for Matryoshka ​Representation ​Learning

本附录提供了用于在 ImageNet-1K 数据集上训练监督式 ResNet50-MRL 模型的代码示例。该代码可作为一个模板,用于将 MRL 扩展到任何领域。


算法 1:Matryoshka 交叉熵损失的 PyTorch 实现

该算法定义了一个自定义的交叉熵损失函数 Matryoshka_CE_Loss,用于 MRL 模型中的多层分类损失计算。

重点内容:

  • 类定义

    • Matryoshka_CE_Loss 继承自 nn.Module

    • 初始化时接收 relative_importance 参数,通常设置为全 1,表示各层具有相同的重要性。

    • 使用标准的 CrossEntropyLoss 作为基础损失函数。

  • forward 方法

    • 接收 outputtarget,其中 output 是多个不同嵌套层级的模型输出。

    • 遍历每个输出层,将当前层的损失乘以其相对重要性后累加。

    • 最终返回所有层级损失的总和。

不重要内容:

  • 代码结构较为直接,未涉及复杂操作,重点在于多层损失的加权求和。


算法 2:MRL 线性层的 PyTorch 实现

该算法实现了 MRL 模型中的线性分类器部分,支持两种模式:标准 MRL 和高效 MRL(MRL-E)。

重点内容:

  • 类定义

    • MRL_Linear_Layer 继承自 nn.Module

    • 接收 nesting_list(嵌套层级配置)、num_classes(类别数)和 efficient(是否启用高效模式)等参数。

    • 如果不使用高效模式,则为每个嵌套层级创建独立的线性分类器。

    • 如果使用高效模式(MRL-E),则只创建一个线性分类器,用于所有嵌套层级的共享权重。

  • forward 方法

    • 接收输入特征 x,逐层计算每个嵌套层级的输出(logits)。

    • 如果启用了高效模式,则使用一个分类器和矩阵乘法计算所有层级的输出,减少计算开销。

    • 返回所有嵌套层级的输出 logits。

不重要内容:

  • 代码结构清晰,主要关注线性分类器的构造和前向传播的实现。


总结

这两部分代码分别实现了 MRL 中的损失计算和分类器构建,是 MRL 模型的核心组件。

  • Matryoshka_CE_Loss:用于对多层级输出进行加权损失计算。

  • MRL_Linear_Layer:实现了嵌套层级下的分类器,支持标准模式和高效模式(MRL-E)。

代码设计灵活,便于扩展和应用到其他任务或数据集。

Appendix B Datasets

ImageNet-1K

  • 重点内容:包含 1,281,167 张带标签的训练图像和 50,000 张带标签的验证图像,分布于 1,000 个类别中。

  • 图像处理:图像使用了 FFCV 中描述的标准处理流程进行转换。

ImageNet-4K

  • 重点内容:从 ImageNet-21K 中选取了 4,202 个与 ImageNet-1K 无重叠的类别,每个类别包含至少 1,050 张图像。

  • 数据划分:每个类别包含 1,000 张训练图像和 50 张查询/验证图像,总计约 420 万训练图像和 20 万验证图像。

  • 数据发布:将发布用于构建 ImageNet-4K 的图像列表。

JFT-300M

  • 数据规模:包含 3 亿张图像,覆盖 18,291 个类别。

  • 特点:是一个多标签数据集,适用于大规模图像任务。

ALIGN

  • 数据形式:包含 18 亿对图像和文本对。

  • 特点:是一个大规模、带噪声的图像-文本数据集,适用于多模态学习。


ImageNet 鲁棒性数据集(ImageNet Robustness Datasets)

这些数据集用于评估 MRL 模型的鲁棒性。

ImageNetV2

  • 构成:10,000 张图像,从 ImageNet-1K 的 1,000 个类别中各选 10 张。

  • 特点:采集于 ImageNet 原始构建之后十年,用于测试模型的泛化能力。

ImageNet-A

  • 重点内容:7,500 张现实世界中的图像,来自 ImageNet-1K 的 200 个类别。

  • 特点:图像经过对抗性筛选,用于测试模型在现实干扰下的鲁棒性。

ImageNet-R

  • 重点内容:30,000 张艺术风格的图像,覆盖 ImageNet-1K 中的 200 个类别。

  • 特点:图像风格多样,用于测试模型对艺术化图像的识别能力。

ImageNet-Sketch

  • 重点内容:50,000 张草图,均匀分布在 ImageNet-1K 的 1,000 个类别中。

  • 特点:测试模型对草图的识别能力,特别是在抽象表示下的表现。

ObjectNet

  • 数据规模:50,000 张图像,覆盖 313 个对象类别,每个类别约 160 张图像。

  • 特点:图像多样,用于评估模型对常见物体识别的鲁棒性。

Appendix C Matryoshka Representation Learning Model Training

本节介绍了使用 Matryoshka 表示学习(MRL)方法训练不同深度学习模型的配置与实验设置,主要包括 ResNet50、ViT-B/16、ALIGN 和 BERT-Base 模型。


ResNet50–MRL 模型训练

我们使用 FFCV 提供的高效数据加载器来训练所有 ResNet50–MRL 模型。训练配置基于 FFCV 项目中的 <rn50_40_epochs.yaml> 文件。

定义了几种模型变体如下:

  • MRL:完全版的 MRL 模型,使用 MRL_Linear_Layer(efficient=False) 替换 ResNet50 的全连接层。

  • MRL–E:高效版的 MRL 模型,使用 MRL_Linear_Layer(efficient=True) 替换全连接层。该模型在保持表示能力的同时提升了计算效率。

  • FF–k:普通全连接层模型,用 torch.nn.Linear(k, num_classes) 替换全连接层,其中 k ∈ [8,16,32,64,128,256,512,1024,2048]。我们将这些模型简称为 FF,并用 k 表示表示的维度大小。

训练参数如下:

  • 学习率:0.475,采用周期性学习率调度。

  • 批大小:每块 GPU 上使用 256。

  • 优化器:使用 SGD,动量为 0.9,权重衰减为 1e-4。

  • 硬件配置:使用 2x A100 NVIDIA GPU 进行训练,原始 FFCV 基准使用的是 8x A100,因此学习率进行了 0.25× 缩放以适应硬件配置。

我们的代码对 FFCV 提供的训练流程进行了最小程度的修改,以实现 Matryoshka 表示的训练。


ViT-B/16 模型训练

  • 数据集:JFT-300M。

  • 硬件:使用 8x8 云 TPU 节点。

  • 框架:TensorFlow。

  • 批大小:128。

  • 训练步数:300,000 步。

  • 优化器:Adafactor。

  • 学习率:从 1e-3 开始,线性衰减。


ALIGN 模型训练

  • 硬件:8x8 云 TPU 节点。

  • 框架:TensorFlow。

  • 批大小:每 TPU 上使用 64。

  • 训练步数:1,000,000 步。

  • 优化器:Adafactor。

  • 学习率:从 1e-3 开始,线性衰减。


BERT-Base 模型训练

  • 数据集:英文维基百科和 BookCorpus。

  • 硬件:4x4 云 TPU 节点。

  • 框架:TensorFlow。

  • 总批大小:1024。

  • 优化器:AdamW。

  • 学习率:从 1e-4 开始,线性衰减。

  • 训练步数:450,000 步。


一致性处理说明

在所有配置中,如果 FF 实现中对最终表示进行了归一化处理,那么 MRL 模型也会在每个嵌套维度上应用同样的归一化操作,以保证实验的公平性与可比性。


总结重点:

  • MRL 模型:包括标准版(MRL)和高效版(MRL–E),以及一系列 FF 模型(不同维度的全连接层模型)。训练参数统一,便于对比。

  • 不同模型的训练设置:ResNet50、ViT、ALIGN 和 BERT 均根据其结构和任务特点配置了不同的硬件、框架和优化策略。

  • 归一化处理:为了公平比较,在 FF 模型归一化时,MRL 模型也同步归一化每个嵌套维度。

这些训练配置和实现细节为后续的实验分析和模型评估提供了坚实的基础。

Appendix D Classification Results

表格 1:ResNet50 MRL 模型在 ImageNet-1K 上的 Top-1 分类准确率(%)与基线模型对比

本节主要对比了 ResNet50-MRL 模型(包括 MRL 和 MRL-E)与其他基线模型在不同表示维度下的分类性能。主要的基线模型包括:

  • FF:使用 FF-k 模型,针对 k ∈ {8, …, 2048}

  • SVD:对 FF-2048 的分类层进行低秩近似,秩为 1000

  • Rand. LP:使用随机特征的线性分类器

  • Slim. Net:使用预训练的可变宽度神经网络,在不同宽度下进行测试(25%、50%、75% 和全宽度)

重点内容

  • 在较低维度(d ≤ 128)时,MRL 明显优于所有基线模型,这表明传统预训练模型缺乏多保真度的表示能力,无法在低维空间中学习到有效的分类边界。

  • 从表中可以看出,随着维度的增加,MRL 与 FF 模型的性能差距逐渐缩小,但在低维时 MRL 优势显著。

表格 2:ResNet50 模型在 ImageNet-1K 上的 1-NN 分类准确率(%)

本节展示了 MRL 模型在使用 1-NN 分类器时的性能表现,并与以下基线进行比较:

  • Rand. FS:从 FF-2048 中随机选择 m 维进行分类

  • FF + SVD:对 FF-2048 表示进行 SVD 降维

  • FF + JL:根据 Johnson-Lindenstrauss 引理对 FF-2048 进行随机投影

  • Slimmable Net:使用可变宽度网络进行 1-NN 分类

重点内容

  • MRL 模型在所有维度上的 1-NN 分类准确率均优于其他基线,特别是在低维度(例如 8、16、32 维)下表现显著优越。

  • 可变宽度网络和基于随机特征选择的模型在低维下性能很差,说明它们未显式训练出 MRL 所具备的多保真度表示能力。


D.1 自适应分类(MRL–AC)

核心思想

  • 使用 MRL 表示的多保真度特性,在分类任务中自适应地选择最小的表示维度,以在保持高分类准确率的同时减少计算开销。

  • 通过训练一个基于预测置信度的策略,决定是否增加表示维度。

重点内容

  • 在 40,000 张验证图像上测试了 MRL–AC 的性能,结果表明:

    • 平均预期表示维度约为 37,即可达到 76.3% 的准确率,比 FF-512 模型小 14 倍。

    • 即使将所有表示维度加权求和,MRL–AC 的预期维度约为 62,仍比基线模型高效 8.2 倍。

    • MRL–E 能在不显著降低准确率的情况下减少计算开销。


D.2 JFT、ALIGN 和 BERT

本节探讨了 MRL 在不同大规模数据集和模型上的可扩展性,包括 JFT、ALIGN 和 BERT。

重点内容

1. ALIGN 和 JFT-ViT

  • MRL 模型在 ALIGN 和 JFT-ViT 上均显著提升了 k-NN 分类准确率,尤其是在低维表示下。

  • 表格 4 显示,MRL 模型在多个维度下的 Top-1 和 Top-5 分类准确率均优于原始模型。

  • 表格 5 显示,MRL 模型在未显式训练的插值维度上也能保持较高的分类性能,表明其具有良好的模型泛化能力。

  • 表格 6 显示,MRL 显著提升了图像-文本嵌入之间的余弦相似性跨度,说明其能更好地区分正样本与随机样本。

2. BERT 的 MLM 任务

  • 将 MRL 应用于 BERT 的掩码语言建模任务,结果表明其表现接近原始 BERT 模型(FF 表示),且在大多数维度上差距不超过 0.5%。

  • 这表明 MRL 可以扩展到自然语言处理任务,并为大规模自适应文档检索提供潜在支持。


总结

本附录主要验证了 MRL(Matryoshka Representation Learning)在多个模型和任务上的有效性,包括:

  • 在低维表示下显著优于传统基线模型

  • 支持自适应分类,降低计算开销而不显著损失准确率

  • 可扩展到大模型(如 JFT-ViT)和跨模态模型(如 ALIGN)

  • 在自然语言处理(如 BERT)任务中也表现出良好的性能

这些实验结果显示,MRL 是一种具有广泛适用性和高效性表示学习方法。

Appendix E Image Retrieval

实验设置与方法

本节评估了 Matryoshka Representation Learning(MRL) 模型在图像检索任务中的性能。评估数据集包括:

  • ImageNet-1K:训练集(训练分布)

  • ImageNetV2ImageNet-4K:域外数据集(out-of-domain)

使用的模型是 ResNet50,并评估其在多种不同表示维度下的性能。为了进行检索,作者使用了以下方法:

  1. 数据库与查询集生成

    • 使用标准的 PyTorch 前向传播生成数据库和查询集。

    • 表示向量的大小由 \( D_s \) 决定,其中数据库和查询集分别为 \( [N, D_s] \)\( [Q, D_s] \)

    • 通过 k-NN 检索生成一个大小为 \( [Q, k] \) 的邻近样本集。

  2. 评估指标

    • 采用 mAP@kP@k 作为评估指标,其中:

      • \( P@k = \frac{correct\_pred}{k} \)

      • \( correct\_pred \) 是在查询集中正确检索到的近邻样本数量的平均值。

  3. 检索方法

    • 使用 FAISS 库进行高效的相似性搜索。

    • 精确搜索(exact search):使用 faiss.IndexFlatL2 进行 L2 距离计算。

    • 近似搜索(approximate search):使用 faiss.IndexHNSWFlat(HNSW 算法),在 CPU 上运行。

    • 向量在构建索引和搜索前进行了 L2 归一化

  4. 实验参数

    • HNSW 参数 \( M = 32 \),简称 HNSW32。

    • 精确搜索索引加载到 GPU,以加快检索速度。

    • 表中展示了不同方法在索引构建时间和索引大小上的比较(见表 [20])。

实验结果概述

Table 8: ImageNet-1K 检索结果

  • FF 模型(直接在特定维度训练的模型)在 \( D_s = 8 \) 时性能较差,但在其他维度表现良好。

  • MRL 模型在几乎所有 \( D_s \) 维度下都优于 FF 模型,尤其是在 \( D_s \leq 32 \) 时表现显著更好。

  • MRL–E 模型(可能为 MRL 的变种)在 \( D_s = 8 \) 时性能不如 FF,但在其他维度表现良好。

  • 插值维度(如 12、24、48 等)的性能展示表明,MRL 模型在训练未明确指定的维度也能保持良好的检索性能。

Table 9: ImageNetV2 检索结果

  • MRL 模型在所有 \( D_s \) 维度下均优于 FF 模型。

  • MRL–E 模型在 \( D_s = 8 \) 时仍劣于 FF,但在其他维度表现优异。

  • 表明 MRL 学习的表示在域外数据集中也具有良好的鲁棒性。

Table 10: ImageNet-4K 检索结果

  • 主要分析了 MRL 模型的性能。

  • 随着 \( D_s \) 增大,MRL 模型的性能逐步提高,显示出其在高维表示下的优势。

  • 插值维度下的表现进一步验证了 MRL 模型在不同维度之间的性能连续性。

结论

  • MRL 模型在多个数据集(包括训练集和域外数据集)中展示了优于传统 FF 模型的图像检索性能。

  • 特别是在低维表示(\( D_s \leq 32 \))下性能显著优越,表明 MRL 在压缩模型和保持性能之间取得了良好的平衡。

  • MRL 模型在插值维度下的表现也表明其学习的 Matryoshka 表示具有良好的泛化能力,能够适应任意维度的检索任务。

  • 实验方法(如 FAISS、L2 归一化、HNSW)的选择有效提升了检索效率和准确性。

Appendix F Adaptive Retrieval

概述

本节讨论了自适应检索(Adaptive Retrieval)方法,旨在在较高维度特征 \( D_s \) 下获得与低计算成本相近的性能。由于 k-NN 检索的时间复杂度通常为 \( O(d) \),其中 \( d = D_s \),在较大 \( D_s \) 下检索成本会显著增加。例如,\( D_s = 2048 \) 相比 \( D_s = 8 \),理论计算成本会增加 256 倍。因此,使用低维 \( D_s \) 特征进行初步检索,再使用高维 \( D_r \) 特征进行重新排序,可以显著降低计算成本,同时保持检索性能。

自适应检索方法

自适应检索的基本思路是:

  • 第一步:使用较小维度 \( D_s \) 的特征进行初步检索,得到一个 k 个最近邻的短列表。

  • 第二步:使用更高维度 \( D_r \) 的特征对这个短列表进行重新排序(re-ranking),以提高检索的准确性。

在 ImageNet-1K 和 ImageNet-4K 上的实验表明:

  • 在 ImageNet-1K 中:使用 \( D_s = 16 \) 进行初步检索,再使用 \( D_r = 32 \) 重新排序,可以获得与 \( D_s = 2048 \) 相当的性能,但计算成本(MFLOPs)降低了 128 倍。

  • 在 ImageNet-4K 中:使用 \( D_s = 64 \) 进行初步检索,再使用 \( D_r = 128 \) 重新排序,可以获得与 \( D_s = 2048 \) 相当的性能,但计算成本降低了 32 倍。

这说明自适应检索是一种有效的降低计算成本的策略,同时保持了检索性能。

实验结果

表 11(ImageNet-1K)与 表 12(ImageNet-4K)

这两张表格展示了不同 \( D_s \)\( D_r \) 组合下的性能指标(Top-1、mAP@10、mAP@25 等)和计算成本(MFLOPs)。表格中,加粗的数值表示其性能接近于不使用重新排序时所能达到的最大性能值。

通过分析可以看出,随着 \( D_s \)\( D_r \) 的增加,性能指标普遍提升,但计算成本也随之上升。自适应检索通过选择合适的 \( D_s \)\( D_r \),在计算成本较低的情况下实现了接近高维检索的性能。

Funnel Retrieval(漏斗检索)

Funnel Retrieval 是一种更进一步的策略,通过级联重新排序(rerank cascade)和逐步缩短短列表长度(shortlist cascade)实现逐步提升检索质量。具体步骤如下:

  1. 初步检索:使用 \( D_s \) 进行初步检索,得到一个较长的短列表。

  2. 多次重新排序:在逐步增加 \( D_r \) 的同时,逐步缩短短列表的长度,形成类似“漏斗”的结构。

实验结果

  • 在 ImageNet-1K 中:使用 \( D_s = 16 \) 进行 Funnel Retrieval,可以达到与 \( D_s = 2048 \) 相当的 Top-1 准确率,但计算成本降低了 128 倍。

  • 在 ImageNet-4K 中:使用 \( D_s = 32 \) 进行 Funnel Retrieval,可以达到与 \( D_s = 2048 \) 相当的 Top-1 准确率,但计算成本降低了 64 倍。

通过 Funnel Retrieval,可以更高效地利用不同维度的 Matryoshka 表示,实现高质量的检索。

结论

本节展示了自适应检索和 Funnel Retrieval 的有效性。它们通过合理使用不同维度的特征表示,在显著降低计算成本的同时,保持了较高的检索性能。这些方法为高效、智能的检索系统设计提供了重要参考。

Appendix G Few-shot and Sample Efficiency

本节比较 MRL、MRL–E 和 FF 在不同基准上的表现,以观察表示维度大小对样本效率的影响。分类任务中使用了 Nearest Class Means 方法,该方法在少样本设置下已被证明有效。


ImageNetV2

在 ImageNetV2 上,我们评估了模型在 n 射 k 类(n-shot k-way)任务中的表示性能。ImageNetV2 是一个常用于评估模型对自然分布变化鲁棒性的数据集。我们分别在传统的 10 类(小规模)和 1000 类(大规模)设置中进行实验,测试 n ∈ {1,3,5,7,9} 的样本数,其中 n 最大为 9 是因为每类仅有 10 张图像。

重点结果:

  • MRL 和 FF 在所有表示大小和样本数下的性能基本一致。

  • 随着样本数减少,达到最佳准确率所需的表示维度也减小。例如,1 射时,32 维表示的准确率与 2048 维相当。

  • 当样本数减少时,表示在较低维度时即可达到性能饱和。

表格 15 显示了 MRL 和 FF 在 1000 类设置下的准确率对比,证实了上述结论。


FLUID

在长尾分布设置中,我们在 FLUID 数据集上评估 MRL 的性能。FLUID 包含预训练类(Pretrain)和新类(Novel),并分为头部(>50 例)和尾部(<50 例)。

重点结果:

  • MRL 在尾部新类上的准确率比基线高约 2%。

  • 对于预训练类,低维和高维表示的准确率差异不大。例如,64 维 MRL 在预训练头部类上的准确率比 2048 维低约 1%。

  • 在尾部类中,高维表示对准确率提升明显,例如尾部新类中,64 维的准确率为 6.22%,而 2048 维为 12.88%。

  • 作者认为,少数样本下更高维的表示有助于区分类别。

表格 16 显示了不同表示大小下各类别的准确率,进一步证实了上述结论,尤其是 MRL 在尾部类上的优势。


总结

  • 样本效率方面:MRL 和 FF 在 ImageNetV2 上表现一致,且低维表示在少样本设置下也能达到良好性能。

  • 长尾分布任务:MRL 在尾部新类上优于基线,在头部类上表现稳定,显示其在不同难度任务中的适应性。

  • 表示维度的作用:更高维表示在样本少、类别区分度低的情况下更有优势。

本节通过多个实验验证了 MRL 在表示学习中的有效性,尤其是在少样本和长尾分布任务中,具备良好的样本效率和泛化能力。

Appendix H Robustness Experiments

表 17:在 out-of-domain datasets 上的 Top-1 分类准确率(%)

本部分测试了 Matryoshka Representation Learning(MRL)模型在出域数据集(ImageNet-V2/R/A/Sketch)上的鲁棒性,并与 FF(固定深度)基线模型 进行了比较。需要注意的是,这些结果没有在这些数据集上进行微调

Rep. Size(表示维度)

FF

MRL–E

MRL

FF

MRL–E

MRL

FF

MRL–E

MRL

FF

MRL–E

MRL

FF

MRL–E

MRL

8

65.86

56.92

67.46

54.05

47.40

55.59

24.60

22.98

23.57

2.92

3.63

3.39

17.73

15.07

17.98

16

73.10

72.38

73.80

60.52

60.48

61.71

28.51

28.45

28.85

3.00

3.55

3.59

21.70

20.38

21.77

32

74.68

74.80

75.26

62.24

62.23

63.05

31.28

30.79

31.47

2.60

3.65

3.57

22.03

21.87

22.48

64

75.45

75.48

76.17

63.51

63.15

63.99

32.96

32.13

33.39

2.87

3.99

3.76

22.13

22.56

23.43

128

75.47

76.05

76.46

63.67

63.52

64.69

33.93

33.48

34.54

2.81

3.71

3.73

22.73

22.73

23.70

256

75.78

76.31

76.66

64.13

63.80

64.71

34.80

33.91

34.85

2.77

3.65

3.60

22.63

22.88

23.59

512

76.30

76.48

76.82

64.11

64.09

64.78

35.53

34.20

34.97

2.37

3.57

3.59

23.41

22.89

23.67

1024

76.74

76.60

76.93

64.43

64.20

64.95

36.06

34.22

34.99

2.53

3.56

3.68

23.44

22.98

23.72

2048

77.10

76.65

76.95

64.69

64.17

64.93

37.10

34.29

35.07

2.93

3.49

3.59

24.05

23.01

23.70

重点分析:

  • MRL 模型在各种表示维度下表现优于或与 FF 模型相当,尤其在 ImageNet-A 数据集上表现出显著提升

  • 在 ImageNet-R 和 ImageNet-Sketch 上,MRL 表现略逊于 FF,但仍保持稳定,表明其具备一定的鲁棒性。

  • 随着表示维度的增加,MRL 的性能逐渐提升,最终在高维下接近 FF 模型的最佳表现,说明 MRL 的多粒度表示能力有助于鲁棒学习


表 18:ALIGN-MRL 模型在不同数据集上的零样本 Top-1 分类准确率(%)

该部分评估了 ALIGN-MRL 模型在多个数据集(包括 ImageNet-V1/V2/A/R 和 ObjectNet)上的零样本分类准确率,以进一步测试其鲁棒性和泛化能力

Rep. Size(表示维度)

V1

V2

A

R

ObjectNet

12

30.57

23.98

14.59

24.24

25.52

24

45.64

37.71

22.75

46.40

35.89

48

53.84

46.16

28.88

60.71

42.76

96

58.31

51.34

33.21

70.12

45.20

192

60.95

53.56

36.10

74.41

48.24

384

62.06

54.77

37.95

76.51

49.10

768

62.26

55.15

37.84

76.73

49.26

Baseline(基线)

66.39

59.57

39.97

80.49

51.60

重点分析:

  • 随着表示维度的增加,模型在所有数据集上的性能稳步提升,表明 MRL 的多粒度表示有助于模型学习更具泛化能力的特征。

  • 与基线模型相比,ALIGN-MRL 在 ImageNet-R 上的性能接近,但在 ImageNet-V1 和 ObjectNet 上仍有提升空间。

  • 在 ObjectNet 等更具挑战性的数据集上,ALIGN-MRL 的准确率稳定提升,说明其具备一定的零样本迁移能力。


总结

本附录通过多个实验验证了 Matryoshka Representation Learning(MRL)模型 的鲁棒性,主要结论如下:

  1. MRL 模型在出域数据集(ImageNet-V2/R/A/Sketch)上表现出良好的泛化能力,尤其在 ImageNet-A 上优于 FF 模型。

  2. ALIGN-MRL 模型在零样本任务中也展现了良好的性能,随着表示维度的增加,模型在多个数据集上持续提升。

  3. MRL 的设计能够支持多粒度表征学习,从而提升模型在多样化任务中的鲁棒性和泛化能力。

这些实验结果进一步验证了 MRL 在表示学习中的有效性与鲁棒性。

Appendix I In Practice Costs

实验环境

  • 近似最近邻(NN)搜索实验:使用 HNSW32,运行在 Intel Xeon 2.20GHz CPU 上,有 24 核心

  • 精确搜索实验:在 2 块 40G RAM 的 A100-SXM4 NVIDIA GPU 上进行,使用 CUDA 11.0


MRL 模型

MRL(Matryoshka Representation Learning)模型对 ResNet50 的最后全连接(fc)层进行了最小修改,通过多头结构在不同尺度上生成表示。这使得 MRL 模型相比标准 ResNet50 模型仅增加 8MB 的存储开销

  • MRL–E:使用共享头结构在最后 fc 层生成 logits,因此没有额外的存储开销


检索(Retrieval)

搜索时间复杂度

  • 精确搜索时间复杂度为:O(dkN)

  • HNSW 搜索时间复杂度为:O(dk log(N))
    其中:

    • \( N \):数据库规模

    • \( d \):表示的维度

    • \( k \):短列表长度(top-k)

为了评估实际性能,作者在 ImageNet-1K 和 ImageNet-4K 验证集 上进行了实测,并分析了不同维度 \( d \) 和短列表长度 \( k \) 对搜索时间的影响。


表 19:不同表示维度下的检索时间(秒)

表示尺寸(d)

ImageNet-1K(精确搜索/近似搜索)

ImageNet-4K(精确搜索/近似搜索)

8

0.60 / 0.14

35.70 / 1.17

16

0.57 / 0.18

36.16 / 1.65

32

0.60 / 0.20

36.77 / 1.75

64

0.66 / 0.24

27.88 / 2.21

128

0.86 / 0.32

30.10 / 4.15

256

1.29 / 0.46

34.97 / 3.39

512

2.17 / 0.68

46.97 / 4.83

1024

3.89 / 1.05

70.59 / 7.14

2048

7.31 / 2.05

117.78 / 13.43

重点总结

  • 精确搜索时间随表示维度 \( d \) 的增大而显著增加。

  • 近似搜索(HNSW32) 时间增长较慢,且在高维度下仍保持高效,尤其在 ImageNet-4K 数据集上表现明显优势。


表 20:索引构建时间与索引大小

表示尺寸(d)

精确搜索(Index Size/Build Time)

HNSW32(Index Size/Build Time)

8

40MB / 0.04s, 131MB / 0.33s

381MB / 4.87s, 1248MB / 24.04s

16

80MB / 0.08s, 263MB / 0.27s

421MB / 6.15s, 1379MB / 33.31s

32

160MB / 0.16s, 525MB / 0.52s

501MB / 6.80s, 1642MB / 37.41s

64

320MB / 0.38s, 1051MB / 1.05s

661MB / 8.31s, 2167MB / 47.23s

128

641MB / 0.64s, 2101MB / 2.10s

981MB / 11.73s, 3218MB / 89.87s

256

1281MB / 1.27s, 4202MB / 4.20s

1622MB / 17.70s, 5319MB / 102.84s

512

2562MB / 2.52s, 8404MB / 8.39s

2903MB / 27.95s, 9521MB / 158.47s

1024

5125MB / 5.10s, 16808MB / 17.20s

5465MB / 44.02s, 17925MB / 236.30s

2048

10249MB / 10.36s, 33616MB / 41.05s

10590MB / 86.15s, 34733MB / 468.18s

重点总结

  • 索引构建时间随表示维度 \( d \) 增长而增加。

  • HNSW32 的索引构建时间显著高于精确搜索。

  • 索引大小随着维度增加而线性增长,HNSW32 需要更大的索引空间。


表 21:不同短列表长度(k)下的检索时间

索引类型

k=50

k=100

k=200

k=500

k=1000

k=2048

精确搜索(Exact L2)

0.4406

0.4605

0.5736

0.6060

1.2781

2.7047

近似搜索(HNSW32)

0.1193

0.1455

0.1833

0.2145

0.2333

0.2670

重点总结

  • 精确搜索时间随着 \( k \) 增大显著增加。

  • 近似搜索(HNSW32) 时间增长较慢,且在大 \( k \) 值下依然保持高效,显示出其在实际应用中的优势。


总结

  • MRL 模型在保持 ResNet50 性能的同时,具有较小的存储开销。

  • 检索性能方面,HNSW32 在高维表示和大规模数据集上表现出优越的效率,尤其在检索时间上显著优于精确搜索。

  • 索引构建和存储成本是近似搜索的代价,但这在实际应用中常被其速度优势所弥补。

Appendix J Analysis of Model Disagreement


Discussion of Oracle Accuracy

作者定义了Oracle Accuracy,即只要某个样本在任何维度上被正确分类,就视为“被正确预测”。

  • ImageNet-1K 验证集中 18.46% 的样本无法在所有维度上被正确分类,这是 MRL 模型的性能上限。

  • 表 22 展示了不同维度下首次正确分类的样本比例。

  • 表 23 对比了多种数据集(如 ImageNet-V2、A、R、Sketch)在 MRL-Oracle 和 FF-2048 下的性能。结果显示,MRL-Oracle 在多个数据集上的准确率均高于 FF-2048

此外,作者提出了一种自适应分类方法(Adaptive Classification),用于模拟 Oracle Accuracy。实验发现,达到 76.30% 的 top-1 准确率仅需约 37 维的表示,作者将更优策略的设计留作未来工作。


Grad-CAM Examples

作者利用 Grad-CAM 可视化 分析 MRL 模型在不同维度下的模型差异原因

  • 在“工具”、“蔬菜”、“肉刀”等类别中,模型常因背景干扰而误判,低维模型更容易被混淆

  • 同一超类(superclass)内的模型差异(如“蛇”类的误判)显示了模型对细粒度类别的区分能力有限


Superclass Performance

作者基于 WordNet 层级结构定义了 30 个超类(superclass),用于评估 MRL 模型在更高层次上的分类性能。

  • 表 24 列出了 30 个超类的示例。

  • 表 25 展示了不同维度下的分类准确率。虽然维度增加对性能提升有限,但总体趋势是随着表示维度增加,准确率逐渐趋稳,表明模型在更高维度下收敛较好。


总结

本附录主要探讨了 MRL 模型在不同表示维度下的性能变化及模型间的差异。核心发现包括:

  1. 并非所有类别都随着表示维度增加而持续提升,部分类在低维或高维下表现更优。

  2. Oracle Accuracy 是 MRL 模型性能的理论上限,部分样本无法在任何维度被正确分类。

  3. 低维表示在某些情况下甚至优于高维表示,提示模型设计中应考虑表示维度的灵活性。

  4. 模型差异的原因包括背景干扰、类间混淆等,Grad-CAM 有助于分析这些差异。

  5. 自适应分类方法 可在不牺牲准确率的前提下降低计算成本。

这一分析为后续研究提供了优化表示学习策略的理论和方法基础。

Appendix K Ablation Studies

K.1 MRL 训练范式

通过微调诱导嵌套表示

本节研究了是否可以在没有从头开始训练嵌套结构的模型中诱导嵌套表示。实验通过加载预训练的 FF-2048 ResNet50 模型,并初始化一个新的 MRL 层来完成。通过解冻不同层级的主干网络,观察非线性(未冻结的卷积层)对嵌套表示的影响。结果显示,仅微调线性层在低维度下无法学习到嵌套表示,但随着卷积+ReLU 层数的增加,d=8 的分类准确率显著提高(从 5% 提升到 60%),仅比端到端训练 40 个 epoch 的结果低 6%。维度大于 64 后,差异进一步缩小到 1.5% 以内。

表 26 展示了不同表示维度下的分类准确率,表明添加更多非线性有助于嵌套结构的形成。

相对重要性调整

本节进一步探讨了通过调整不同嵌套维度的相对重要性 \( c_m \) 来优化训练效果的方法。通过两个模型 MRL-8boost 和 MRL-8+16boost 实验发现,提升低维度的相对重要性可以显著提高 d=8 的 top-1 准确率 3%,同时在 16 到 256 维度间也有提升,但 512 到 2048 维度略有下降。这表明相对重要性的调整对优化模型性能具有重要意义,但具体设置方法还需进一步研究。

表 27 展示了不同模型在不同维度下的 top-1 和 top-5 准确率,表明相对重要性调整对低维表现有明显提升。

在任意粒度下训练嵌套表示

为了验证 MRL 在任意粒度下的性能,研究者对比了使用对数粒度(logarithmic granularity)和均匀粒度(uniform granularity)两种方式训练模型的效果。结果发现,在低维度下,对数粒度模型(MRL-log)表现更优,而均匀粒度模型(MRL-Uniform)在高维度下略有提升,但整体提升有限。这表明对数粒度在信息压缩和效率之间取得了更好的平衡。

表 29 展示了不同粒度设置下的 1-NN 准确率,表明对数粒度在低维度性能更好。

低维性实验

本节进一步探讨了使用小于 8 的维度训练 MRL 的效果。实验结果显示,使用 4 或 6 维度训练的模型在其他维度上表现稳定,但低于 8 维度的表示准确率较低且难以训练。此外,更高的维度也会受到最小维度优化难度的影响,略微下降。

表 28 展示了不同维度下的 top-1 准确率,表明 8 维是一个经验上较优的选择。


K.2 检索

自适应检索

本节研究了短列表长度(shortlist length)对检索性能和搜索时间的影响。在 ImageNet-1K 数据集上,随着短列表长度增加到 200,性能趋于饱和。而在 ImageNet-4K 数据集上,随着短列表长度增加到 2048,性能持续提升。这可能与 ImageNet-4K 数据集更大、略偏离训练分布有关。

表 30 和 31 展示了在不同短列表长度下,P@1、mAP@10 等指标的变化,表明短列表长度对检索性能有显著影响,尤其在 ImageNet-4K 中提升更明显。


总结

本附录通过一系列消融实验验证了 MRL 在不同维度、不同训练策略、不同检索设置下的性能表现。重点总结如下:

  1. 嵌套表示的诱导:通过增加非线性结构(如卷积+ReLU),可以有效诱导嵌套表示,即使模型未从头开始训练嵌套结构。

  2. 相对重要性调整:通过提升低维度的相对重要性,可以在不影响高维性能的前提下提升低维表示质量。

  3. 粒度选择:对数粒度在低维表现更优,而均匀粒度在高维略有提升,但提升有限。

  4. 低维训练:低于 8 维度的表示难以训练且性能低下,8 维是一个经验上较优的选择。

  5. 检索优化:短列表长度对检索性能有显著影响,尤其在大数据集(如 ImageNet-4K)中性能提升明显。

这些实验结果为 MRL 方法的进一步优化提供了重要参考。