PyTorch DistributedDataParallel训练踩坑记录

问题所在:当需要用DDP同时训练多个模型时,正常情况下将每个模型用DistributedDataParallel类包裹一下即可。当时我想着这样挺麻烦的,就想了一个取巧的办法:我用torch.nn.Module.add_module将所有模型注册成为一个大模型,然后以整体一次传递给DistributedDataParallel。

这样是不是很完美?但是我自己感觉这样会有问题,所以我没忘记做测试。如何测试呢,当然是将网络部分参数打印出来,不同进程的同一个参数值应该是一模一样的。然而现实狠狠的给了我一巴掌,不同进程打印出来的结果除了第一次(刚初始化完)外完全不一样。(其实可以直接观察两个进程的输出日志是否同步来判断。如果每一个iter都是同步的,那么他们数据应该是同步的。而如果有的进程快,有的进程慢,那么他们的数据肯定没同步。因为DistributedDataParallel多个进程之间每一个loss.backward()都会进行数据同步。)

因为这个项目挺大的,我一直没以为是这个原因导致的,所以我DeBug了两天都没有解决,最后是怎么找到问题的呢?

我将pytorch examples下载下来运行确认该代码多个进程之间数据是同步的,然后将我的代码结构和参数往这边靠,我将所有可调的参数都靠过去了我的模型参数还是没有在多个进程间同步。这时我就开始将我的代码中的复杂的模块用这个样例中的简单的模块一一替换,包括data_loader, optimizer, models和迭代过程等...。最后的结果是替换了模型之后我的数据就同步了,由此才定位到了应该是上面采用的技巧出的问题。

此时我想起之前浏览DDP的API的时候闪现过去的一个warning

DDP_warning

别说了,都是泪!!!

DDP坑太多了,比如K80不支持nccl后端,要改为gloo。比如本应该占用其他GPU的进程在0号GPU也占用几百兆显存导致影响最大batch-size。

就挺难!!!

0%