1. O4-mini-Cursor
kernels = [1, 3, 5]
strides = [1, 3, 5]
# 均分通道并分配余数
base_ch = out_ch // len(kernels)
rem = out_ch - base_ch * len(kernels)
# branch_chs 示例:[base_ch+1, base_ch, base_ch] (如果 rem=1)
branch_chs = [base_ch + (1 if i < rem else 0) for i in range(len(kernels))]
self.branches = nn.ModuleList([
nn.Conv2d(in_ch, c, k, stride=s, padding=k // 2, bias=bias)
for c, (k, s) in zip(branch_chs, zip(kernels, strides))
])
# fuse 输入通道改为所有 branch 输出通道之和
fuse_in = sum(branch_chs)
self.fuse = nn.Conv2d(fuse_in, out_ch, 1, bias=bias)