Skip to content
给每一条河每一座山取一个温暖的名字,我有一所房子,面朝大海,春暖花开。

Transformer 架构

自注意力同时具有并行计算最短最大路径长度这两个优势,因此使用自注意力来设计深度架构是很有吸引力的。

Transformer 模型完全基于注意力机制,没有任何卷积层或循环神经网络层。

尽管 Transformer 最初是应用于在文本数据上的序列到序列学习,但现在已经推广到各种现代的深度学习中,例如语言、视觉、语音和强化学习领域。

模型

An Image Transformer 是编码器-解码器架构实例,基于自注意力模块叠加而成。

源序列(输入)嵌入embedding目标序列(输出)嵌入,都加上位置编码positional encoding,分别输入到编码器和解码器中。

编码器:由多个相同层叠加而成的,每个层都有两个子层

  • 第一个子层:多头自注意力multi-head self-attention汇聚。
  • 第二个子层:逐位前馈网络positionwise feed-forward network:FFN
  • 计算时,查询、键和值都来自前一个编码器层输出,每个子层都采用了残差连接residual connection,在残差连接加法计算后,应用层规范化layer normalization

解码器:也是由多个相同层叠加而成,同样使用了残差连接层规范化

  • 第三个子层:插入在这两个子层之间,称为编码器-解码器注意力encoder-decoder attention层:查询来自前一个解码器层输出,而键和值来自整个编码器输出。
  • 解码器自注意力中,查询、键和值都来自上一个解码器层输出。但解码器中每个位置只能考虑该位置之前的所有位置。这种掩蔽masked注意力保留了自回归auto-regressive属性,确保预测仅依赖于已生成的输出词元。

基于位置的前馈网络(FFN)

Transformer 模型中基于位置的前馈网络使用同一个多层感知机,作用是对所有序列位置表示进行转换

py
class PositionWiseFFN(nn.Module):
  def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
    super(PositionWiseFFN, self).__init__(**kwargs)
    self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
    self.relu = nn.ReLU()
    self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

  def forward(self, X):
    return self.dense2(self.relu(self.dense1(X)))

# 实例化
ffn = PositionWiseFFN(4, 4, 8)
  • 在标准 Transformer 中,ffn_num_outputs 通常与模型主维度(如输入维度)一致,而隐藏层是主维度 4 倍(如输入 512 → 隐藏层 2048)。
  • 本例输出维度(8)与输入(4)不同,或破坏残差连接条件(需 输入维度 == 输出维度)。

基于位置

nn.Linear 默认行为:位置独立计算,不混合不同位置信息。

  • 当输入是 (B=batch_size, T=sequence_length, D_in=输入特征维度) 时,对 T 的每个位置独立计算,即 "位置相关"。
层类型位置独立原因
nn.Linear✅ 是默认独立处理 (B, T, D) 中每个 T
nn.Conv1d❌ 否滑动窗口混合相邻位置的信息。
nn.LSTM❌ 否隐状态依赖前序位置的计算结果。
MultiHeadAttention❌ 否显式计算所有位置间的注意力权重。
PositionWiseFFN✅ 是nn.Linear 堆叠,独立处理每个位置。

前馈网络

前馈网络Feedforward Neural Network, FNN 是一种最基本神经网络结构。其核心特点是:数据单向流动(从输入层 → 隐藏层 → 输出层),没有循环或反馈连接

  • “Feed”:数据被“喂入”网络。
  • “Forward”:数据只向前流动,不反向或循环。

数学表达

给定输入 XRB×T×D(B=批大小,T=序列长度,D=特征维度):

FFN(X)=W2ReLU(W1X+b1)+b2
  • W1RD×Dhidden, W2RDhidden×D
  • 每个位置的输出仅依赖该位置的输入,不依赖其他位置。

前馈网络 vs 其他网络

网络类型典型用途数据流动方向示例
前馈网络图像分类、特征提取单向(输入 → 输出)MLP, PositionWiseFFN
循环网络时序数据(文本、语音)双向(含时间反馈)LSTM, GRU
卷积网络图像、空间数据局部连接+权重共享ResNet, VGG
Transformer序列建模(如机器翻译)自注意力+前馈BERT, GPT

残差连接和层规范化(add&norm)

Transformer 中的残差连接和层规范化,是训练非常深度模型的重要工具。

层规范化批量规范化的目标相同,但层规范化是基于特征维度进行规范化。尽管批量规范化在计算机视觉中被广泛应用,但在自然语言处理任务中(输入通常是变长序列)批量规范化通常不如层规范化的效果好。

残差连接要求两个输入形状相同,以便加法操作后输出张量形状相同。

