|
import gradio as gr |
|
from PIL import Image |
|
import clipGPT |
|
import vitGPT |
|
import skimage.io as io |
|
import PIL.Image |
|
import difflib |
|
import ViTCoAtt |
|
import cnnrnn |
|
from build_vocab import Vocabulary |
|
|
|
|
|
|
|
|
|
|
|
def generate_caption_clipgpt(image, max_tokens, temperature): |
|
caption = clipGPT.generate_caption_clipgpt(image, max_tokens, temperature) |
|
return caption |
|
|
|
def generate_caption_vitgpt(image, max_tokens, temperature): |
|
caption = vitGPT.generate_caption(image, max_tokens, temperature) |
|
return caption |
|
|
|
def generate_caption_vitCoAtt(image): |
|
caption = ViTCoAtt.CaptionSampler.main(image) |
|
return caption |
|
|
|
def generate_caption_cnnrnn(image): |
|
|
|
|
|
|
|
|
|
caption = "" |
|
return caption |
|
|
|
|
|
with gr.Row(): |
|
|
|
image = gr.Image(label="Upload Chest X-ray", type="pil") |
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
max_tokens = gr.Dropdown(list(range(50, 101)), label="Max Tokens", value=75) |
|
temperature = gr.Slider(0.5, 0.9, step=0.1, label="Temperature", value=0.7) |
|
|
|
model_choice = gr.Radio(["CLIP-GPT2", "ViT-GPT2", "ViT-CoAttention"], label="Select Model") |
|
generate_button = gr.Button("Generate Caption") |
|
|
|
|
|
caption = gr.Textbox(label="Generated Caption") |
|
|
|
def predict(img, model_name, max_tokens, temperature): |
|
if model_name == "CLIP-GPT2": |
|
return generate_caption_clipgpt(img, max_tokens, temperature) |
|
elif model_name == "ViT-GPT2": |
|
return generate_caption_vitgpt(img, max_tokens, temperature) |
|
elif model_name == "ViT-CoAttention": |
|
return generate_caption_vitCoAtt(img) |
|
elif model_name == "Baseline Model CNN-RNN": |
|
print(img) |
|
return generate_caption_cnnrnn(img) |
|
else: |
|
return "Caption generation for this model is not yet implemented." |
|
|
|
|
|
|
|
examples = [[f"example{i}.jpg"] for i in range(1,4)] |
|
|
|
description= "You can generate captions by uploading an X-Ray and selecting a model of your choice below. Please select the number of Max Tokens and Temperature setting, if you are testing CLIP GPT2 and VIT GPT2 Models" |
|
title = "MedViT: A Vision Transformer-Driven Method for Generating Medical Reports π₯π€" |
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs = [image, model_choice, max_tokens, temperature], |
|
theme="soft", |
|
outputs=caption, |
|
examples = examples, |
|
title = title, |
|
description = description |
|
) |
|
|
|
interface.launch(debug=True) |
|
|
|
|
|
|