6.5.2. Zero Redundancy Optimizer (ZeRO)

  • 【原始】Parallelism technique which performs sharding of the tensors somewhat similar to TensorParallel, except the whole tensor gets reconstructed in time for a forward or backward computation, therefore the model doesn’t need to be modified. This method also supports various offloading techniques to compensate for limited GPU memory. Learn more about ZeRO here.

  • 【定义】Zero Redundancy Optimizer (ZeRO) 是由 DeepSpeed 框架提出的一种分布式优化策略。是一种优化器并行技术,旨在大幅减少深度学习模型训练中的内存使用,特别是在多 GPU 环境下。它通过 分片(sharding)管理 模型参数、梯度和优化器状态,极大地提高了 GPU 的内存效率,从而允许更大规模的模型在有限的 GPU 资源上进行训练。ZeRO 是一种与 Tensor Parallelism 相似的技术,但有一些独特的特性,使其能够进一步优化内存使用。

  • 【核心思想】只在计算需要时才将张量重新构建,而在计算完成后将其再次分片到各个 GPU。因此,整个模型的张量不需要一直驻留在单个 GPU 上,从而有效减少了显存的占用。

  • 【三个阶段】

    • Stage 1:Optimizer State Sharding(优化器状态分片)在训练过程中,优化器会维护很多状态信息(例如动量、二阶梯度等),这些状态会消耗大量的内存。在 Stage 1,ZeRO 将这些优化器状态分片到不同的 GPU 上,每个 GPU 只维护一部分优化器状态。因此,内存消耗从所有 GPU 都需要存储完整的优化器状态,变成每个 GPU 只存储一部分。

    • Stage 2:Gradient Sharding(梯度分片)梯度通常在反向传播过程中计算并存储。Stage 2 将这些梯度也进行分片,各个 GPU 只存储一部分的梯度。这样,反向传播中的梯度计算仍然是并行的,但存储的内存显著减少。

    • Stage 3:Parameter Sharding(参数分片)在 Stage 3,模型的参数本身也被分片。每个 GPU 只存储模型参数的一部分,而不是整个模型的副本。在前向传播和反向传播的过程中,模型的参数只会在需要计算的时刻被重构和使用,之后再重新分片和同步。

  • 【原理】在 ZeRO 中,模型参数、梯度和优化器状态不再被每个 GPU 完整地复制,而是被划分成多个块(shards),分布在不同的 GPU 上。当模型进行前向传播、反向传播或者更新参数时:

    • 在需要计算的时候,ZeRO 会将分片的张量重构(reconstruct)成完整的张量,进行计算。

    • 计算结束后,张量会再次被切分并分配回各个 GPU,这样整个训练过程中每个 GPU 只需要处理一部分的张量。

  • ZeRO 与 Tensor Parallelism 的区别:Tensor Parallelism 是在前向传播和反向传播过程中将张量水平切分,不同 GPU 并行处理张量的不同部分,然后同步计算结果。ZeRO 通过分片模型参数、梯度和优化器状态,使得每个 GPU 不需要同时存储整个模型的副本。在计算时,ZeRO 动态重构张量,计算完后再进行分片,而不是一开始就将张量水平切分并固定下来。

  • 【优势】无需对模型结构进行修改,而只需要调整张量的存储和处理方式,因此模型本身不需要特别为 ZeRO 进行重写。不仅通过分片减少 GPU 内存的占用,还支持offloading 技术,这意味着可以将部分计算或数据从 GPU 卸载到 CPU,甚至 NVMe 存储(硬盘),通过这种技术,即使 GPU 内存非常有限,也能支持大模型训练:

  • 【总结】Zero Redundancy Optimizer (ZeRO) 是一种高效的并行优化技术,它通过 分片(sharding) 模型的参数、梯度和优化器状态,将这些信息分布在多个 GPU 上,从而极大地减少了单个 GPU 的内存占用。它能够支持训练非常大的模型,同时保持较高的计算效率。与 Tensor Parallelism 不同,ZeRO 允许动态重构张量,且不需要对模型进行修改。此外,ZeRO 还支持将计算和存储卸载到 CPU 或硬盘,进一步优化有限资源上的模型训练。