From f12c307d32b17f2fb2a7aeb6d17bd063dfc95966 Mon Sep 17 00:00:00 2001 From: Zhiyu Zhao <41675271+JerryFlymi@users.noreply.github.com> Date: Tue, 5 Sep 2023 15:44:20 +0800 Subject: [PATCH 1/3] Update for flash-attn --- tome/patch/mae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tome/patch/mae.py b/tome/patch/mae.py index 8a9e842d..83cce7fd 100644 --- a/tome/patch/mae.py +++ b/tome/patch/mae.py @@ -15,7 +15,7 @@ from tome.utils import parse_r -from .timm import ToMeBlock, ToMeAttention +from .timm import ToMeBlock, FlashAttnToMeAttention def make_tome_class(transformer_class): @@ -100,4 +100,4 @@ def apply_patch( module.__class__ = ToMeBlock module._tome_info = model._tome_info elif isinstance(module, Attention): - module.__class__ = ToMeAttention + module.__class__ = FlashAttnToMeAttention From 5f803816c78f382f309ff910158a01ceaed85a0e Mon Sep 17 00:00:00 2001 From: Zhiyu Zhao <41675271+JerryFlymi@users.noreply.github.com> Date: Tue, 5 Sep 2023 15:55:13 +0800 Subject: [PATCH 2/3] Update for flash-attn --- tome/patch/timm.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tome/patch/timm.py b/tome/patch/timm.py index ae2b8fc8..2a387181 100644 --- a/tome/patch/timm.py +++ b/tome/patch/timm.py @@ -96,6 +96,42 @@ def forward( return x, k.mean(1) +class FlashAttnToMeAttention(Attention): + """ + Modifications: + - apply Flash-attn + - Do not Apply proportional attention for MAE models + - Return the mean of k over heads from attention + """ + + def forward( + self, x: torch.Tensor, size: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Note: this is copied from timm.models.vision_transformer.Attention with modifications. + B, N, C = x.shape + try: + qkv_bias = torch.cat( + (self.q_bias, + torch.zeros_like(self.v_bias, + requires_grad=False), + self.v_bias)) + except: + qkv_bias = None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1) + k = qkv.permute(2, 0, 3, 1, 4)[1] + + x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, + softmax_scale=self.scale, + causal=False) + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + # Return k as well here + return x, k.mean(1) + + def make_tome_class(transformer_class): class ToMeVisionTransformer(transformer_class): """ From d8d64c67bffbdcf616144e2a46fe0e66297c0aaf Mon Sep 17 00:00:00 2001 From: Zhiyu Zhao <41675271+JerryFlymi@users.noreply.github.com> Date: Tue, 5 Sep 2023 15:59:52 +0800 Subject: [PATCH 3/3] Update for flash-attn --- tome/patch/timm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tome/patch/timm.py b/tome/patch/timm.py index 2a387181..5577db72 100644 --- a/tome/patch/timm.py +++ b/tome/patch/timm.py @@ -17,6 +17,8 @@ from tome.merge import bipartite_soft_matching, merge_source, merge_wavg from tome.utils import parse_r +from flash_attn import flash_attn_qkvpacked_func + class ToMeBlock(Block): """