2209.05433_FP8: FP8 Formats For Deep Learning

Abstract

  • FP8是一种比现有16位浮点格式更快的深度学习训练和推理加速方案。

  • 文中提出了两种8位浮点格式:E4M3(4位指数+3位尾数)和E5M2(5位指数+2位尾数)。

  • E5M2遵循IEEE标准处理特殊值,E4M3通过不表示无穷大和简化NaN的方式扩展了动态范围。

  • 实验显示,FP8在图像和语言任务中,能达到和16位训练相当的效果,涵盖了主流神经网络结构(CNN、RNN、Transformer),且训练参数未作调整。

  • 还测试了1750亿参数的大型语言模型,以及用16位训练的模型用FP8后训练量化的效果,解决了int8量化失败的问题。

1. Introduction

  • 随着深度学习模型越来越大,训练它们需要的计算资源也不断增加。

    • 为了加速训练和推理,研究人员一直在用更低精度的数据格式来表示数字,比如16位浮点数(FP16)、8位整数(int8)等。

    • 虽然有人尝试过极端的低位数表示(比如1位二进制),但效果不好,难以保证模型质量。

  • FP8(8位浮点数)是从16位浮点数发展而来的一种新格式,能进一步减少计算需求,同时在推理时比int8更有优势。

    • 已有多篇研究验证了FP8在训练卷积神经网络和语言模型时的有效性,包括不同的指数位设计和改进方案。

  • 本文介绍了一种新的FP8格式(使用两种编码方式),详细说明了设计原理和实际表现。

    • 实验表明,使用FP8训练的模型在各种任务和规模下,效果可以和FP16或bfloat16持平,且不需调整模型或优化器参数,适用于大规模语言模型(最高到1750亿参数)。

2. Aspects of FP8 Usage in Deep Learning

  • 这段内容主要讲了深度学习中使用FP8(8位浮点数)的一些关键点:

    1. FP8格式选择与缩放

      • 不同网络对数值范围的需求不同,因此需要两种FP8格式,并且缩放因子通常在软件中处理,而不是用指数偏置。

      • 缩放因子用来把更高精度的数值调整到FP8能表示的范围内。

    2. 计算时精度转换

      • FP8输入的数学运算通常会先转换成更高精度计算,结果再选性转回FP8保存。

      • 类似现有FP16的做法。

    3. 缩放因子的作用与限制

      • 缩放因子让数值更适合FP8表达,但FP8动态范围有限,有时需要每个张量单独缩放。

      • 溢出时值会被饱和到最大值,跳过更新不适合FP8,因为溢出概率太高。

    4. 反缩放操作

      • 高精度值在转换为FP8之前乘以缩放因子,计算后再用反缩放(乘以缩放因子的倒数)恢复,开销很小。

    5. 特殊值处理

      • 转换时,特殊值(无穷大、NaN)会被转换成FP8对应的NaN,方便混合精度训练时检测溢出。转换的舍入方式和溢出处理方式可灵活设置。

  • 总结:FP8用在深度学习时,关键是用缩放因子调整数值范围,计算过程保持高精度,溢出处理与特殊值转换设计上需特别注意。

3. FP8 Binary Interchange Format

  • 这段内容主要讲了 FP8(8位浮点数)格式的两种编码方式:E4M3E5M2,用于深度学习中不同的张量类型。

两种 FP8 编码格式

  • E4M3:4位指数 + 3位尾数(mantissa),适合用于权重和激活值

  • E5M2:5位指数 + 2位尾数,适合用于梯度张量

  • 它们的设计权衡了精度与动态范围的需求。

  • Exponent (E) and Mantissa (M)

格式细节

  • E5M2 更符合 IEEE 754 标准(例如支持 inf, NaN),便于与 FP16 转换。

  • E4M3 舍弃了大部分特殊值(如 inf),以换取更大的动态范围(从原本 17 个 binades 扩展到 18 个)。

    • 最大数可达 448(而不是标准的 240)。

    • 仍保留正负零和 NaN,保持 IEEE 对称性,避免影响排序等操作。

4. 指数偏置(Exponent Bias)

  • E4M3 的 bias 为 7,E5M2 的 bias 为 15(与 IEEE 类似)。

  • 不对 bias 做自定义,而是在软件中按张量设置缩放因子(更灵活,能处理不同张量的动态范围)。

总结

  • E4M3 和 E5M2 是为深度学习定制的轻量浮点格式,各有适用场景。

  • 设计上兼顾 IEEE 兼容性和实际训练需求,通过删减特殊值或保留其标准表示,权衡了精度与动态范围。

示例讲解

  • 具体的浮点表示在 IEEE-754 中有详细说明

Table 1: Details of FP8 Binary Formats

Normal

  • 公式

值 = (−1)^sign × (1 + fraction) × 2^(exponent - bias)
  • E5M2示例: 最大值 0 11110 11

    • 格式: [ S ][ EEEEE ][ MM ]

    • 数值: 0 11110 11

    • 参数:

      • sign: 0

      • exponent长度: 5

      • bais: 2^(exponent长度-1)-1 = 2^4-1=15

      • exponent: 11110=30

      • fraction: 0.11(二进制)=3/4=0.75

    • 计算过程

      • = (−1)^sign × (1 + fraction) × 2^(exponent - bias)

      • = (-1)^0 × (1 + 0.75) × 2^(30 - 15)

      • = 1.75*2^15

      • = 57,344

  • E4M3示例: 最大值 0 1111 110

    • 格式: [ S ][ EEEE ][ MMM ]

    • 数值: 0 1111 110

    • 参数:

      • sign: 0

      • exponent长度: 4

      • bais: 2^(exponent长度-1)-1 = 2^3-1=7

      • exponent: 1111=15

      • fraction: 0.110(二进制)=3/4=0.75

    • 计算过程

      • = (−1)^sign × (1 + fraction) × 2^(exponent - bias)

      • = (-1)^0 × (1 + 0.75) × 2^(15 - 7)

      • = 1.75*2^8

      • = 448

