在计算机视觉领域,UNet因其优异的性能在图像分割任务中广受欢迎。本文将介绍一种改进的UNet架构——UNetWithCrossAttention,它通过引入交叉注意力机制来增强模型的特征融合能力。
1. 交叉注意力机制
交叉注意力(Cross Attention)是一种让模型能够动态地从辅助特征中提取相关信息来增强主特征的机制。在我们的实现中,CrossAttention
类实现了这一功能:
class CrossAttention(nn.Module):
def __init__(self, channels):
super(CrossAttention, self).__init__()
self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self,