计算主体:rpe_k(q),q.shape=(batch_size,heads_num,patches_num_each_window,embed_dim)
一、绝对位置
- 输入:窗口的高和宽。
- 获取方式:sqrt(patches_num_each_window)
- 输出:窗口坐标的编码。
- [sqrt(L),sqrt(L),2]
二、diff
- 输入:绝对位置坐标
计算每个坐标和其他坐标的差异,包含本身。
diff.csv文件
(截取部分) - [L,L,2]
三、r、c
- 输入:分别输入横坐标代表的diff和纵坐标代表的diff。
- 算法描述:
取绝对值大于alpha=1.9的位置,将其置为一个新值([-3,3]闭区间的整数),原来的范围从[-7,7]就变为了[-3,3]。更通俗易懂点,逻辑很简单,只有判断和赋值,遍历矩阵的每个元素值,判断其绝对值是否大于1.9,如果满足,则将其置为另一个[-3,3]之间的整数,符号保持一致。加上3,改变其范围。7倍放大r。
例如,某位置上的坐标是-6(任取的,只要在[-7,7]之间即可),-6的绝对值大于1.9,于是将-6输入函数f,映射成-2,于是这点就从-6变成了-2。f包含的逻辑比较多,包含放缩、平移、截断、取整。
参见r.csv和c.csv(一个横坐标一个纵坐标) - 都是[L,L]
四、rp_bucket
- 输入:r、c
- 算法描述:
r、c相加 - [L,L]
五、num_buckets
- 输入:计算方式
- 输出:49
六、self._ctx_p_bucket_flatten
- 输入:rp_bucket
- 算法描述:
逐行扫描rp_bucket,加上首项为0、公差为49、个数为64的等差数列。展平。 - [L*L]
七、lookup_table*
- 输入:q、self._ctx_p_bucket_flatten,num_buckets
- 计算方式:
线性层,将原来每个头32dim先变为num_buckets的dim,用self._ctx_p_bucket_flatten从每个num_buckets(49)的向量中索引出64个,所以会有重复的。 - [batch_size,num_heads,L,L]
总结(絮叨)
lookup_table的shape是(8,3,64,64)和attn的shape相同,后续二者相加。
其实前面都是准备数据,第七步才用到q,本质上看,它只是为了得到索引和线性层的中间维度,前面那么多逻辑是为了让索引值更合理吧。所以我们来看看,它最本质的一步还是线形层,也就是对q的通道数据进行加权改变向量的长度,改变通道的维度,那何不让最后一步也改成线形层而不是构建复杂的索引取选择,我们知道,每个patch对应一组q/k/v,如果要考虑位置信息,进行位置编码,直观的想法是直接将位置编码作用在patch上(我是这么想的),而作用再q或k或v上,想想也未尝不可,因为q/k/v本身就是x经过线形层得来的,本身数据就是x数据的加权,一个patch被编码为一个向量,一个向量和一套权重相乘是一个加权求和的过程,最终一个有多个元素的向量只得到一个值,我称为加权求和的“湮灭”过程,要想最终还得到一个向量,那么就需要多套权重,可以说最终向量的长度取决于权重组的个数,而与数据向量的长度无关,只要求数据向量的长度和一组权重向量的长度相等即可,不要求权重组的个数与数据向量的长度相同,它的个数是自由的,取决于最终想要得到多长的向量。
然后,思考它是怎么就把位置信息考虑进去了。不很明白,感觉就是对坐标加加减减、乘乘除除,设置的那些阈值也感觉没有合理的解释,反正只要考虑到位置信息了,就考虑到位置信息了。
代码实现:
import torch
import pandas as pd
import csv
import math
import numpy as np
# 一、绝对位置
def get_absolute_positions(height, width, dtype, device):
rows = torch.arange(height, dtype=dtype, device=device).view(
height, 1).repeat(1, width)
cols = torch.arange(width, dtype=dtype, device=device).view(
1, width).repeat(height, 1)
return torch.stack([rows, cols], 2)
pos=get_absolute_positions(8,8,dtype=torch.long,device='cuda:0')
# 二、diff
max_L = 8 * 8
pos1 = pos.view((max_L, 1, 2))
pos2 = pos.view((1, max_L, 2))
diff = pos1 - pos2
# 三、r、c
diff1=diff[:,:,0]
diff2=diff[:,:,1]
# tensor_list1 = diff1.tolist()
# tensor_list2 = diff2.tolist()
# # 保存到 CSV 文件
# with open('tensor_data1.csv', 'w', newline='') as csvfile:
# writer = csv.writer(csvfile)
# for row in tensor_list1:
# writer.writerow(row)
#
# with open('tensor_data2.csv', 'w', newline='') as csvfile:
# writer = csv.writer(csvfile)
# for row in tensor_list2:
# writer.writerow(row)
def piecewise_index(relative_position, alpha, beta, gamma, dtype):
rp_abs = relative_position.abs()
mask = rp_abs <= alpha
not_mask = ~mask
rp_out = relative_position[not_mask]
rp_abs_out = rp_abs[not_mask]
y_out = (torch.sign(rp_out) * (alpha +
torch.log(rp_abs_out / alpha) /
math.log(gamma / alpha) *
(beta - alpha)).round().clip(max=beta)).to(dtype)
idx = relative_position.clone()
if idx.dtype in [torch.float32, torch.float64]:
idx = idx.round().to(dtype)
idx[not_mask] = y_out
return idx
r=piecewise_index(diff1,alpha=1.9,beta=3.8,gamma=15.2,dtype=torch.long)
r_list=r.tolist()
# with open('r.csv', 'w', newline='') as csvfile:
# writer = csv.writer(csvfile)
# for row in r_list:
# writer.writerow(row)
c=piecewise_index(diff2,alpha=1.9,beta=3.8,gamma=15.2,dtype=torch.long)
c_list=c.tolist()
# with open('c.csv', 'w', newline='') as csvfile:
# writer = csv.writer(csvfile)
# for row in c_list:
# writer.writerow(row)
r=r+3
c=c+3
r=r*7
# 四、rp_bucket
rp_bucket=r+c
# 五、num_buckets
num_buckets=49
# 六、self._ctx_rp_bucket_flatten
_ctx_rp_bucket_flatten = None
offset=torch.arange(0, 64 * 49, 49,dtype=torch.long, device='cuda:0').view(-1, 1)
_ctx_rp_bucket_flatten = (rp_bucket + offset).flatten()
# 七、lookup_table
num_heads=3
head_dim=32
lookup_table_weight=torch.zeros(num_heads,head_dim,num_buckets)
batch_size=8
patches_num_each_window=64
q=torch.randn(batch_size,num_heads,patches_num_each_window,head_dim)
lookup_table=torch.matmul(
q.transpose(0, 1).reshape(-1, batch_size * 64, head_dim),
lookup_table_weight).view(-1, batch_size, 64, num_buckets).transpose(0, 1)
'''
(8,3,64,32)->(3,8*64,32)
(3,32,49)
(8,3,64,49)
'''
lookup_table=lookup_table.flatten(2)[:, :, _ctx_rp_bucket_flatten].view(batch_size, -1, 64, 64)