6. 使用LARS / LAMB 优化分布式超大batch 训练¶
6.1. 简介¶
在数据并行分布式训练场景中, 常使用增加GPU数量的方式来加速训练. 为了保证GPU的算力得到充分利用, 每张GPU卡上的batch size都需要足够大. 因此在增加GPU 数量同时, 训练的全局batch size 也会变大.
但越大的全局batch size 会带来训练的收敛问题[1] [2]:
模型最终精度损失
收敛速度变慢, 需要更多的epoch 才能收敛
LARS[3] 和 LAMB[4] 两个优化策略常用来解决上述超大batch 训练中的收敛问题.
Paddle 实现了这两种优化策略,paddle.distributed.fleet 作为Paddle通用的分布式训练API提供了简单易用的接口, 用户只需要添加几行代码 就可将策略加入到原有的训练中。 通过这两个优化策略, 我们在超大batch 场景中实现了更快的收敛速度和无损的精度, 结合Fleet 中其他的策略(e.g. AMP) 可以极大缩短的训练整体的time2train.
6.1.1. 试验效果¶
resnet50 imagenet |
Global batch size |
epoch |
top1 |
---|---|---|---|
[Goyal et al] |
8k |
90 |
76.3% |
LARS Paper |
32k |
90 |
72.3% |
[fleet: lars + amp] |
16k |
60 |
76.2% |
[fleet: lars + amp] |
32k |
62 |
75.9% |
6.2. LARS¶
我们以在单机多卡上Resent50 训练为简单例子介绍fleet 中 LARS的用法。
import os
import fleetx as X
import paddle
paddle.enable_staic()
import paddle.fluid as fluid
import paddle.distributed.fleet.base.role_maker as role_maker
import time
import paddle.distributed.fleet as fleet
通过X.parse_train_configs()
接口,用户可以定义训练相关的参数,如:学习率、衰减率等。同时通过fleet.init()
接口定义了分布式模型,下面代码中的is_collective=True
表示采用集合通信的GPU分布式模式训练模型。
paddle.enable_static()
configs = X.parse_train_configs()
fleet.init(is_collective=True)
用户可以通过X.applications
接口加载我们预先定义好的模型,如:Resnet50、VGG16、BERT等。并使用定制化的data_loader加载模型,同时可以定义训练中使用的batch_size等参数。
model = X.applications.Resnet50()
downloader = X.utils.Downloader()
local_path = downloader.download_from_bos(
fs_yaml='https://fleet.bj.bcebos.com/test/loader/small_imagenet.yaml',
local_path='./data')
batch_size = 32
loader = model.get_train_dataloader(local_path, batch_size=batch_size)
LARS 优化算法的公式如下:
可以看到LARS 其实是在 带weight decay
的momentum
优化器的基础上加入了local learning rate
的逻辑,
对每一层的learning rate
进行了放缩. fleet 将 LARS实现为一个 fleet
meta optimizer, 在使用时需要设置一下几点:
LARS meta optimizer 的 inner optimizer 必须为
momentum
, 并在 momentum 中定义mu
和lr
参数.在 fleet dist_strategy 定义LARS 特有的
lars_coeff
参数和lars_weight_decay
参数.LARS 已经将
weight decay
包含进公式中, 用户不需要再在 optimizer中设置regularization
.fleet 中还提供 lars_weight_decay 过滤策略, 可以通过在
exclude_from_weight_decay
参数加入对应layer 的name string
, 让这一 layer 的参数不进行 lars weight decay. (通常我们将``BN`` 参数 和FC_bias
从lars weight decay 中过滤)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.lars = True
dist_strategy.lars_configs = {
"lars_coeff": 0.001,
"lars_weight_decay": 0.0005,
"exclude_from_weight_decay": ['batch_norm', '.b_0']
}
optimizer = fluid.optimizer.Momentum(learning_rate=0.01, momentum=0.9)
optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
optimizer.minimize(model.loss)
这一部分和fleet 中其他任务基本相同:
place = fluid.CUDAPlace(int(os.environ.get('FLAGS_selected_gpus', 0)))
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for i, data in enumerate(loader()):
start_time = time.time()
cost_val = exe.run(model.main_prog,
feed=data,
fetch_list=[model.loss.name])
end_time = time.time()
print(
"worker_index: %d, step%d cost = %f, speed: %f"
% (fleet.worker_index(), i, cost_val[0], batch_size / (end_time - start_time)))
6.2.1. 运行训练脚本¶
完成上述脚本的编写后,我们就可以使用以下命令一行启动单机多卡分布式训练:
fleetrun --gpus 0,1,2,3,4,5,6,7 --log_dir log example_lars.py
6.3. LAMB¶
我们以在单机多卡上Bert 训练为简单例子介绍fleet 中LAMB 的用法.
import os
import fleetx as X
import paddle
paddle.enable_staic()
import paddle.fluid as fluid
import paddle.distributed.fleet.base.role_maker as role_maker
import time
import paddle.distributed.fleet as fleet
这一步和上文中的LARS 一致。
paddle.enable_static()
configs = X.parse_train_configs()
fleet.init(is_collective=True)
这一步和上文中的LARS 一致。
model = X.applications.Resnet50()
downloader = X.utils.Downloader()
local_path = downloader.download_from_bos(
fs_yaml='https://fleet.bj.bcebos.com/test/loader/small_imagenet.yaml',
local_path='./data')
batch_size = 32
loader = model.get_train_dataloader(local_path, batch_size=batch_size)
LAMB 优化算法的公式如下:
和LARS 类似, LAMB 也是在内层优化器的基础上,
套了一个local learning rate
的逻辑, 对每一层的learning rate
进行了放缩. fleet 将 LAMB实现为一个 fleet meta optimizer,
在使用时需要设置一下几点:
LAMB meta optimizer 的 inner optimizer 必须为
adam
, 并在 adam 中定义 学习率lr
, 一阶 moment 的指数衰减率beta1
和二阶moment 的指数衰减率beta2
参数.在 fleet dist_strategy 定义LAMB 特有的
lamb_weight_decay
参数.LAMB 已经将
weight decay
包含进公式中, 用户不需要再在 optimizer中设置regularization
.fleet 中还提供 lamb_weight_decay 过滤策略, 可以通过在
exclude_from_weight_decay
参数加入对应layer 的name string
, 让这一 layer 的参数不进行 lars weight decay. (通常我们将``LN`` 从lamb weight decay 中过滤)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.lamb = True
dist_strategy.lamb_configs = {
'lamb_weight_decay': 0.01,
'exclude_from_weight_decay': ['layer_norm'],
}
optimizer = paddle.optimizer.Adam(learning_rate=0.01, beta1=0.9, beta2=0.999)
optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
optimizer.minimize(model.loss)
这一部分和fleet 中其他任务基本相同:
place = fluid.CUDAPlace(int(os.environ.get('FLAGS_selected_gpus', 0)))
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for i, data in enumerate(loader()):
start_time = time.time()
cost_val = exe.run(model.main_prog,
feed=data,
fetch_list=[model.loss.name])
end_time = time.time()
print(
"worker_index: %d, step%d cost = %f, speed: %f"
% (fleet.worker_index(), i, cost_val[0], batch_size / (end_time - start_time)))
6.3.1. 运行训练脚本¶
完成上述脚本的编写后,我们就可以使用以下命令一行启动单机多卡分布式训练:
fleetrun --gpus 0,1,2,3,4,5,6,7 --log_dir log resnet50_lamb.py