IIWAB transformer中的encoder,decoder, cross attention, self attention - IIWAB

transformer中的encoder,decoder, cross attention, self attention

IIWAB 16天前 ⋅ 62 阅读

在Transformer架构里,EncoderDecoderCross AttentionSelf Attention是关键组件,

1. Self Attention

Self Attention是Transformer的核心机制,它能够让模型在处理序列时,关注序列里不同位置的元素,进而捕获长距离依赖关系。具体而言,Self Attention会为序列中的每个位置计算一个加权和,这些权重反映了该位置与序列中其他位置的相关性。

Self Attention的计算步骤如下:

  • 针对输入序列中的每个元素,分别计算查询(Query)、键(Key)和值(Value)向量。
  • 计算Query和Key之间的相似度,一般采用点积运算。
  • 对相似度得分进行Softmax操作,得到注意力权重。
  • 用注意力权重对Value向量进行加权求和,从而得到输出。

以下是Self Attention的简单Python示例:

import torch
import torch.nn.functional as F

def self_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    return output, attn_weights

# 示例输入
seq_length = 4
d_model = 8
Q = torch.randn(seq_length, d_model)
K = torch.randn(seq_length, d_model)
V = torch.randn(seq_length, d_model)

output, attn_weights = self_attention(Q, K, V)
print("Output shape:", output.shape)
print("Attention weights shape:", attn_weights.shape)

2. Cross Attention

Cross Attention主要用于解码器(Decoder),它能够让解码器关注编码器(Encoder)的输出。与Self Attention不同,Cross Attention的Query来自解码器的输入,而Key和Value则来自编码器的输出。

Cross Attention的计算步骤和Self Attention类似,不过其Query、Key和Value的来源不同。通过Cross Attention,解码器可以利用编码器提取的信息,生成目标序列。

3. Encoder

Encoder由多个相同的层堆叠而成,每一层包含两个子层:Self Attention层和前馈神经网络(Feed Forward Network)。Self Attention层让编码器能够捕获输入序列中的长距离依赖关系,而前馈神经网络则对Self Attention的输出进行非线性变换。

Encoder的主要作用是对输入序列进行编码,将其转换为一系列的特征表示,这些特征表示会被传递给解码器。

4. Decoder

Decoder同样由多个相同的层堆叠而成,每一层包含三个子层:Self Attention层、Cross Attention层和前馈神经网络。Self Attention层让解码器能够捕获目标序列中的长距离依赖关系,Cross Attention层让解码器能够关注编码器的输出,前馈神经网络则对Cross Attention的输出进行非线性变换。

Decoder的主要作用是根据编码器的输出和之前生成的目标序列,生成下一个目标元素。在生成过程中,会使用掩码(Mask)来保证解码器只能关注到之前生成的元素。

以下是一个简化的Transformer Encoder和Decoder的Python示例:

import torch
import torch.nn as nn

class EncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class DecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

# 示例使用
d_model = 512
nhead = 8
encoder_layer = EncoderLayer(d_model, nhead)
decoder_layer = DecoderLayer(d_model, nhead)

src = torch.randn(10, 32, d_model)  # 输入序列
tgt = torch.randn(20, 32, d_model)  # 目标序列

encoder_output = encoder_layer(src)
decoder_output = decoder_layer(tgt, encoder_output)

print("Encoder output shape:", encoder_output.shape)
print("Decoder output shape:", decoder_output.shape)

综上所述,Self Attention和Cross Attention是Transformer架构的核心机制,它们让模型能够捕获序列中的长距离依赖关系;Encoder负责对输入序列进行编码,Decoder则根据编码器的输出和之前生成的目标序列,生成下一个目标元素。


全部评论: 0

    我有话说: