Prasi21 commited on
Commit
20013e4
·
verified ·
1 Parent(s): 1d006e9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import Blip2ForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
4
+ from peft import PeftModel, PeftConfig
5
+
6
+ # Load the PEFT model configuration and quantization settings
7
+ peft_model_id = "Prasi21/blip2-opt-2.7b-strep-throat-caption-adapters3"
8
+ config = PeftConfig.from_pretrained(peft_model_id)
9
+ config.base_model_name_or_path = "Prasi21/blip2-opt-2.7b-strep-throat-caption-adapters3"
10
+
11
+ # Enable 8-bit quantization for more efficient loading
12
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
13
+
14
+ # Load the base model with quantization
15
+ model = Blip2ForConditionalGeneration.from_pretrained(
16
+ config.base_model_name_or_path,
17
+ quantization_config=quantization_config,
18
+ device_map="auto"
19
+ )
20
+
21
+ # Load the fine-tuned PEFT model
22
+ model = PeftModel.from_pretrained(model, peft_model_id)
23
+
24
+ # Load the processor
25
+ processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
26
+
27
+ # Define the prediction function
28
+ def predict(image):
29
+ # Preprocess the image
30
+ inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
31
+ new_eos_token_id = 13
32
+ with torch.no_grad():
33
+ generated_ids = modelA.generate(**inputs, max_length=100,
34
+ eos_token_id=new_eos_token_id)
35
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
36
+ return f"{generated_caption[0]}"
37
+
38
+ # Set up the Gradio interface
39
+ demo = gr.Interface(
40
+ fn=predict,
41
+ inputs=gr.inputs.Image(type="pil"), # Upload an image in PIL format
42
+ outputs="text", # The output will be the generated caption
43
+ title="Strep Throat Image Assessment",
44
+ description="Upload an image of a throat and receive a medical assessment caption based on the model's output."
45
+ )
46
+
47
+ # Launch the Gradio app
48
+ demo.launch()