Skip to content

注意力机制(attention mechanisms)

查询、键和值

在注意力机制的背景下,自主性提示被称为查询(query)

给定任何查询,注意力机制通过注意力汇聚(attention pooling) 将选择引导至感官输入(sensory inputs,例如中间特征表示)。这些感官输入被称为值(value)

每个都与一个key)配对,这可以想象为感官输入的非自主提示

注意力机制

“是否包含自主性提示”将注意力机制全连接层或汇聚层区别开来。

特性注意力机制全连接层汇聚层
参数化是(动态权重)是(静态权重)否(固定规则)
自主性提示✅ 动态适应输入❌ 静态处理❌ 静态处理
输入依赖性高度依赖不依赖不依赖
典型应用Transformer, NLP传统分类模型CNN 的空间降维

通过注意力汇聚查询(自主性提示)和(非自主性提示)结合在一起,实现对(感官输入)的选择倾向(智能选择)。

注意力汇聚(attention pooling)

查询(自主提示)和键(非自主提示)之间交互形成注意力汇聚,注意力汇聚有选择地聚合值(感官输入)以生成最终输出。

平均汇聚忽略了键的作用,不够聪明。因此根据查询和键相对位置加权,是一个更好的想法。

f(x)=i=1nK(xxi)j=1nK(xxj)yi

其中K是核(kernel),公式所描述的估计器被称为 Nadaraya-Watson 核回归(Nadaraya-Watson kernel regression)。

受此启发,可以写一个更加通用的注意力汇聚attention pooling公式:

f(x)=i=1nα(x,xi)yi

其中x是查询,(xi,yi)是键值对。将查询x和键xi之间关系建模为注意力权重α(x,xi),这个权重将被分配给每一个对应值yi

下面考虑一个高斯核Gaussian kernel,其定义为:

K(u)=12πexp(u22)

代入高斯核,并推导公式:

f(x)=i=1nα(x,xi)yi=i=1nexp(12(xxi)2)j=1nexp(12(xxj)2)yi=i=1nsoftmax(12(xxi)2)yi.

如上可知,一个键xi越接近查询x,这个键注意力权重就越大,就获得更多注意力。

TIP

越接近、越相似、权重越高。“查询-键”对越近,注意力汇聚注意力权重就越高。

注意力评分函数

上面使用高斯核对查询和键之间关系建模,高斯核指数部分视为注意力评分函数,简称评分函数(scoring function)。

评分函数计算方式适用场景
高斯核exp(|qk|2/2σ2)局部相关性、空间或时间连续性
点积qk查询和键维度相同,高维向量,效率优先
加性注意力vTtanh(Wqq+Wkk)查询和键维度不同,复杂关系,可学习交互
余弦相似度qk|q||k|方向相似性(忽略向量长度)

缩放点积注意力(Scaled Dot-Product Attention)公式:

score(q,k)=qkdk

Transformer 默认缩放点积注意力,主要基于计算效率、理论合理性和实践效果综合考量。

key 相同时权重均匀分布

当所有的键(Key)相同时,计算出的注意力权重会均匀分布。所以键是决定性因素

详细解释

  1. 注意力权重的计算过程
  • 假设我们有一组查询 Q 和一组键 K。对于每一个查询 qi,我们计算它与所有键 kj 的相似度(通常通过点积或其他相似度度量)。
  • 然后,对这些相似度进行 softmax 操作,得到注意力权重 αij
αij=exp(qikj)jexp(qikj)
  • 这些权重表示在生成输出时,每个值 vj 的重要性。
  1. 所有键相同的情况
  • 如果所有的键 kj 都相同,即 k1=k2==kn=k,那么对于任意查询 qi
qikj=qik对于所有的 j
  • 因此,所有的相似度 qikj 都相同,假设这个相同的值为 s
  • 计算 softmax 时:
αij=exp(s)jexp(s)=exp(s)nexp(s)=1n
  • 因此,所有的注意力权重 αij 都等于 1n,即均匀分布。
  1. 直观理解
  • 键的作用是为查询提供“应该关注哪些部分”的信息。如果所有的键都相同,那么查询无法区分哪些部分更重要,因此只能均匀地关注所有的部分。
  • 类似于在人群中,如果所有人的声音完全相同,你无法区分谁在说什么,只能平均听取所有人的意见。

实例演示

假设:

  • 查询 q=[1,0]
  • k1=k2=k3=[1,1]
  • 计算点积:qkj=11+01=1对于所有的 j
  • 计算 softmax:αj=exp(1)exp(1)+exp(1)+exp(1)=e3e=13
  • 因此,注意力权重为 [13,13,13],即均匀分布。

可能的误区

  1. 查询的作用
  • 有人可能会认为查询的不同会影响权重,但实际上如果键相同,无论查询如何,相似度都是相同的(因为 qk 对所有 k 相同)。
  • 只有在键不同时,查询的不同才会导致不同的相似度。
  1. softmax 的性质
  • softmax 对相同的输入值会输出均匀分布。这是 softmax 的一个基本性质。
  • 如果输入到 softmax 的所有值相同,输出就是均匀的。

多头注意力(multihead attention)

在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。

因此,允许注意力机制组合查询、键和值的不同子空间表示representation subspaces可能是有益的。

与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 h 组不同的线性投影linear projections)来变换查询、键和值。然后,这 h 组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 h 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为多头注意力(multihead attention)

对于 h 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)

代码实现

在实现过程中通常选择缩放点积注意力,作为每一个注意力头。

py
import math
import torch
from torch import nn
from d2l import torch as d2l

