10-2 注意力评分函数

10-2节使用了高斯核来对查询和键之间的关系建模。 (10.2.6)中的 高斯核指数部分可以视为注意力评分函数(attention scoring function), 简称评分函数(scoring function), 然后把这个函数的输出结果输入到softmax函数中进行运算。 通过上述步骤,将得到与键对应的值的概率分布(即注意力权重)。 最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。

从宏观来看,上述算法可以用来实现 图10.1.3中的注意力机制框架。 图10.3.1说明了 如何将注意力汇聚的输出计算成为值的加权和, 其中aaa表示注意力评分函数。 由于注意力权重是概率分布, 因此加权和其本质上是加权平均值。

请添加图片描述
请添加图片描述
正如上图所示,选择不同的注意力评分函数aaa会导致不同的注意力汇聚操作。 本节将介绍两个流行的评分函数,稍后将用他们来实现更复杂的注意力机制。

import math
import torch
from torch import nn
from d2l import torch as d2l

掩蔽softmax操作

正如上面提到的,softmax操作用于输出一个概率分布作为注意力权重。 在某些情况下,并非所有的值都应该被纳入到注意力汇聚中。 例如,为了在 9-5节中高效处理小批量数据集, 某些文本序列被填充了没有意义的特殊词元。 为了仅将有意义的词元作为值来获取注意力汇聚, 可以指定一个有效序列长度(即词元的个数), 以便在计算softmax时过滤掉超出指定范围的位置。 下面的masked_softmax函数 实现了这样的掩蔽softmax操作(masked softmax operation), 其中任何超出有效长度的位置都被掩蔽并置为0。

def masked_softmax(X, valid_lens):

    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    # if valid_lens is None::检查 valid_lens 是否为空。如果为空,则表示不需要进行掩蔽操作。
	# return nn.functional.softmax(X, dim=-1):如果 valid_lens 为空,直接对 X 在最后一个维度上(dim=-1)执行 softmax 操作,并返回结果。
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        
		# else::如果 valid_lens 不为空,则进入掩蔽处理流程。
		# shape = X.shape:获取张量 X 的形状,并将其存储在变量 shape 中。
		# if valid_lens.dim() == 1::检查 valid_lens 的维度是否为1。如果是,表示每个样本只有一个有效长度。
		# valid_lens = torch.repeat_interleave(valid_lens, shape[1]):将 valid_lens 中的每个值重复 shape[1] 次,以匹配 X 的形状,从而为每个样本生成一个对应的有效长度。
		# else::如果 valid_lens 是2D张量。
		# valid_lens = valid_lens.reshape(-1):将 valid_lens 展平成1D张量,以便后续处理。
		
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)
        # X.reshape(-1, shape[-1]):将 X 重塑为二维张量,其中第一个维度为 -1(自动计算),第二个维度为原来的最后一个维度 shape[-1]。
		# d2l.sequence_mask(...):对 X 进行掩蔽,sequence_mask 函数会将 X 中超出 valid_lens 范围的元素替换为一个非常大的负值(-1e6),以确保这些位置在 softmax 中的输出接近于0。
        return nn.functional.softmax(X.reshape(shape), dim=-1)
        # 将 X 还原为原来的形状,并在最后一个维度上执行 softmax 操作,最终返回结果。

为了演示此函数是如何工作的, 考虑由两个2×42\times 42×4矩阵表示的样本, 这两个样本的有效长度分别为222333。 经过掩蔽softmax操作,超出有效长度的值都被掩蔽为0。

masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

请添加图片描述
同样,也可以使用二维张量,为矩阵样本中的每一行指定有效长度。

masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

请添加图片描述

加性注意力

请添加图片描述
下面来实现加性注意力。

class AdditiveAttention(nn.Module):
    """加性注意力"""
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
    # 定义类的初始化函数,接收几个参数:key_size(键的大小)、query_size(查询的大小)、num_hiddens(隐藏层的维度数)和 dropout(dropout 概率),以及额外的关键字参数 **kwargs。
        super(AdditiveAttention, self).__init__(**kwargs)
		# 调用父类 nn.Module 的初始化函数。
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 对 queries 和 keys 分别应用线性变换 W_q 和 W_k,将它们转换为相同的隐藏层维度 num_hiddens。
        # 在维度扩展后,
        # queries的形状:(batch_size,查询的个数,1,num_hidden)
        # key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
        # 使用广播方式进行求和
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        # 对 queries 张量在第3个维度(索引为2)添加一个维度,使得 queries 的形状变为 (batch_size, 查询的个数, 1, num_hiddens)
        # 对 keys 张量在第2个维度(索引为1)添加一个维度,使得 keys 的形状变为 (batch_size, 1, 键-值对的个数, num_hiddens)。
        # 利用广播机制,将 queries 和 keys 对应位置的张量元素相加,得到 features,其形状为 (batch_size, 查询的个数, 键-值对的个数, num_hiddens)。
        features = torch.tanh(features)
        # 对 features 张量应用 tanh 激活函数,增加非线性能力。
        
        # self.w_v仅有一个输出,因此从形状中移除最后那个维度。
        # scores的形状:(batch_size,查询的个数,“键-值”对的个数)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # 调用 masked_softmax 函数,对 scores 进行掩蔽后应用 softmax,得到注意力权重 self.attention_weights。形状仍然为 (batch_size, 查询的个数, 键-值对的个数)。
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)
        # self.dropout(self.attention_weights):对注意力权重 self.attention_weights 应用 Dropout 操作,防止过拟合。
		# torch.bmm(..., values):使用 torch.bmm(批量矩阵乘法)将注意力权重和 values 进行矩阵乘法。最终输出的形状为 (batch_size, 查询的个数, 值的维度)。
		# 返回结果即为加权后的值,代表加性注意力机制的输出。

用一个小例子来演示上面的AdditiveAttention类, 其中查询、键和值的形状为(批量大小,步数或词元序列长度,特征大小), 实际输出为(2,1,20)(2,1,20)(2,1,20)(2,10,2)(2,10,2)(2,10,2)(2,10,4)(2,10,4)(2,10,4)。 注意力汇聚输出的形状为(批量大小,查询的步数,值的维度)。

queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# values的小批量,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
                              dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)

请添加图片描述
尽管加性注意力包含了可学习的参数,但由于本例子中每个键都是相同的, 所以注意力权重是均匀的,由指定的有效长度决定。

d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

请添加图片描述

缩放点积注意力

请添加图片描述
下面的缩放点积注意力的实现使用了暂退法进行模型正则化。

class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

为了演示上述的DotProductAttention类, 我们使用与先前加性注意力例子中相同的键、值和有效长度。 对于点积操作,我们令查询的特征维度与键的特征维度大小相同。

queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)

请添加图片描述
与加性注意力演示相同,由于键包含的是相同的元素, 而这些元素无法通过任何查询进行区分,因此获得了均匀的注意力权重。

d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

请添加图片描述

小结

  • 将注意力汇聚的输出计算可以作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力汇聚操作。

  • 当查询和键是不同长度的矢量时,可以使用可加性注意力评分函数。当它们的长度相同时,使用缩放的“点-积”注意力评分函数的计算效率更高

请添加图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Alkali!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值