py
# 残差连接后进行层规范化
class AddNorm(nn.Module):
  def __init__(self, normalized_shape, dropout, **kwargs):
    super(AddNorm, self).__init__(**kwargs)
    self.dropout = nn.Dropout(dropout)
    self.ln = nn.LayerNorm(normalized_shape)

  def forward(self, X, Y):
    return self.ln(self.dropout(Y) + X)

# 示例
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
# torch.Size([2, 3, 4])

实现编码器

py
# 编码器块
class EncoderBlock(nn.Module):
  def __init__(self, key_size, query_size, value_size, num_hiddens,
              norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
              dropout, use_bias=False, **kwargs):
    super(EncoderBlock, self).__init__(**kwargs)
    self.attention = d2l.MultiHeadAttention(
        key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
    self.addnorm1 = AddNorm(norm_shape, dropout)
    self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
    self.addnorm2 = AddNorm(norm_shape, dropout)

  def forward(self, X, valid_lens):
    Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
    return self.addnorm2(Y, self.ffn(Y))

# Transformer编码器
class TransformerEncoder(d2l.Encoder):
  def __init__(self, vocab_size, key_size, query_size, value_size,
              num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
              num_heads, num_layers, dropout, use_bias=False, **kwargs):
    super(TransformerEncoder, self).__init__(**kwargs)
    self.num_hiddens = num_hiddens
    self.embedding = nn.Embedding(vocab_size, num_hiddens)
    self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
    self.blks = nn.Sequential()
    for i in range(num_layers):
      self.blks.add_module("block"+str(i), EncoderBlock(
        key_size, query_size, value_size, num_hiddens,
        norm_shape, ffn_num_input, ffn_num_hiddens,
        num_heads, dropout, use_bias))

  def forward(self, X, valid_lens, *args):
    # 因为位置编码值在-1和1之间,
    # 因此嵌入值乘以嵌入维度的平方根进行缩放,
    # 然后再与位置编码相加。
    X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
    self.attention_weights = [None] * len(self.blks)
    for i, blk in enumerate(self.blks):
      X = blk(X, valid_lens)
      self.attention_weights[i] = blk.attention.attention.attention_weights
    return X
  • 嵌入和位置编码两数相加后生成“一个数”,实际上是高维空间一个坐标点
  • 而这个点,可以看作是原始嵌入向量和位置编码向量在该维度上的线性组合
  • 模型通过训练,能够解耦出原始嵌入和位置信息。

解码器实现

py
# 解码器中第i个块
class DecoderBlock(nn.Module):
  def __init__(self, key_size, query_size, value_size, num_hiddens,
                norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                dropout, i, **kwargs):
    super(DecoderBlock, self).__init__(**kwargs)
    self.i = i
    self.attention1 = d2l.MultiHeadAttention(
        key_size, query_size, value_size, num_hiddens, num_heads, dropout)
    self.addnorm1 = AddNorm(norm_shape, dropout)
    self.attention2 = d2l.MultiHeadAttention(
        key_size, query_size, value_size, num_hiddens, num_heads, dropout)
    self.addnorm2 = AddNorm(norm_shape, dropout)
    self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
    self.addnorm3 = AddNorm(norm_shape, dropout)

  def forward(self, X, state):
    enc_outputs, enc_valid_lens = state[0], state[1]
    # 训练阶段,输出序列的所有词元都在同一时间处理,
    # 因此state[2][self.i]初始化为None。
    # 预测阶段,输出序列是通过词元一个接着一个解码的,
    # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
    if state[2][self.i] is None:
      key_values = X
    else:
      key_values = torch.cat((state[2][self.i], X), axis=1)
    state[2][self.i] = key_values
    if self.training:
      batch_size, num_steps, _ = X.shape
      # dec_valid_lens的开头:(batch_size,num_steps),
      # 其中每一行是[1,2,...,num_steps]
      dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
    else:
      dec_valid_lens = None

    # 自注意力
    X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
    Y = self.addnorm1(X, X2)
    # 编码器-解码器注意力。
    # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
    Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
    Z = self.addnorm2(Y, Y2)
    return self.addnorm3(Z, self.ffn(Z)), state

# 解码器
class TransformerDecoder(d2l.AttentionDecoder):
  def __init__(self, vocab_size, key_size, query_size, value_size,
                num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                num_heads, num_layers, dropout, **kwargs):
    super(TransformerDecoder, self).__init__(**kwargs)
    self.num_hiddens = num_hiddens
    self.num_layers = num_layers
    self.embedding = nn.Embedding(vocab_size, num_hiddens)
    self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
    self.blks = nn.Sequential()
    for i in range(num_layers):
      self.blks.add_module("block"+str(i),
          DecoderBlock(key_size, query_size, value_size, num_hiddens,
                        norm_shape, ffn_num_input, ffn_num_hiddens,
                        num_heads, dropout, i))
    self.dense = nn.Linear(num_hiddens, vocab_size)

  def init_state(self, enc_outputs, enc_valid_lens, *args):
    return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

  def forward(self, X, state):
    X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
    self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
    for i, blk in enumerate(self.blks):
      X, state = blk(X, state)
      # 解码器自注意力权重
      self._attention_weights[0][i] = blk.attention1.attention.attention_weights
      # “编码器-解码器”自注意力权重
      self._attention_weights[1][i] = blk.attention2.attention.attention_weights
    return self.dense(X), state

  @property
  def attention_weights(self):
    return self._attention_weights

训练

py
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

encoder = TransformerEncoder(
    len(src_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
decoder = TransformerDecoder(
    len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

d2l.EncoderDecoderEncoderDecoder

ViT:Vision Transformer

ViT 全称 Vision Transformer(视觉变换器),是2020年谷歌团队提出的一种纯Transformer架构,用于图像分类任务的模型。(论文标题:《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》)

一句话总结它的核心思想:把一张图片当成一段“话”,把图片切成很多小方块(patch),每个小方块当成一个“词”,然后直接扔给Transformer去处理。

它几乎是把NLP里经典的Transformer(就是BERT、GPT那一套Encoder结构)几乎原封不动地搬到图像领域来了。

如下是 ViT-Base/16(patch 16×16,12层,768维,12头)的简化版核心实现,常用于教学和快速实验。

py
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
  # 把图像切成patch并做线性投影
  def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
    super().__init__()
    self.num_patches = (img_size // patch_size) ** 2
    self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

  def forward(self, x):
    # x: (B, C, H, W) → (B, embed_dim, num_patches_h, num_patches_w)
    x = self.proj(x)              # (B, embed_dim, n_h, n_w)
    x = x.flatten(2)              # (B, embed_dim, n_patches)
    x = x.transpose(1, 2)         # (B, n_patches, embed_dim)
    return x


class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim=768, num_heads=12, dropout=0.1):
    super().__init__()
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads
    self.scale = self.head_dim ** -0.5

    self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
    self.proj = nn.Linear(embed_dim, embed_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
    qkv = qkv.permute(2, 0, 3, 1, 4)          # [q,k,v] -> (3, B, heads, N, head_dim)
    q, k, v = qkv[0], qkv[1], qkv[2]

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.dropout(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    return x


class TransformerEncoderLayer(nn.Module):
  def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.1):
    super().__init__()
    self.norm1 = nn.LayerNorm(embed_dim)
    self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)

    self.norm2 = nn.LayerNorm(embed_dim)
    self.mlp = nn.Sequential(
      nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
      nn.GELU(),
      nn.Dropout(dropout),
      nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
      nn.Dropout(dropout)
    )

  def forward(self, x):
    x = x + self.attn(self.norm1(x))      # residual + attention
    x = x + self.mlp(self.norm2(x))       # residual + MLP
    return x


class VisionTransformer(nn.Module):
  """完整的ViT模型"""
  def __init__(
    self,
    img_size=224,
    patch_size=16,
    in_channels=3,
    num_classes=1000,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_ratio=4.0,
    dropout=0.1,
    emb_dropout=0.1
  ):
    super().__init__()

    self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
    num_patches = self.patch_embed.num_patches

    # [CLS] token + 位置编码
    self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
    self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
    self.dropout = nn.Dropout(emb_dropout)

    # Transformer Encoder堆叠
    self.encoder = nn.Sequential(*[
      TransformerEncoderLayer(embed_dim, num_heads, mlp_ratio, dropout)
      for _ in range(depth)
    ])

    # 分类头
    self.norm = nn.LayerNorm(embed_dim)
    self.head = nn.Linear(embed_dim, num_classes)

  def forward(self, x):
    B = x.shape[0]

    # patch embedding
    x = self.patch_embed(x)                     # (B, n_patches, embed_dim)

    # 加 [CLS] token
    cls_tokens = self.cls_token.expand(B, -1, -1)
    x = torch.cat((cls_tokens, x), dim=1)       # (B, 1 + n_patches, embed_dim)

    # 加位置编码 + dropout
    x = x + self.pos_embed
    x = self.dropout(x)

    # 通过Transformer encoder
    x = self.encoder(x)

    # 取 [CLS] token 输出,做分类
    x = self.norm(x)
    cls_out = x[:, 0]                           # 只取第一个token
    logits = self.head(cls_out)

    return logits


# 示例:实例化一个ViT-Base/16模型(跟原论文ViT-B/16一致)
if __name__ == "__main__":
  model = VisionTransformer(
    img_size=224,
    patch_size=16,
    num_classes=1000,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_ratio=4.0,
    dropout=0.1
  )

  x = torch.randn(2, 3, 224, 224)  # batch=2的随机输入
  out = model(x)
  print(out.shape)          # torch.Size([2, 1000])
Transformer 架构 has loaded