SwinGAN(二)——[子模块]ciRPE(上下文图像相对位置编码)

文章详细探讨了Transformer模型中位置编码的几种方法,如绝对位置编码、diff操作、r、c变换以及lookup_table的构建过程,强调了线性层在处理位置信息的核心作用,并提出了对传统编码方法的反思和优化建议。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

计算主体:rpe_k(q),q.shape=(batch_size,heads_num,patches_num_each_window,embed_dim)

一、绝对位置

  • 输入:窗口的高和宽。
  • 获取方式:sqrt(patches_num_each_window)
  • 输出:窗口坐标的编码。
  • [sqrt(L),sqrt(L),2]

二、diff

  • 输入:绝对位置坐标
    计算每个坐标和其他坐标的差异,包含本身。
    diff.csv文件
    (截取部分)
  • [L,L,2]

在这里插入图片描述

三、r、c

  • 输入:分别输入横坐标代表的diff和纵坐标代表的diff。
  • 算法描述:
    取绝对值大于alpha=1.9的位置,将其置为一个新值([-3,3]闭区间的整数),原来的范围从[-7,7]就变为了[-3,3]。更通俗易懂点,逻辑很简单,只有判断和赋值,遍历矩阵的每个元素值,判断其绝对值是否大于1.9,如果满足,则将其置为另一个[-3,3]之间的整数,符号保持一致。加上3,改变其范围。7倍放大r。
    例如,某位置上的坐标是-6(任取的,只要在[-7,7]之间即可),-6的绝对值大于1.9,于是将-6输入函数f,映射成-2,于是这点就从-6变成了-2。f包含的逻辑比较多,包含放缩、平移、截断、取整。
    参见r.csv和c.csv(一个横坐标一个纵坐标)
  • 都是[L,L]

四、rp_bucket

  • 输入:r、c
  • 算法描述:
    r、c相加
  • [L,L]

五、num_buckets

  • 输入:计算方式
  • 输出:49

六、self._ctx_p_bucket_flatten

  • 输入:rp_bucket
  • 算法描述:
    逐行扫描rp_bucket,加上首项为0、公差为49、个数为64的等差数列。展平。
  • [L*L]

七、lookup_table*

  • 输入:q、self._ctx_p_bucket_flatten,num_buckets
  • 计算方式:
    线性层,将原来每个头32dim先变为num_buckets的dim,用self._ctx_p_bucket_flatten从每个num_buckets(49)的向量中索引出64个,所以会有重复的。
  • [batch_size,num_heads,L,L]

总结(絮叨)

lookup_table的shape是(8,3,64,64)和attn的shape相同,后续二者相加。
其实前面都是准备数据,第七步才用到q,本质上看,它只是为了得到索引和线性层的中间维度,前面那么多逻辑是为了让索引值更合理吧。所以我们来看看,它最本质的一步还是线形层,也就是对q的通道数据进行加权改变向量的长度,改变通道的维度,那何不让最后一步也改成线形层而不是构建复杂的索引取选择,我们知道,每个patch对应一组q/k/v,如果要考虑位置信息,进行位置编码,直观的想法是直接将位置编码作用在patch上(我是这么想的),而作用再q或k或v上,想想也未尝不可,因为q/k/v本身就是x经过线形层得来的,本身数据就是x数据的加权,一个patch被编码为一个向量,一个向量和一套权重相乘是一个加权求和的过程,最终一个有多个元素的向量只得到一个值,我称为加权求和的“湮灭”过程,要想最终还得到一个向量,那么就需要多套权重,可以说最终向量的长度取决于权重组的个数,而与数据向量的长度无关,只要求数据向量的长度和一组权重向量的长度相等即可,不要求权重组的个数与数据向量的长度相同,它的个数是自由的,取决于最终想要得到多长的向量。
然后,思考它是怎么就把位置信息考虑进去了。不很明白,感觉就是对坐标加加减减、乘乘除除,设置的那些阈值也感觉没有合理的解释,反正只要考虑到位置信息了,就考虑到位置信息了。

