import gradio as gr from PIL import Image import clipGPT import vitGPT import skimage.io as io import PIL.Image import difflib import ViTCoAtt import cnnrnn from build_vocab import Vocabulary # Caption generation functions def generate_caption_clipgpt(image, max_tokens, temperature): caption = clipGPT.generate_caption_clipgpt(image, max_tokens, temperature) return caption def generate_caption_vitgpt(image, max_tokens, temperature): caption = vitGPT.generate_caption(image, max_tokens, temperature) return caption def generate_caption_vitCoAtt(image): caption = ViTCoAtt.CaptionSampler.main(image) return caption def generate_caption_cnnrnn(image): # with open('/content/Image_features_ecoder_decoder.pickle', 'rb') as f: # Xnet_features = pickle.load(f) # image = Xnet_features[image] # caption = cnn-rnn.get_result(image) caption = "" return caption with gr.Row(): image = gr.Image(label="Upload Chest X-ray", type="pil") with gr.Row(): with gr.Column(): # Column for dropdowns and model choice max_tokens = gr.Dropdown(list(range(50, 101)), label="Max Tokens", value=75) temperature = gr.Slider(0.5, 0.9, step=0.1, label="Temperature", value=0.7) model_choice = gr.Radio(["CLIP-GPT2", "ViT-GPT2", "ViT-CoAttention", "Baseline Model CNN-RNN"], label="Select Model") generate_button = gr.Button("Generate Caption") caption = gr.Textbox(label="Generated Caption") def predict(img, model_name, max_tokens, temperature): if model_name == "CLIP-GPT2": return generate_caption_clipgpt(img, max_tokens, temperature) elif model_name == "ViT-GPT2": return generate_caption_vitgpt(img, max_tokens, temperature) elif model_name == "ViT-CoAttention": return generate_caption_vitCoAtt(img) elif model_name == "Baseline Model CNN-RNN": print(img) return generate_caption_cnnrnn(img) else: return "Caption generation for this model is not yet implemented." examples = [[f"example{i}.jpg"] for i in range(1,4)] description= "You can generate captions by uploading an X-Ray and selecting a model of your choice below. Please select the number of Max Tokens and Temperature setting, if you are testing CLIP GPT2 and VIT GPT2 Models" title = "MedViT: A Vision Transformer-Driven Method for Generating Medical Reports 🏥🤖" interface = gr.Interface( fn=predict, inputs = [image, model_choice, max_tokens, temperature], theme="soft", outputs=caption, examples = examples, title = title, description = description ) interface.launch(debug=True)