davanstrien HF staff commited on
Commit
3555196
·
1 Parent(s): ec8173c
Files changed (1) hide show
  1. app.py +8 -47
app.py CHANGED
@@ -1,16 +1,14 @@
1
- # import subprocess # 🥲 need for flash attention in QWEN model
2
 
3
- # subprocess.run(
4
- # "pip install flash-attn --no-build-isolation",
5
- # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
- # shell=True,
7
- # )
8
 
9
  import spaces
10
  import gradio as gr
11
- from transformers import AutoModelForCausalLM, AutoProcessor
12
- # from transformers import Qwen2VLForConditionalGeneration # Uncomment when adding QWEN back
13
- # from qwen_vl_utils import process_vision_info # Uncomment when adding QWEN back
14
  import torch
15
  import os
16
  import json
@@ -33,15 +31,6 @@ processor = AutoProcessor.from_pretrained(
33
  device_map='auto'
34
  )
35
 
36
- # # Load Qwen model (commented out for now)
37
- # qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
38
- # "Qwen/Qwen2-VL-7B-Instruct",
39
- # torch_dtype=torch.bfloat16,
40
- # attn_implementation="flash_attention_2",
41
- # device_map="auto",
42
- # )
43
- # qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
44
-
45
  class GeneralRetrievalQuery(BaseModel):
46
  broad_topical_query: str
47
  broad_topical_explanation: str
@@ -91,34 +80,6 @@ Generate the queries based on this image and provide the response in the specifi
91
 
92
  prompt, pydantic_model = get_retrieval_prompt("general")
93
 
94
- # def _prep_data_for_input_qwen(image):
95
- # messages = [
96
- # {
97
- # "role": "user",
98
- # "content": [
99
- # {
100
- # "type": "image",
101
- # "image": image,
102
- # },
103
- # {"type": "text", "text": prompt},
104
- # ],
105
- # }
106
- # ]
107
- #
108
- # text = qwen_processor.apply_chat_template(
109
- # messages, tokenize=False, add_generation_prompt=True
110
- # )
111
- #
112
- # image_inputs, video_inputs = process_vision_info(messages)
113
- #
114
- # return qwen_processor(
115
- # text=[text],
116
- # images=image_inputs,
117
- # videos=video_inputs,
118
- # padding=True,
119
- # return_tensors="pt",
120
- # )
121
-
122
  def _prep_data_for_input(image):
123
  return processor.process(
124
  images=[image],
@@ -131,7 +92,7 @@ def generate_response(image):
131
  inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
132
  output = model.generate_from_batch(
133
  inputs,
134
- gr.GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
135
  tokenizer=processor.tokenizer
136
  )
137
  generated_tokens = output[0, inputs['input_ids'].size(1):]
 
1
+ import subprocess # 🥲
2
 
3
+ subprocess.run(
4
+ "pip install flash-attn --no-build-isolation",
5
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
+ shell=True,
7
+ )
8
 
9
  import spaces
10
  import gradio as gr
11
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
 
 
12
  import torch
13
  import os
14
  import json
 
31
  device_map='auto'
32
  )
33
 
 
 
 
 
 
 
 
 
 
34
  class GeneralRetrievalQuery(BaseModel):
35
  broad_topical_query: str
36
  broad_topical_explanation: str
 
80
 
81
  prompt, pydantic_model = get_retrieval_prompt("general")
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def _prep_data_for_input(image):
84
  return processor.process(
85
  images=[image],
 
92
  inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
93
  output = model.generate_from_batch(
94
  inputs,
95
+ GenerationConfig(max_new_tokens=200, stop_token="<|endoftext|>"),
96
  tokenizer=processor.tokenizer
97
  )
98
  generated_tokens = output[0, inputs['input_ids'].size(1):]