class MultiHeadAttention(nn.Module):
  def __init__(self, key_size, query_size, value_size, num_hiddens,
                num_heads, dropout, bias=False, **kwargs):
    super(MultiHeadAttention, self).__init__(**kwargs)
    self.num_heads = num_heads
    self.attention = d2l.DotProductAttention(dropout)
    self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
    self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
    self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
    self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

  def forward(self, queries, keys, values, valid_lens):
    # queries,keys,values的形状: (batch_size,查询或者“键-值”对的个数,num_hiddens)
    # valid_lens的形状: (batch_size,)或(batch_size,查询的个数)
    # 经过变换后,输出的queries,keys,values 的形状:
    # (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    queries = transpose_qkv(self.W_q(queries), self.num_heads)
    keys = transpose_qkv(self.W_k(keys), self.num_heads)
    values = transpose_qkv(self.W_v(values), self.num_heads)

    if valid_lens is not None:
      # 在轴0,将第一项(标量或者矢量)复制num_heads次,
      # 然后如此复制第二项,然后诸如此类。
      valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

    # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
    output = self.attention(queries, keys, values, valid_lens)

    # output_concat的形状:(batch_size,查询的个数,num_hiddens)
    output_concat = transpose_output(output, self.num_heads)
    return self.W_o(output_concat)


# 先降维拆分:为了多注意力头的并行计算而变换形状
def transpose_qkv(X, num_heads):
  # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
  # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
  X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

  # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
  X = X.permute(0, 2, 1, 3)

  # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
  return X.reshape(-1, X.shape[2], X.shape[3])

# 再升维融合:逆转transpose_qkv函数的操作
def transpose_output(X, num_heads):
  X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
  X = X.permute(0, 2, 1, 3)
  return X.reshape(X.shape[0], X.shape[1], -1)

# 实例实用
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
# torch.Size([2, 4, 100])

自注意力和位置编码

自注意力定义

将词元序列输入注意力池化中,以便同一组词元同时充当查询、键和值。

即每个查询都会关注所有键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,被称为自注意力self-attention),或内部注意力intra-attention)。

py
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape
# torch.Size([2, 4, 100])

与 CNN、RNN 比较

An Image目标:把 n 个词元序列映射到另一个长度相等序列,每个输入输出词元都是 d 维向量。

CNN(卷积神经网络)

  • 设卷积核大小 k,序列长度 n,输入输出通道数 d,卷积层计算复杂度O(knd2)
  • 卷积神经网络是分层的,因有 O(1) 个顺序操作,最大路径长度O(n/k)

RNN(循环神经网络)

  • 更新循环神经网络隐状态时,d×d 权重矩阵和 d 维隐状态乘法计算复杂度为 O(d2),序列长度为 n ,因此计算复杂度O(nd2)
  • O(n) 个顺序操作无法并行化,最大路径长度O(n)

自注意力

  • 查询、键和值都是 n×d 矩阵,考虑缩放点积注意力,n×d 矩阵乘以 d×n 矩阵。之后输出的 n×n 矩阵乘以 n×d 矩阵,计算复杂度O(n2d)
  • 每个词元都通过自注意力直接连接到任何其他词元,有 O(1) 个顺序操作可以并行计算,最大路径长度也是 O(1)

结论:卷积神经网络和自注意力都拥有并行计算优势,且自注意力最大路径长度最短。但是因为自注意力计算复杂度是关于序列长度的二次方,所以在很长序列中计算会非常慢

位置编码

在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作

为了使用序列的顺序信息,通过在输入表示中添加位置编码positional encoding)来注入绝对或相对位置信息。位置编码可以通过学习得到也可以直接固定得到。

下面介绍固定位置编码

假设输入 XRn×d 包含一个序列中 n 个词元 d 维嵌入表示。位置编码使用相同形状的位置嵌入矩阵 PRn×d 输出 X+P

在位置嵌入矩阵 P 中,行代表词元在序列中的位置,列代表位置编码的不同维度。相关位置如下:

矩阵第 i 行、第 2j 列的元素:

pi,2j=sin(i100002j/d)

矩阵第 i 行、第 2j+1 列的元素:

pi,2j+1=cos(i100002j/d)

代码实现如下:

py
import math
import torch
from torch import nn
from d2l import torch as d2l

# 位置编码
class PositionalEncoding(nn.Module):
  def __init__(self, num_hiddens, dropout, max_len=1000):
    super(PositionalEncoding, self).__init__()
    self.dropout = nn.Dropout(dropout)
    # 创建一个足够长的P
    self.P = torch.zeros((1, max_len, num_hiddens))
    X = torch.arange(max_len, dtype=torch.float32).reshape(
        -1, 1) / torch.pow(10000, torch.arange(
        0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
    self.P[:, :, 0::2] = torch.sin(X)
    self.P[:, :, 1::2] = torch.cos(X)

  def forward(self, X):
    X = X + self.P[:, :X.shape[1], :].to(X.device)
    return self.dropout(X)

# Go!
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

An Image效果如图:可以看到位置嵌入矩阵第 6、7 列和频率高于第 8、9 列。

点解?在二进制表示中,较高比特位的交替频率低于较低比特位。

py
for i in range(8):
  print(f'{i}的二进制是:{i:>03b}')

# 0的二进制是:000
# 1的二进制是:001
# 2的二进制是:010
# 3的二进制是:011
# 4的二进制是:100
# 5的二进制是:101
# 6的二进制是:110
# 7的二进制是:111

打印 (0、1、...、7) 的二进制表示形式,可以看到:每 1 个数字、每 2 个数字和每 4 个数字上的比特值在第 1 个最低位、第 2 个最低位和第 3 个最低位上分别交替 0 和 1

注意力机制 has loaded