Multi Head Self Attention

  1. MHA Torch 实现

MHA Torch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, f"d_model({d_model})必须是num_heads({num_heads})的整数倍"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1))                           # [batch_size, num_heads, seq_len_q, seq_len_k]
        scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=q.dtype))   # 缩放避免梯度消失
        attn_weights = F.softmax(scaled_scores, dim=-1)                         # [batch_size, num_heads, seq_len_q, seq_len_k]
        output = torch.matmul(attn_weights, v)                                  # [batch_size, num_heads, seq_len_q, d_v]
        return output, attn_weights

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size = x.size(0)
        seq_len = x.size(1)
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        q = q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)  # [B, H, L, d_k]
        k = k.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)  # [B, H, L, d_k]
        v = v.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)  # [B, H, L, d_k]
        attn_output, attn_weights = self.scaled_dot_product_attention(q, k, v, mask) # attn_output形状:[batch_size, num_heads, seq_len, d_k]
        attn_output = attn_output.transpose(1, 2).contiguous()  # [batch_size, num_heads, seq_len, d_k] → [batch_size, seq_len, d_model]
        output = attn_output.view(batch_size, seq_len, self.d_model)  # [B, L, d_model](H*d_k = d_model)
        output = self.w_o(output)  # 线性映射回d_model维度
        output = self.dropout(output)  # 防止过拟合
        return output, attn_weights

if __name__ == "__main__":
    batch_size = 2
    seq_len = 5
    d_model = 64
    num_heads = 8
    dropout = 0.1

    mha = MultiHeadSelfAttention(d_model=d_model, num_heads=num_heads, dropout=dropout)

    x = torch.randn(batch_size, seq_len, d_model)  # [2, 5, 64]
    mask = torch.randint(0, 2, (batch_size, seq_len, seq_len))  # [2, 5, 5]
    output, attn_weights = mha(x, mask=mask)

    print("="*50)
    print(f"输入形状: {x.shape}")                  # 期望:torch.Size([2, 5, 64])
    print(f"输出形状: {output.shape}")              # 期望:torch.Size([2, 5, 64])
    print(f"注意力权重形状: {attn_weights.shape}")  # 期望:torch.Size([2, 8, 5, 5])
    print("="*50)
    assert output.shape == x.shape, "输出维度与输入维度不一致!"
    assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len), "注意力权重维度错误!"
    print("✅ 维度验证通过!")

转载请注明来源 goldandrabbit.github.io