lukiod commited on
Commit
bd52e59
1 Parent(s): 5bc3f56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
4
  from PIL import Image
5
  from byaldi import RAGMultiModalModel
6
  from qwen_vl_utils import process_vision_info
@@ -8,24 +8,24 @@ from qwen_vl_utils import process_vision_info
8
  # Model and processor names
9
  RAG_MODEL = "vidore/colpali"
10
  QWN_MODEL = "Qwen/Qwen2-VL-7B-Instruct"
11
- QWN_PROCESSOR = "Qwen/Qwen2-VL-2B-Instruct"
12
 
13
  @st.cache_resource
14
  def load_models():
15
  RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL)
16
 
17
- model = AutoModelForCausalLM.from_pretrained(
18
  QWN_MODEL,
19
  torch_dtype=torch.bfloat16,
 
 
20
  trust_remote_code=True
21
- ).cuda().eval()
22
 
23
- processor = AutoProcessor.from_pretrained(QWN_PROCESSOR, trust_remote_code=True)
24
- tokenizer = AutoTokenizer.from_pretrained(QWN_PROCESSOR, trust_remote_code=True)
25
 
26
- return RAG, model, processor, tokenizer
27
 
28
- RAG, model, processor, tokenizer = load_models()
29
 
30
  def document_rag(text_query, image):
31
  messages = [
@@ -40,7 +40,7 @@ def document_rag(text_query, image):
40
  ],
41
  }
42
  ]
43
- text = tokenizer.apply_chat_template(
44
  messages, tokenize=False, add_generation_prompt=True
45
  )
46
  image_inputs, video_inputs = process_vision_info(messages)
@@ -51,12 +51,12 @@ def document_rag(text_query, image):
51
  padding=True,
52
  return_tensors="pt",
53
  )
54
- inputs = inputs.to("cuda")
55
  generated_ids = model.generate(**inputs, max_new_tokens=50)
56
  generated_ids_trimmed = [
57
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
58
  ]
59
- output_text = tokenizer.batch_decode(
60
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
61
  )
62
  return output_text[0]
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
  from PIL import Image
5
  from byaldi import RAGMultiModalModel
6
  from qwen_vl_utils import process_vision_info
 
8
  # Model and processor names
9
  RAG_MODEL = "vidore/colpali"
10
  QWN_MODEL = "Qwen/Qwen2-VL-7B-Instruct"
 
11
 
12
  @st.cache_resource
13
  def load_models():
14
  RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL)
15
 
16
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
17
  QWN_MODEL,
18
  torch_dtype=torch.bfloat16,
19
+ attn_implementation="flash_attention_2",
20
+ device_map="auto",
21
  trust_remote_code=True
22
+ ).eval()
23
 
24
+ processor = AutoProcessor.from_pretrained(QWN_MODEL, trust_remote_code=True)
 
25
 
26
+ return RAG, model, processor
27
 
28
+ RAG, model, processor = load_models()
29
 
30
  def document_rag(text_query, image):
31
  messages = [
 
40
  ],
41
  }
42
  ]
43
+ text = processor.apply_chat_template(
44
  messages, tokenize=False, add_generation_prompt=True
45
  )
46
  image_inputs, video_inputs = process_vision_info(messages)
 
51
  padding=True,
52
  return_tensors="pt",
53
  )
54
+ inputs = inputs.to(model.device)
55
  generated_ids = model.generate(**inputs, max_new_tokens=50)
56
  generated_ids_trimmed = [
57
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
58
  ]
59
+ output_text = processor.batch_decode(
60
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
61
  )
62
  return output_text[0]