Subnormal

  • subnormal的公式

值 = (−1)^sign × (0 + fraction) × 2^(1 − bias)
  • E5M2示例: 最大值 0 00000 11

    • 格式: [ S ][ EEEEE ][ MM ]

    • 数值: 0 00000 11

    • 参数:

      • sign: 0

      • exponent长度: 5

      • bais: 2^(exponent长度-1)-1 = 2^4-1=15

      • fraction: 0.11(二进制)=3/4=0.75

    • 计算过程

      • = (−1)^sign × (0 + fraction) × 2^(1 − bias)

      • = (-1)^0 × (0 + 0.75) × 2^(1 - 15)

      • = 0.75*2^(-14)

  • E4M3示例: 最大值 0 0000 111

    • 格式: [ S ][ EEEE ][ MMM ]

    • 数值: 0 0000 111

    • 参数:

      • sign: 0

      • exponent长度: 4

      • bais: 2^(exponent长度-1)-1 = 2^3-1=7

      • fraction: 0.111(二进制)=7/8=0.875

    • 计算过程

      • = (−1)^sign × (0 + fraction) × 2^(1 − bias)

      • = (-1)^0 × (0 + 0.875) × 2^(1 - 7)

      • = 0.875*2^(-6)

4. Empirical Results

  • 研究目标

    • 评估使用FP8格式在训练和推理中的表现,重点是能否在不明显损失精度的前提下降低计算和存储成本。

4.1 Training

  1. 怎么做的

    • 将参与矩阵乘法(GEMM操作)的激活值、权重和梯度剪裁为FP8可表示的值,但计算仍在更高精度(FP16或bfloat16)中进行。

    • 输出张量仍保留高精度,因为它们通常会传给非线性或归一化操作。

  2. 图像分类结果(ImageNet数据集)

    • 各种CNN和Transformer模型在FP8训练下的准确率与原始高精度训练几乎一致(差距在正常训练波动范围内)。

  3. 语言翻译任务(WMT16英德翻译)

    • GNMT(RNN)和Transformer模型在BLEU分数上与原始训练也基本一致。

  4. 语言模型(如GPT)

    • 在维基百科和The Pile等数据集上训练的Transformer/GPT模型,其损失(困惑度)也与高精度训练持平。

    • 即使是175B的GPT模型也成功用FP8训练完成,表现稳定。

4.2 Inference

  1. 优势

    • 如果训练就是用FP8,推理也可以直接用FP8,不需要额外量化或校准,简化了部署。

    • 相比之下,传统的int8推理需要额外的量化步骤且有较大精度损失。

  2. 结果对比

    • 对BERT和GPT等模型的测试显示:

      • FP8推理的精度/困惑度几乎和16位原始模型一样。

      • 而int8会出现明显的精度下降。

4.3 Per-tensor Scaling Factors

  • 在某些情况下,使用相同缩放因子(或相同指数偏移)会造成精度下降。

  • 特别是当更多非GEMM操作(如残差连接)也转为FP8格式时,必须使用每个张量独立的缩放因子,才能保持精度。

  • 实验显示,只对GEMM输入使用FP8并设置合适偏移,可以保持精度;但更多张量参与FP8时必须校准。

✅ 总结

  • FP8在训练和推理中都能保持高精度,表现接近FP16/bfloat16。

  • 它能显著减少存储、内存带宽和计算开销。

  • 部分模型(如MobileNet v2)还有优化空间。

  • 更深入的FP8支持(比如所有张量都FP8)需要进一步研究和更复杂的缩放策略。

5. Conclusions

  • 本文提出了一种FP8格式(E4M3 和 E5M2),兼容IEEE-754标准,便于使用整数操作进行比较和排序。

  • FP8的主要目的是加速深度学习训练和推理,减少资源消耗。

  • 实验表明,使用FP8训练各种图像和语言模型能达到与16位训练相同的精度,同时也简化了8位推理的部署流程,无需复杂的校准或微调。

  • 原文:In this paper we propose an FP8 binary interchange format, consisting of E4M3 and E5M2 encodings. By minimally deviating from IEEE-754 conventions for binary encoding of floating point values, we ensure that that software implementations can continue to rely on such IEEE FP properties as ability to compare and sort values using integer operations. The primary motivator for the format is acceleration of Deep Learning training and inference, by enabling smaller and more power efficient math pipelines as well as reducing memory bandwidth pressure. We demonstrate that a wide variety of neural network models for image and language tasks can be trained in FP8 to match model accuracy achieved with 16-bit training sessions, using the same model, optimizer, and training hyperparameters. Using FP8 not only accelerates and reduces resources required to train, but also simplifies 8-bit inference deployment by using the same datatypes for training and inference. Prior to FP8 8-bit inference required calibrating or fine-tuning for int8 models trained in floating point, which added complexity to the deployment process and in some cases failed to maintain accuracy.