5to9 commited on
Commit
20d9962
·
1 Parent(s): 1a0b767

0.19 implementing flash_attn

Browse files
Files changed (2) hide show
  1. app.py +9 -0
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import gradio as gr
5
  import logging
6
  from huggingface_hub import login
 
7
 
8
  import os
9
  import traceback
@@ -66,6 +67,10 @@ def load_model_a(model_id):
66
  device_map="auto",
67
  trust_remote_code=True,
68
  ).eval()
 
 
 
 
69
  except Exception as e:
70
  logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
71
 
@@ -83,6 +88,10 @@ def load_model_b(model_id):
83
  device_map="auto",
84
  trust_remote_code=True,
85
  ).eval()
 
 
 
 
86
  except Exception as e:
87
  logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
88
  return gr.update(label=model_id)
 
4
  import gradio as gr
5
  import logging
6
  from huggingface_hub import login
7
+ from flash_attn.flash_attention import FlashAttention
8
 
9
  import os
10
  import traceback
 
67
  device_map="auto",
68
  trust_remote_code=True,
69
  ).eval()
70
+ for name, module in model_a.named_modules():
71
+ if isinstance(module, torch.nn.MultiheadAttention):
72
+ module.forward = FlashAttention(module.embed_dim)
73
+ logging.debug(f'{SPACER} forwarding module of {model_id_a} to flash_attn')
74
  except Exception as e:
75
  logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
76
 
 
88
  device_map="auto",
89
  trust_remote_code=True,
90
  ).eval()
91
+ for name, module in model_b.named_modules():
92
+ if isinstance(module, torch.nn.MultiheadAttention):
93
+ module.forward = FlashAttention(module.embed_dim)
94
+ logging.debug(f'{SPACER} forwarding module of {model_id_b} to flash_attn')
95
  except Exception as e:
96
  logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
97
  return gr.update(label=model_id)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ accelerate==0.33.0
5
  sentencepiece==0.2.0
6
  spaces==0.29.2
7
  gradio==4.39.0
8
- bitsandbytes==0.43.2
 
 
5
  sentencepiece==0.2.0
6
  spaces==0.29.2
7
  gradio==4.39.0
8
+ bitsandbytes==0.43.2
9
+ flash-attn