Appearance
注意力机制(attention mechanisms)
查询、键和值
在注意力机制的背景下,自主性提示被称为查询(query)。
给定任何查询,注意力机制通过注意力汇聚(attention pooling) 将选择引导至感官输入(sensory inputs,例如中间特征表示)。这些感官输入被称为值(value)。
每个值都与一个键(key)配对,这可以想象为感官输入的非自主提示。
注意力机制
“是否包含自主性提示”将注意力机制与全连接层或汇聚层区别开来。
| 特性 | 注意力机制 | 全连接层 | 汇聚层 |
|---|---|---|---|
| 参数化 | 是(动态权重) | 是(静态权重) | 否(固定规则) |
| 自主性提示 | ✅ 动态适应输入 | ❌ 静态处理 | ❌ 静态处理 |
| 输入依赖性 | 高度依赖 | 不依赖 | 不依赖 |
| 典型应用 | Transformer, NLP | 传统分类模型 | CNN 的空间降维 |
通过注意力汇聚将查询(自主性提示)和键(非自主性提示)结合在一起,实现对值(感官输入)的选择倾向(智能选择)。
注意力汇聚(attention pooling)
查询(自主提示)和键(非自主提示)之间交互形成注意力汇聚,注意力汇聚有选择地聚合值(感官输入)以生成最终输出。
平均汇聚忽略了键的作用,不够聪明。因此根据查询和键相对位置加权,是一个更好的想法。
其中
受此启发,可以写一个更加通用的注意力汇聚attention pooling公式:
其中
下面考虑一个高斯核Gaussian kernel,其定义为:
代入高斯核,并推导公式:
如上可知,一个键
TIP
越接近、越相似、权重越高。“查询-键”对越近,注意力汇聚注意力权重就越高。
注意力评分函数
上面使用高斯核对查询和键之间关系建模,高斯核指数部分视为注意力评分函数,简称评分函数(scoring function)。
| 评分函数 | 计算方式 | 适用场景 |
|---|---|---|
| 高斯核 | 局部相关性、空间或时间连续性 | |
| 点积 | 查询和键维度相同,高维向量,效率优先 | |
| 加性注意力 | 查询和键维度不同,复杂关系,可学习交互 | |
| 余弦相似度 | 方向相似性(忽略向量长度) |
缩放点积注意力(Scaled Dot-Product Attention)公式:
Transformer 默认缩放点积注意力,主要基于计算效率、理论合理性和实践效果综合考量。
key 相同时权重均匀分布
当所有的键(Key)相同时,计算出的注意力权重会均匀分布。所以键是决定性因素。
详细解释
- 注意力权重的计算过程:
- 假设我们有一组查询
和一组键 。对于每一个查询 ,我们计算它与所有键 的相似度(通常通过点积或其他相似度度量)。 - 然后,对这些相似度进行 softmax 操作,得到注意力权重
:
- 这些权重表示在生成输出时,每个值
的重要性。
- 所有键相同的情况:
- 如果所有的键
都相同,即 ,那么对于任意查询 :
- 因此,所有的相似度
都相同,假设这个相同的值为 。 - 计算 softmax 时:
- 因此,所有的注意力权重
都等于 ,即均匀分布。
- 直观理解:
- 键的作用是为查询提供“应该关注哪些部分”的信息。如果所有的键都相同,那么查询无法区分哪些部分更重要,因此只能均匀地关注所有的部分。
- 类似于在人群中,如果所有人的声音完全相同,你无法区分谁在说什么,只能平均听取所有人的意见。
实例演示
假设:
- 查询
。 - 键
。 - 计算点积:
- 计算 softmax:
- 因此,注意力权重为
,即均匀分布。
可能的误区
- 查询的作用:
- 有人可能会认为查询的不同会影响权重,但实际上如果键相同,无论查询如何,相似度都是相同的(因为
对所有 相同)。 - 只有在键不同时,查询的不同才会导致不同的相似度。
- softmax 的性质:
- softmax 对相同的输入值会输出均匀分布。这是 softmax 的一个基本性质。
- 如果输入到 softmax 的所有值相同,输出就是均匀的。
多头注意力(multihead attention)
在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。
因此,允许注意力机制组合查询、键和值的不同子空间表示representation subspaces可能是有益的。
与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 linear projections)来变换查询、键和值。然后,这
对于
代码实现
在实现过程中通常选择缩放点积注意力,作为每一个注意力头。
py
import math
import torch
from torch import nn
from d2l import torch as d2l
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# queries,keys,values的形状: (batch_size,查询或者“键-值”对的个数,num_hiddens)
# valid_lens的形状: (batch_size,)或(batch_size,查询的个数)
# 经过变换后,输出的queries,keys,values 的形状:
# (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 在轴0,将第一项(标量或者矢量)复制num_heads次,
# 然后如此复制第二项,然后诸如此类。
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
# output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
output = self.attention(queries, keys, values, valid_lens)
# output_concat的形状:(batch_size,查询的个数,num_hiddens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
# 先降维拆分:为了多注意力头的并行计算而变换形状
def transpose_qkv(X, num_heads):
# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3)
# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])
# 再升维融合:逆转transpose_qkv函数的操作
def transpose_output(X, num_heads):
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
# 实例实用
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
# torch.Size([2, 4, 100])自注意力和位置编码
自注意力定义
将词元序列输入注意力池化中,以便同一组词元同时充当查询、键和值。
即每个查询都会关注所有键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,被称为自注意力(self-attention),或内部注意力(intra-attention)。
py
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape
# torch.Size([2, 4, 100])与 CNN、RNN 比较
目标:把
CNN(卷积神经网络):
- 设卷积核大小
,序列长度 ,输入输出通道数 ,卷积层计算复杂度为 。 - 卷积神经网络是分层的,因有
个顺序操作,最大路径长度为 。
RNN(循环神经网络):
- 更新循环神经网络隐状态时,
权重矩阵和 维隐状态乘法计算复杂度为 ,序列长度为 ,因此计算复杂度为 。 - 有
个顺序操作无法并行化,最大路径长度为 。
自注意力:
- 查询、键和值都是
矩阵,考虑缩放点积注意力, 矩阵乘以 矩阵。之后输出的 矩阵乘以 矩阵,计算复杂度为 。 - 每个词元都通过自注意力直接连接到任何其他词元,有
个顺序操作可以并行计算,最大路径长度也是 。
结论:卷积神经网络和自注意力都拥有并行计算优势,且自注意力最大路径长度最短。但是因为自注意力计算复杂度是关于序列长度的二次方,所以在很长序列中计算会非常慢。
位置编码
在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。
为了使用序列的顺序信息,通过在输入表示中添加位置编码(positional encoding)来注入绝对或相对位置信息。位置编码可以通过学习得到也可以直接固定得到。
下面介绍固定位置编码:
假设输入
在位置嵌入矩阵
矩阵第
矩阵第
代码实现如下:
py
import math
import torch
from torch import nn
from d2l import torch as d2l
# 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
# Go!
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])效果如图:可以看到位置嵌入矩阵第 6、7 列和频率高于第 8、9 列。
点解?在二进制表示中,较高比特位的交替频率低于较低比特位。
py
for i in range(8):
print(f'{i}的二进制是:{i:>03b}')
# 0的二进制是:000
# 1的二进制是:001
# 2的二进制是:010
# 3的二进制是:011
# 4的二进制是:100
# 5的二进制是:101
# 6的二进制是:110
# 7的二进制是:111打印 (0、1、...、7) 的二进制表示形式,可以看到:每 1 个数字、每 2 个数字和每 4 个数字上的比特值在第 1 个最低位、第 2 个最低位和第 3 个最低位上分别交替 0 和 1。