1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
|
# Nhập các thư viện cần thiết
import math
import torch
import torch.nn as nn
import warnings
warnings.filterwarnings(action="ignore")
# Tạo một lớp Block
class SimpleDecoder(nn.Module):
def __init__(self, hidden_dim, nums_head, dropout=0.1):
super().__init__()
self.nums_head = nums_head
self.head_dim = hidden_dim // nums_head
self.dropout = dropout
# Thực hiện theo cách post_norm trong decoder Transformers, chú ý có residual connection
# eps nhằm ngăn chặn tràn số; LLaMA thường dùng RMSnorm và pre-norm cho sự ổn định
# RMSNorm sử dụng bình phương trung bình của w để chuẩn hóa $\sqrt{\frac{1}{n} \sum_{i=1}^{n}{a_i^2} }$
self.layernorm_att = nn.LayerNorm(hidden_dim, eps=0.00001)
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
self.drop_att = nn.Dropout(self.dropout)
# Chuẩn bị cho FFN
self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4)
self.down_proj = nn.Linear(hidden_dim * 4, hidden_dim)
self.layernorm_ffn = nn.LayerNorm(hidden_dim, eps=0.00001)
self.act_fn = nn.ReLU()
self.drop_ffn = nn.Dropout(self.dropout)
def attention_output(self, query, key, value, attention_mask=None):
# Tính tương quan giữa query và key
key = key.transpose(2, 3) # (batch, num_head, head_dim, seq)
att_weight = torch.matmul(query, key) / math.sqrt(self.head_dim)
# Điều chỉnh attention mask để tạo causal_attention
if attention_mask is not None:
# Chuyển thành ma trận tam giác dưới
attention_mask = attention_mask.tril()
att_weight = att_weight.masked_fill(attention_mask == 0, float("-1e20"))
else:
# Tạo thủ công một attention mask tam giác dưới
attention_mask = torch.ones_like(att_weight).tril()
att_weight = att_weight.masked_fill(attention_mask == 0, float("-1e20"))
att_weight = torch.softmax(att_weight, dim=-1)
print(att_weight)
att_weight = self.drop_att(att_weight)
mid_output = torch.matmul(att_weight, value)
# mid_output shape là: (batch, nums_head, seq, head_dim)
mid_output = mid_output.transpose(1, 2).contiguous()
batch, seq, _, _ = mid_output.size()
mid_output = mid_output.view(batch, seq, -1)
output = self.o_proj(mid_output)
return output
def attention_block(self, X, attention_mask=None):
batch, seq, _ = X.size()
query = self.q_proj(X).view(batch, seq, self.nums_head, -1).transpose(1, 2)
key = self.k_proj(X).view(batch, seq, self.nums_head, -1).transpose(1, 2)
value = self.v_proj(X).view(batch, seq, self.nums_head, -1).transpose(1, 2)
output = self.attention_output(
query,
key,
value,
attention_mask=attention_mask,
)
return self.layernorm_att(X + output)
def ffn_block(self, X):
up = self.act_fn(
self.up_proj(X),
)
down = self.down_proj(up)
# Áp dụng dropout
down = self.drop_ffn(down)
# Thực hiện chuẩn hóa
return self.layernorm_ffn(X + down)
def forward(self, X, attention_mask=None):
# X giả định đã qua embedding, kích thước (batch, seq, hidden_dim)
# attention_mask chỉ ra các mẫu nào cần bỏ qua
# Kích thước thường là: (batch, nums_head, seq)
att_output = self.attention_block(X, attention_mask=attention_mask)
ffn_output = self.ffn_block(att_output)
return ffn_output
# Kiểm thử
x = torch.rand(3, 4, 64)
net = SimpleDecoder(64, 8)
mask = (
torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
.unsqueeze(1)
.unsqueeze(2)
.repeat(1, 8, 4, 1)
)
net(x, mask).shape
|