Jyothirmai commited on
Commit
2e77581
β€’
1 Parent(s): a7e52ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -58
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  from PIL import Image
3
  import clipGPT
@@ -11,12 +13,12 @@ from build_vocab import Vocabulary
11
 
12
 
13
  # Caption generation functions
14
- def generate_caption_clipgpt(image, max_tokens, temperature):
15
- caption = clipGPT.generate_caption_clipgpt(image, max_tokens, temperature)
16
  return caption
17
 
18
- def generate_caption_vitgpt(image, max_tokens, temperature):
19
- caption = vitGPT.generate_caption(image, max_tokens, temperature)
20
  return caption
21
 
22
  def generate_caption_vitCoAtt(image):
@@ -26,78 +28,48 @@ def generate_caption_vitCoAtt(image):
26
 
27
  with gr.Blocks() as demo:
28
 
 
29
  gr.HTML("<h1 style='text-align: center;'>MedViT: A Vision Transformer-Driven Method for Generating Medical Reports πŸ₯πŸ€–</h1>")
30
  gr.HTML("<p style='text-align: center;'>You can generate captions by uploading an X-Ray and selecting a model of your choice below</p>")
31
 
32
- sample_data = [
33
- ['sample/CXR192_IM-0598-1001.png', '75', '0.7', 'CLIP-GPT2, ViT-GPT2, ViT-CoAttention', '...'],
34
- ['https://imgur.com/MLJaWnf', '50' ,'0.8', 'CLIP-GPT2, ViT-CoAttention', '...']
35
- ]
36
-
37
  with gr.Row():
38
- image_display = gr.Image(label="Selected Image") # Area to display the selected image
39
- image_table = gr.Dataframe(sample_data, headers=['image', 'max token', 'temp', 'model supported', 'ground truth'], datatype=['str', 'str', 'str', 'str', 'str']) # Changed to 'str'
 
 
 
 
 
 
 
 
 
 
40
 
41
  with gr.Row():
42
- image = gr.Image(label="Upload Chest X-ray", type="pil")
43
 
 
44
 
45
- gr.HTML("<p style='text-align: center;'> Please select the Number of Max Tokens and Temperature setting, if you are testing CLIP GPT2 and VIT GPT2 Models</p>")
46
-
47
- with gr.Row():
48
- with gr.Column():
49
- max_tokens = gr.Dropdown(list(range(50, 101)), label="Max Tokens", value=75)
50
- temperature = gr.Slider(0.5, 0.9, step=0.1, label="Temperature", value=0.7)
51
-
52
- model_choice = gr.Radio(["CLIP-GPT2", "ViT-GPT2", "ViT-CoAttention"], label="Select Model")
53
 
54
- generate_button = gr.Button("Generate Caption")
55
  caption = gr.Textbox(label="Generated Caption")
56
-
57
 
58
- def predict(img, model_name, max_tokens, temperature):
59
  if model_name == "CLIP-GPT2":
60
- return generate_caption_clipgpt(img, max_tokens, temperature)
61
  elif model_name == "ViT-GPT2":
62
- return generate_caption_vitgpt(img, max_tokens, temperature)
63
  elif model_name == "ViT-CoAttention":
64
  return generate_caption_vitCoAtt(img)
65
  else:
66
- return "Caption generation for this model is not yet implemented."
67
 
68
- def predict_from_table(row, model_name):
69
- img_url = row['image']
70
- max_tokens = row['max token']
71
- temperature = row['temp']
72
-
73
- image_display.update(value=img_url) # Update the image display
74
-
75
- # Load the image
76
- img = Image.open(io.imread(img_url))
77
-
78
- # Generate the caption
79
- if model_name == "CLIP-GPT2":
80
- caption = generate_caption_clipgpt(img, max_tokens, temperature)
81
- elif model_name == "ViT-GPT2":
82
- caption = generate_caption_vitgpt(img, max_tokens, temperature)
83
- elif model_name == "ViT-CoAttention":
84
- caption = generate_caption_vitCoAtt(img)
85
- else:
86
- caption = "Caption generation for this model is not yet implemented."
87
-
88
- return image_display, caption # Return both the image and the caption
89
-
90
- # Event handlers
91
- generate_button.click(predict, [image, model_choice, max_tokens, temperature], caption)
92
 
93
- # Create an event handler for the entire Dataframe
94
- def dataframe_selected(index, row, model_choice):
95
- if row is not None: # Check if an actual row was selected
96
- return predict_from_table(row, model_choice)
97
-
98
- # Attach the function to the dataframe
99
- image_table.change(dataframe_selected, inputs=[image_table, model_choice], outputs=[image_display, caption])
100
-
101
 
102
 
103
  demo.launch()
 
1
+
2
+
3
  import gradio as gr
4
  from PIL import Image
5
  import clipGPT
 
13
 
14
 
15
  # Caption generation functions
16
+ def generate_caption_clipgpt(image):
17
+ caption = clipGPT.generate_caption_clipgpt(image)
18
  return caption
19
 
20
+ def generate_caption_vitgpt(image):
21
+ caption = vitGPT.generate_caption(image)
22
  return caption
23
 
24
  def generate_caption_vitCoAtt(image):
 
28
 
29
  with gr.Blocks() as demo:
30
 
31
+
32
  gr.HTML("<h1 style='text-align: center;'>MedViT: A Vision Transformer-Driven Method for Generating Medical Reports πŸ₯πŸ€–</h1>")
33
  gr.HTML("<p style='text-align: center;'>You can generate captions by uploading an X-Ray and selecting a model of your choice below</p>")
34
 
35
+
 
 
 
 
36
  with gr.Row():
37
+ sample_images = [
38
+ 'https://imgur.com/W1pIr9b',
39
+ 'https://imgur.com/MLJaWnf',
40
+ 'https://imgur.com/6XymFW1',
41
+ 'https://imgur.com/zdPjZZ1',
42
+ 'https://imgur.com/DKUlZbF'
43
+ ]
44
+
45
+
46
+ image = gr.Image(label="Upload Chest X-ray", type="pil")
47
+
48
+ sample_images_gallery = gr.Gallery(value = sample_images,label="Sample Images")
49
 
50
  with gr.Row():
51
+ model_choice = gr.Radio(["CLIP-GPT2", "ViT-GPT2", "ViT-CoAttention"], label="Select Model")
52
 
53
+ generate_button = gr.Button("Generate Caption")
54
 
 
 
 
 
 
 
 
 
55
 
56
+
57
  caption = gr.Textbox(label="Generated Caption")
 
58
 
59
+ def predict(img, model_name):
60
  if model_name == "CLIP-GPT2":
61
+ return generate_caption_clipgpt(img)
62
  elif model_name == "ViT-GPT2":
63
+ return generate_caption_vitgpt(img)
64
  elif model_name == "ViT-CoAttention":
65
  return generate_caption_vitCoAtt(img)
66
  else:
67
+ return "Caption generation for this model is not yet implemented."
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # Event handlers
71
+ generate_button.click(predict, [image, model_choice], caption) # Trigger prediction on button click
72
+ sample_images_gallery.change(predict, [sample_images_gallery, model_choice], caption) # Handle sample images
 
 
 
 
 
73
 
74
 
75
  demo.launch()