spoorthibhat commited on
Commit
a99acda
·
verified ·
1 Parent(s): d6b67ed

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.chdir("LLaVA_Med")
4
+ os.system('pip install -q -e .')
5
+
6
+ import warnings
7
+ warnings.filterwarnings('ignore')
8
+
9
+ import io
10
+ from contextlib import redirect_stdout
11
+ import gradio as gr
12
+ from transformers import AutoTokenizer
13
+ from llava.model.builder import load_pretrained_model
14
+ from llava.mm_utils import get_model_name_from_path
15
+ from llava.eval.run_llava import eval_model
16
+
17
+ # Define the model path
18
+ model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
19
+
20
+ # Load the model
21
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
22
+ model_path=model_path,
23
+ model_base=None,
24
+ model_name=get_model_name_from_path(model_path)
25
+ )
26
+
27
+ # Define the inference function
28
+ def run_inference(image, question):
29
+ args = type('Args', (), {
30
+ "model_path": model_path,
31
+ "model_base": None,
32
+ "image_file": image,
33
+ "query": question,
34
+ "conv_mode": None,
35
+ "sep": ",",
36
+ "temperature": 0,
37
+ "top_p": None,
38
+ "num_beams": 1,
39
+ "max_new_tokens": 512
40
+ })()
41
+
42
+ # Capture the printed output of eval_model
43
+ f = io.StringIO()
44
+ with redirect_stdout(f):
45
+ eval_model(args)
46
+ output = f.getvalue()
47
+ return output
48
+
49
+ # Create the Gradio interface
50
+ with gr.Blocks(theme=gr.themes.Monochrome()) as app:
51
+ with gr.Column(scale=1):
52
+ gr.Markdown("<center><h1>LLaVA-Med</h1></center>")
53
+
54
+ with gr.Row():
55
+ image = gr.Image(type="filepath", scale=2)
56
+ question = gr.Textbox(placeholder="Enter a question", scale=3)
57
+
58
+ with gr.Row():
59
+ answer = gr.Textbox(placeholder="Answer pops up here", scale=1)
60
+
61
+ with gr.Row():
62
+ btn = gr.Button("Run Inference", scale=1)
63
+
64
+ btn.click(fn=run_inference, inputs=[image, question], outputs=answer)
65
+
66
+ # Launch the app
67
+ if __name__ == "__main__":
68
+ app.queue().launch()