代码实现:

import torch
import pandas as pd
import csv
import math
import numpy as np

# 一、绝对位置
def get_absolute_positions(height, width, dtype, device):
    rows = torch.arange(height, dtype=dtype, device=device).view(
        height, 1).repeat(1, width)
    cols = torch.arange(width, dtype=dtype, device=device).view(
        1, width).repeat(height, 1)
    return torch.stack([rows, cols], 2)
pos=get_absolute_positions(8,8,dtype=torch.long,device='cuda:0')

# 二、diff
max_L = 8 * 8
pos1 = pos.view((max_L, 1, 2))
pos2 = pos.view((1, max_L, 2))
diff = pos1 - pos2

# 三、r、c
diff1=diff[:,:,0]
diff2=diff[:,:,1]
# tensor_list1 = diff1.tolist()
# tensor_list2 = diff2.tolist()
# # 保存到 CSV 文件
# with open('tensor_data1.csv', 'w', newline='') as csvfile:
#     writer = csv.writer(csvfile)
#     for row in tensor_list1:
#         writer.writerow(row)
#
# with open('tensor_data2.csv', 'w', newline='') as csvfile:
#     writer = csv.writer(csvfile)
#     for row in tensor_list2:
#         writer.writerow(row)
def piecewise_index(relative_position, alpha, beta, gamma, dtype):
    rp_abs = relative_position.abs()
    mask = rp_abs <= alpha
    not_mask = ~mask
    rp_out = relative_position[not_mask]
    rp_abs_out = rp_abs[not_mask]
    y_out = (torch.sign(rp_out) * (alpha +
                                   torch.log(rp_abs_out / alpha) /
                                   math.log(gamma / alpha) *
                                   (beta - alpha)).round().clip(max=beta)).to(dtype)

    idx = relative_position.clone()
    if idx.dtype in [torch.float32, torch.float64]:
        idx = idx.round().to(dtype)
    idx[not_mask] = y_out
    return idx
r=piecewise_index(diff1,alpha=1.9,beta=3.8,gamma=15.2,dtype=torch.long)
r_list=r.tolist()
# with open('r.csv', 'w', newline='') as csvfile:
#     writer = csv.writer(csvfile)
#     for row in r_list:
#         writer.writerow(row)
c=piecewise_index(diff2,alpha=1.9,beta=3.8,gamma=15.2,dtype=torch.long)
c_list=c.tolist()
# with open('c.csv', 'w', newline='') as csvfile:
#     writer = csv.writer(csvfile)
#     for row in c_list:
#         writer.writerow(row)

r=r+3
c=c+3
r=r*7

# 四、rp_bucket
rp_bucket=r+c

# 五、num_buckets
num_buckets=49

# 六、self._ctx_rp_bucket_flatten
_ctx_rp_bucket_flatten = None
offset=torch.arange(0, 64 * 49, 49,dtype=torch.long, device='cuda:0').view(-1, 1)
_ctx_rp_bucket_flatten = (rp_bucket + offset).flatten()

# 七、lookup_table
num_heads=3
head_dim=32
lookup_table_weight=torch.zeros(num_heads,head_dim,num_buckets)
batch_size=8
patches_num_each_window=64
q=torch.randn(batch_size,num_heads,patches_num_each_window,head_dim)
lookup_table=torch.matmul(
                q.transpose(0, 1).reshape(-1, batch_size * 64, head_dim),
                lookup_table_weight).view(-1, batch_size, 64, num_buckets).transpose(0, 1)
'''
(8,3,64,32)->(3,8*64,32)
(3,32,49)
(8,3,64,49)
'''
lookup_table=lookup_table.flatten(2)[:, :, _ctx_rp_bucket_flatten].view(batch_size, -1, 64, 64)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值