limingj commited on
Commit
76c6b8a
1 Parent(s): 7ee46a0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM
3
+ from PIL import Image
4
+ import torch
5
+
6
+ # 加载模型和处理器
7
+ model_name = "microsoft/llava-med-v1.5-mistral-7b"
8
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
9
+ processor = AutoProcessor.from_pretrained(model_name)
10
+
11
+ def predict(image, question):
12
+ # 将图像和问题处理为模型输入格式
13
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
14
+
15
+ # 生成答案
16
+ with torch.no_grad():
17
+ outputs = model.generate(**inputs)
18
+
19
+ # 解码输出
20
+ answer = processor.batch_decode(outputs, skip_special_tokens=True)[0]
21
+ return answer
22
+
23
+ # 创建 Gradio 界面
24
+ interface = gr.Interface(
25
+ fn=predict,
26
+ inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Question")],
27
+ outputs="text",
28
+ title="Medical Visual Question Answering"
29
+ )
30
+
31
+ if __name__ == "__main__":
32
+ interface.launch()