MHA Torch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, f"d_model({d_model})必须是num_heads({num_heads})的整数倍"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) # [batch_size, num_heads, seq_len_q, seq_len_k]
scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=q.dtype)) # 缩放避免梯度消失
attn_weights = F.softmax(scaled_scores, dim=-1) # [batch_size, num_heads, seq_len_q, seq_len_k]
output = torch.matmul(attn_weights, v) # [batch_size, num_heads, seq_len_q, d_v]
return output, attn_weights
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
batch_size = x.size(0)
seq_len = x.size(1)
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
q = q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # [B, H, L, d_k]
k = k.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # [B, H, L, d_k]
v = v.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # [B, H, L, d_k]
attn_output, attn_weights = self.scaled_dot_product_attention(q, k, v, mask) # attn_output形状:[batch_size, num_heads, seq_len, d_k]
attn_output = attn_output.transpose(1, 2).contiguous() # [batch_size, num_heads, seq_len, d_k] → [batch_size, seq_len, d_model]
output = attn_output.view(batch_size, seq_len, self.d_model) # [B, L, d_model](H*d_k = d_model)
output = self.w_o(output) # 线性映射回d_model维度
output = self.dropout(output) # 防止过拟合
return output, attn_weights
if __name__ == "__main__":
batch_size = 2
seq_len = 5
d_model = 64
num_heads = 8
dropout = 0.1
mha = MultiHeadSelfAttention(d_model=d_model, num_heads=num_heads, dropout=dropout)
x = torch.randn(batch_size, seq_len, d_model) # [2, 5, 64]
mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)) # [2, 5, 5]
output, attn_weights = mha(x, mask=mask)
print("="*50)
print(f"输入形状: {x.shape}") # 期望:torch.Size([2, 5, 64])
print(f"输出形状: {output.shape}") # 期望:torch.Size([2, 5, 64])
print(f"注意力权重形状: {attn_weights.shape}") # 期望:torch.Size([2, 8, 5, 5])
print("="*50)
assert output.shape == x.shape, "输出维度与输入维度不一致!"
assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len), "注意力权重维度错误!"
print("✅ 维度验证通过!")
转载请注明来源 goldandrabbit.github.io