傅里叶卷积Fourier Convolutions
对输入tensor进行FFT,然后提取出实部和虚部,对实部和虚部进行卷积计算,再还原为实部和虚部,还原为tensor。
中间可以对频率域使用SE,对不同的频率进行重标定。
class FourierUnit(nn.Module):
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
# bn_layer not used
super(FourierUnit, self).__init__()
self.groups = groups
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
out_channels=out_channels * 2,
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
self.bn = torch.nn.BatchNorm2d(out_channels