-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
144 lines (104 loc) · 4.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __future__ import annotations
from dataclasses import dataclass,field
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import math
DEBUGE = True
@dataclass
class Node:
idx: int
value: Tensor
adjacency_list: list[Edge] = field(default_factory=list)
@dataclass
class Edge:
node: Node
weight: Tensor
# call this function for multi-headed attn
def scaled_graph_attention(query:Tensor, key:Tensor, value:Tensor):
batch,num_heads,seq_len,head_dim = query.shape
assert batch==1, "batch size must be one"
# key = key.reshape(batch,num_heads,seq_len,head_dim)
# query = query.reshape(batch,num_heads,seq_len,head_dim)
outputs = []
for head_idx in range(num_heads):
q = query[:,head_idx,:,:]
k = key[:,head_idx,:,:]
v = value[:,head_idx,:,:]
result = casual_self_attention_with_graph(q,k,v)
outputs.append(result)
output = torch.stack(outputs,dim=1)
return output.reshape(batch,num_heads,seq_len,head_dim )
def casual_self_attention_with_graph(query:Tensor, key:Tensor, value:Tensor):
batch,seq_len,d_model = query.shape
nodes = [Node(idx,value[:,idx,:],[]) for idx in range(seq_len)]
graph = build_graph(nodes,key,query)
# traversing graph
outputs = []
for r_idx,root in enumerate(graph):
curr_value = torch.zeros(1,1,d_model)
for edge in root.adjacency_list:
curr_value += edge.node.value * edge.weight
outputs.append(curr_value)
output = torch.stack(outputs,dim=-2).squeeze(dim=2)
output = output.reshape(batch,seq_len,d_model)
return output
def build_graph(nodes:list[Node],keys:Tensor,queries:Tensor):
batch,seq_len,d_model = queries.shape
for idx,curr_node in enumerate(nodes):
# picking 1 to n keys
keys_history = keys[:,:idx+1,:]
# picking nth query
curr_query = queries[:,idx,:]
# here we take dot product (concise similarity) between current query
# and all keys that contains in histoy of current node (token)
similarity_values = curr_query@keys_history.transpose(-1,-2)
# if DEBUGE: print(f"{keys_history.shape=} {curr_query.shape=} {similarity_values.shape=} ")
similarity_values = similarity_values/math.sqrt(d_model)
# after softmax you will get weights with indicates
# how much current node want pay attention to past node
attn = F.softmax(similarity_values.float(),dim=-1).type_as(keys)
attn = attn.reshape(-1) # reshaping to make it simple
# if DEBUGE: print(attn)
# adding back edges in adjacency list of each node
for nidx,node in enumerate(nodes[:idx+1]):
edge_weight = attn[nidx]
# if DEBUGE: print(f"{idx} attend to {nidx} node with {edge_weight:.2f}")
edge = Edge(node=node,weight=edge_weight)
# curent node is getting weighted edge with all past nodes
curr_node.adjacency_list.append(
edge
)
return nodes
@torch.no_grad
def test_attn():
torch.manual_seed(6)
batch = 1
seq_len = 8
d_model = 2**10
num_heads = 2
head_dim = int(d_model/num_heads)
assert batch == 1, "Batch size must be 1 for this test"
Wk = nn.Linear(d_model, d_model)
Wq = nn.Linear(d_model, d_model)
Wv = nn.Linear(d_model, d_model)
x = torch.rand(batch, seq_len, d_model)
key: Tensor = Wk(x)
query: Tensor = Wq(x)
value: Tensor = Wv(x)
# reshape batch, num_heads, seq_len, head_dim
key = key.reshape(batch, num_heads, seq_len, head_dim)
query = query.reshape(batch, num_heads, seq_len, head_dim)
value = value.reshape(batch, num_heads, seq_len, head_dim)
mask = torch.triu(torch.ones(1,1,seq_len,seq_len) *-torch.inf,diagonal=1)
scores = [email protected](-1,-2) / math.sqrt(head_dim)
scores = mask+scores
attn_mtx = F.softmax(scores,dim=-1)
out = attn_mtx@value
output = scaled_graph_attention(query, key, value)
assert torch.isclose(output,out,atol=1e-5).all() , "you need to debug buddy"
print("IT WORKS !!!")
ITER = 3
for _ in range(ITER):
test_attn()