首先是错误内容截图:(抱歉因为打码有点糊)
我在训练修改后的TransformerXL时,发现了如上的错误,此前代码已经成功地在单GPU下运行过,切换到多卡运行出现该问题。尝试进行解决。
使用的环境是: Pytorch1.11 transformers:4.18
在网上进行查阅后大部分人都说可能是pytorch版本的问题,当前所使用的pytorch版本过高,需要降级到1.4.0版本。
降级听起来比较简单,但是我不想降级到太低的版本,只能走第二条路,修改代码。
首先定位到出错的非源码的最后一行,
param = next(self.parameters())
经过上网查找,发现可能是在训练过程中部分数据的精度不同导致的问题,可能同时存在16位精度和32位精度的数据,尝试在这里进行修改,将其直接指定为torch.float32 进行训练。
原始代码为:
def init_mems(self):
if self.mem_len > 0:
mems = []
param = next(self.parameters())
for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=param.dtype, device=param.device)
mems.append(empty)
return mems
else:
return None
更改后的代码是:
def init_mems(self):
if self.mem_len > 0:
mems = []
for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=torch.float32).cuda()
mems.append(empty)
return mems
else:
return None
成功! 问题解决!