CheXRay / app.py
Tonic's picture
Update app.py
b3cab48 verified
raw
history blame
4.83 kB
import spaces
import io
import torch
from PIL import Image
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
title = """# Welcome to🌟Tonic's CheXRay⚕⚛ !
You can use this ZeroGPU Space to test out the current model [StanfordAIMI/CheXagent-8b](https://huggingface.co/StanfordAIMI/CheXagent-8b). CheXRay⚕⚛ is fine tuned to analyze chest x-rays with a different and generally better results than other multimodal models.
You can also useCheXRay⚕⚛ by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/CheXRay?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
### How To use
Upload a medical image and enter a prompt to receive an AI-generated analysis.
simply upload an image with the right prompt (coming soon!) and anaylze your Xray !
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [Poly](https://github.com/tonic-ai/poly) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
device = "cuda"
dtype = torch.float16
processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXagent-8b", trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained("StanfordAIMI/CheXagent-8b")
model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True)
@spaces.GPU
def generate(image, prompt):
# Convert the uploaded file to an image and process
image = Image.open(io.BytesIO(image.read())).convert("RGB")
images = [image]
# Prepare inputs
inputs = processor(images=images, text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt").to(device=device, dtype=dtype)
# Generate the findings
output = model.generate(**inputs, generation_config=generation_config)[0]
response = processor.tokenizer.decode(output, skip_special_tokens=True)
return response
with gr.Blocks() as demo:
gr.Markdown(title)
with gr.Accordion("Custom Prompt Analysis"):
with gr.Row():
image_input_custom = gr.Image(type="pil")
prompt_input_custom = gr.Textbox(label="Enter your custom prompt")
generate_button_custom = gr.Button("Generate")
output_text_custom = gr.Textbox(label="Response")
generate_button_custom.click(fn=generate, inputs=[image_input_custom, prompt_input_custom], outputs=output_text_custom)
with gr.Accordion("Anatomical Feature Analysis"):
anatomies = [
"Airway", "Breathing", "Cardiac", "Diaphragm",
"Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, and pacemakers)"
]
with gr.Row():
image_input_feature = gr.Image(type="pil")
prompt_select = gr.Dropdown(label="Select an anatomical feature", choices=anatomies)
generate_button_feature = gr.Button("Analyze Feature")
output_text_feature = gr.Textbox(label="Response")
generate_button_feature.click(fn=lambda image, feature: generate(image, f'Describe "{feature}"'), inputs=[image_input_feature, prompt_select], outputs=output_text_feature)
with gr.Accordion("Common Abnormalities Analysis"):
common_abnormalities = ["Lung Nodule", "Pleural Effusion", "Pneumonia"]
with gr.Row():
image_input_abnormality = gr.Image(type="pil")
abnormality_select = gr.Dropdown(label="Select a common abnormality", choices=common_abnormalities)
generate_button_abnormality = gr.Button("Analyze Abnormality")
output_text_abnormality = gr.Textbox(label="Response")
generate_button_abnormality.click(fn=lambda image, abnormality: generate(image, f'Analyze for "{abnormality}"'), inputs=[image_input_abnormality, abnormality_select], outputs=output_text_abnormality)
demo.launch()