-
Notifications
You must be signed in to change notification settings - Fork 150
Expand file tree
/
Copy pathreference.py
More file actions
170 lines (135 loc) · 6.35 KB
/
reference.py
File metadata and controls
170 lines (135 loc) · 6.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from utils import make_match_reference, DisableCuDNNTF32
from task import input_t, output_t
import torch
from torch import nn, einsum
import math
# Reference code in PyTorch
class TriMul(nn.Module):
# Based on https://github.com/lucidrains/triangle-multiplicative-module/blob/main/triangle_multiplicative_module/triangle_multiplicative_module.py
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
self.to_out_norm = nn.LayerNorm(hidden_dim)
self.to_out = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
x: [bs, seq_len, seq_len, dim]
mask: [bs, seq_len, seq_len]
Returns:
output: [bs, seq_len, seq_len, dim]
"""
batch_size, seq_len, _, dim = x.shape
x = self.norm(x)
left = self.left_proj(x)
right = self.right_proj(x)
mask = mask.unsqueeze(-1)
left = left * mask
right = right * mask
left_gate = self.left_gate(x).sigmoid()
right_gate = self.right_gate(x).sigmoid()
out_gate = self.out_gate(x).sigmoid()
left = left * left_gate
right = right * right_gate
out = einsum('... i k d, ... j k d -> ... i j d', left, right)
# This einsum is the same as the following:
# out = torch.zeros(batch_size, seq_len, seq_len, dim, device=x.device)
# # Compute using nested loops
# for b in range(batch_size):
# for i in range(seq_len):
# for j in range(seq_len):
# # Compute each output element
# for k in range(seq_len):
# out[b, i, j] += left[b, i, k, :] * right[b, j, k, :]
out = self.to_out_norm(out)
out = out * out_gate
return self.to_out(out)
def ref_kernel(data: input_t) -> output_t:
"""
Reference implementation of TriMul using PyTorch.
Args:
data: Tuple of (input: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, seq_len, dim]
- mask: Mask tensor of shape [batch_size, seq_len, seq_len]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
"""
# Use deterministic kernels and disable TF32 for accuracy
with DisableCuDNNTF32():
input_tensor, mask, weights, config = data
trimul = TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(input_tensor.device)
# Fill in the given weights of the model
trimul.norm.weight = nn.Parameter(weights['norm.weight'])
trimul.norm.bias = nn.Parameter(weights['norm.bias'])
trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])
output = trimul(input_tensor, mask)
return output
# Input generation for the reference code
def generate_input(
seqlen: int,
bs: int,
dim: int,
hiddendim: int,
seed: int,
nomask: bool,
distribution: str,
) -> input_t:
# Really dumb but for now _ isn't parsing correctly.
batch_size = bs
seq_len = seqlen
hidden_dim = hiddendim
no_mask = nomask
config = {
"hidden_dim": hidden_dim,
"dim": dim,
}
gen = torch.Generator(device='cuda')
gen.manual_seed(seed)
weights = {}
# Generate input tensor based on distribution
if distribution == "cauchy":
# Heavier tail distribution
zero = torch.tensor(0.0, device="cuda")
two = torch.tensor(2.0, device="cuda")
input_tensor = torch.distributions.Cauchy(zero, two).sample(
(batch_size, seq_len, seq_len, dim)
).to(device='cuda', dtype=torch.float32)
else: # normal distribution
input_tensor = torch.randn(
(batch_size, seq_len, seq_len, dim),
device='cuda',
dtype=torch.float32,
generator=gen
).contiguous()
if no_mask:
mask = torch.ones(batch_size, seq_len, seq_len, device=input_tensor.device)
else:
mask = torch.randint(0, 2, (batch_size, seq_len, seq_len), device=input_tensor.device, generator=gen)
# Initialize model weights based on distribution
weights["norm.weight"] = torch.randn(dim, device="cuda", dtype=torch.float32)
weights["norm.bias"] = torch.randn(dim, device="cuda", dtype=torch.float32)
weights["left_proj.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
weights["right_proj.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
weights["left_gate.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
weights["right_gate.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
weights["out_gate.weight"] = torch.randn(hidden_dim, dim, device="cuda", dtype=torch.float32) / math.sqrt(hidden_dim)
weights["to_out_norm.weight"] = torch.randn(hidden_dim, device="cuda", dtype=torch.float32)
weights["to_out.weight"] = torch.randn(dim, hidden_dim, device="cuda", dtype=torch.float32) / math.sqrt(dim)
weights["to_out_norm.bias"] = torch.randn(hidden_dim, device="cuda", dtype=torch.float32)
return (input_tensor, mask, weights, config)
check_implementation = make_match_reference(ref_kernel, rtol=2e-2, atol=2e-2)