Filip
update
fe01251
raw
history blame
4.37 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextStreamer
import torch
import gc
import os
# Enable better CPU performance
torch.set_num_threads(4)
device = "cpu"
def load_model():
model_name = "forestav/unsloth_vision_radiography_finetune"
base_model_name = "unsloth/Llama-3.2-11B-Vision-Instruct" # Correct base model
print("Loading tokenizer and processor...")
# Load tokenizer from base model
tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
trust_remote_code=True
)
# Load processor from base model
processor = AutoProcessor.from_pretrained(
base_model_name,
trust_remote_code=True
)
print("Loading model...")
# Load model with CPU optimizations
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
offload_folder="offload",
offload_state_dict=True,
trust_remote_code=True
)
print("Quantizing model...")
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
return model, tokenizer, processor
# Create offload directory if it doesn't exist
os.makedirs("offload", exist_ok=True)
# Initialize model and tokenizer globally
print("Starting model initialization...")
try:
model, tokenizer, processor = load_model()
print("Model loaded and quantized successfully!")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
def analyze_image(image, instruction):
try:
# Clear memory
gc.collect()
if instruction.strip() == "":
instruction = "You are an expert radiographer. Describe accurately what you see in this image."
# Prepare the messages
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": instruction}
]}
]
# Process the image and text
inputs = processor(
images=image,
text=tokenizer.apply_chat_template(messages, add_generation_prompt=True),
return_tensors="pt"
)
# Generate with conservative settings for CPU
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=128,
temperature=1.0,
min_p=0.1,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
num_beams=1
)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up
del outputs
gc.collect()
return response
except Exception as e:
return f"Error processing image: {str(e)}\nPlease try again with a smaller image or different settings."
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""
# Medical Image Analysis Assistant
Upload a medical image and receive a professional description from an AI radiographer.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Upload Medical Image",
max_pixels=1500000 # Limit image size
)
instruction_input = gr.Textbox(
label="Custom Instruction (optional)",
placeholder="You are an expert radiographer. Describe accurately what you see in this image.",
lines=2
)
submit_btn = gr.Button("Analyze Image")
with gr.Column():
output_text = gr.Textbox(label="Analysis Result", lines=10)
# Handle the submission
submit_btn.click(
fn=analyze_image,
inputs=[image_input, instruction_input],
outputs=output_text
)
gr.Markdown("""
### Notes:
- The model runs on CPU and may take several moments to process each image
- For best results, upload images smaller than 1.5MP
- Please be patient during processing
""")
# Launch the app
if __name__ == "__main__":
demo.launch()