引言:很久之前读blip2,对ITC和ITM大致有个印象,一个对比学习,一个图文匹配的二分类,咋一听好像没什么难理解的,最近好好看了一下源码,觉得实现上很巧妙,值得与诸君共享
这里小编没有一句一句分析,直接源码+注释,觉得这样看比较方便,因为只分析ITC和ITM,所以这里只放了blip2里面的Blip2Qformer的forward函数内容,如有出入,还请各位小伙伴留言斧正!
Image-text Contrastive
###============== Image-text Contrastive ===================###
"""
因为在多张卡上训练,所以这里需要将所有卡上的图像特征收集起来,维度为[batch_size*num_gpu, num_query_tokens, embed_dim],
其中,num_query_tokens是视觉tokens数量,embed_dim是维度
"""
image_feats_all = concat_all_gather(
image_feats
) # [batch_size*num_gpu, num_query_tokens, embed_dim]
# 文本这一步操作与上述同理
text_feat_all = concat_all_gather(text_feat) # [batch_size*num_gpu, embed_dim]
"""
求图像与所有文本的相似度
这里image_feats.unsqueeze(1)之后的维度是[batch_size,1, num_query_tokens, embed_dim]
text_feat_all.unsqueeze(-1)之后的维度是[batch_size*num_gpu, embed_dim,1]
为了求每个图像跟所有文本的相似度,图像特征[batch_size,1, num_query_tokens, embed_dim]第2个维度会被广播到batch_size*num_gpu变成[batch_size*,batch_size*num_gpu, num_query_tokens, embed_dim]
然后矩阵乘法会沿着image_feats和text_feat_all最后两个维度进行相乘,embed_dim维度相乘消失,所以得到的结果为[batch_size,batch_size*num_gpu, num_query_tokens,1]
相乘之后的结果再squeeze()就得到了[batch_size,batch_size*num_gpu, num_query_tokens]
"""
sim_q2t = torch.matmul(
image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
).squeeze() # [batch_size, batch_size*num_gpu, num_query_tokens]
"""
max(-1)表示在最后一个维度上,寻找最大值
也就是说,对每个图像到文本的相似度,选取所有num_query_tokens中的最大值,sim_i2t最终的维度为[batch_size, batch_size*num_gpu]
"""
sim_i2t, _ = sim_q2t.max(-1)
# 通过温度参数self.temp进行相似度的缩放控制
sim_i2t = sim_i2t / self.temp
"""
求文本与所有图像的相似度
text_feat.unsqueeze(1).unsqueeze(1)之后的维度为[batch_size,1,1,embed_dim]
image_feats_all.permute(0, 2, 1)交换后面两个维度之后的特征维度为[batch_size*num_gpu, embed_dim, num_query_token]
同理,文本特征[batch_size,1,1,embed_dim]会广播第2个维度到batch_size*num_gpu,变成[batch_size,batch_size*num_gpu,1,embed_dim]
然后最后两个维度做矩阵乘法得到[batch_size,batch_size*num_gpu,1,num_query_token]
squeeze()之后的特征为[batch_size,batch_size*num_gpu,num_query_token]
"""
sim_t2q = torch.matmul(
text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
).squeeze()
# 对每个文本到图像的相似度,选取所有num_query_tokens中的最大值,sim_i2t最终的维度为[batch_size, batch_size*num_gpu]
sim_t2i, _ = sim_t2q.max(-1)
sim_t2i = sim_t2i / self.temp
rank = dist.get_rank()
bs = image.size(0)
"""
torch.linspace(start, end, steps, dtype=int)的作用是生成从 start