Jyothirmai commited on
Commit
afda258
1 Parent(s): 8a13188

Upload 8 files

Browse files
CXR191_IM-0591-1001.png ADDED
CXR192_IM-0598-1001.png ADDED
CXR193_IM-0601-1001.png ADDED
CXR194_IM-0609-1001.png ADDED
CXR195_IM-0618-1001.png ADDED
app.py CHANGED
@@ -1,37 +1,49 @@
1
  import gradio as gr
2
- from PIL import Image # For image handling
 
3
 
4
- # Replace with paths or loading functions for your specific models
5
- def load_model_1():
6
- # ... load your first model
7
- return model_1
8
 
9
- def load_model_2():
10
- # ... load your second model
11
- return model_2
12
 
13
- def load_model_3():
14
- # ... load your third model
15
- return model_3
16
-
17
- def generate_caption(model, image):
18
- # ... perform inference with your model
19
  return caption
20
 
21
- # models = [load_model_1(), load_model_2(), load_model_3()]
 
 
 
 
 
 
 
 
 
22
 
 
23
  with gr.Blocks() as demo:
24
  with gr.Row():
25
- image = gr.Image(label="Upload Chest X-ray")
26
- with gr.Row():
27
- gr.Radio(["Model 1", "Model 2", "Model 3"], label="Select Model")
 
28
  with gr.Row():
29
  caption = gr.Textbox(label="Generated Caption")
30
 
31
- # image.change(
32
- # fn=generate_caption,
33
- # inputs=[image, gr.inputs.Radio],
34
- # outputs=caption
35
- # )
 
 
 
 
 
36
 
37
  demo.launch()
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ import clipGPT
4
 
5
+ # Define model loading functions (if needed)
6
+ def load_model_1(): # CLIP-GPT2
7
+ # Load model components here if necessary
8
+ return None
9
 
10
+ # ... load_model_2(), load_model_3() - Define if and when needed
 
 
11
 
12
+ # Caption generation functions
13
+ def generate_caption_clipgpt(image):
14
+ caption = clipGPT.generate_caption_clipgpt(image)
 
 
 
15
  return caption
16
 
17
+ # ... Add more caption generation functions for future models
18
+
19
+ # Sample image paths
20
+ sample_images = [
21
+ "CXR191_IM-0591-1001.jpg",
22
+ "CXR191_IM-0598-1001.jpg",
23
+ "CXR191_IM-0601-1001.jpg",
24
+ "CXR191_IM-0609-1001.jpg",
25
+ "CXR191_IM-0618-1001.jpg"
26
+ ]
27
 
28
+ # Gradio interface
29
  with gr.Blocks() as demo:
30
  with gr.Row():
31
+ image = gr.Image(label="Upload Chest X-ray", source="upload")
32
+ sample_image_gallery = gr.ImageGallery(sample_images, label="Sample Images")
33
+ with gr.Row():
34
+ model_choice = gr.Radio(["CLIP-GPT2", "ViT-GPT2", "ViT-CoAttention"], label="Select Model")
35
  with gr.Row():
36
  caption = gr.Textbox(label="Generated Caption")
37
 
38
+ def predict(img, model_name):
39
+ if model_name == "CLIP-GPT2":
40
+ return generate_caption_clipgpt(img)
41
+ # Add elif blocks for "ViT-GPT2", "ViT-CoAttention" as you implement them
42
+ else:
43
+ return "Caption generation for this model is not yet implemented."
44
+
45
+ # Handle changes for both uploaded and sample images
46
+ gr.Image.change(predict, [image, model_choice], caption)
47
+ sample_image_gallery.change(predict, [sample_image_gallery, model_choice], caption)
48
 
49
  demo.launch()
clipGPT.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import clip
5
+ import skimage.io as io
6
+ import PIL.Image
7
+ from IPython.display import Image
8
+ from transformers import AutoTokenizer, AutoModel
9
+ import skimage.io as io
10
+ import PIL.Image
11
+ from IPython.display import Image
12
+
13
+ import pandas as pd
14
+ import numpy as np
15
+ import time
16
+ import json
17
+ import nltk
18
+ nltk.download('punkt')
19
+
20
+
21
+ class ClipGPT2Model(nn.Module):
22
+ def __init__(self, img_feature_length, img_feature_size = 512):
23
+ super(ClipGPT2Model, self).__init__()
24
+ torch.cuda.empty_cache()
25
+ gc.collect()
26
+ self.img_feature_length = img_feature_length
27
+
28
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
29
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
30
+ self.clip_project = Adapter((img_feature_size,
31
+ (self.gpt_embedding_size * img_feature_length) // 2,
32
+ self.gpt_embedding_size * img_feature_length))
33
+ torch.cuda.empty_cache()
34
+ def get_dummy_token(self,
35
+ batch_size: int,
36
+ device: torch.device) -> torch.Tensor:
37
+ return torch.zeros(batch_size, self.img_feature_length, dtype=torch.int64, device=device)
38
+
39
+ def forward(self,
40
+ tokens: torch.Tensor,
41
+ feature: torch.Tensor,
42
+ mask = None,
43
+ labels = None):
44
+ torch.cuda.empty_cache()
45
+ gc.collect()
46
+
47
+ embedding_text = self.gpt.transformer.wte(tokens)
48
+ feature_projections = self.clip_project(feature).view(-1, self.img_feature_length, self.gpt_embedding_size)
49
+ embedding_cat = torch.cat((feature_projections, embedding_text), dim=1)
50
+ if labels is not None:
51
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
52
+ labels = torch.cat((dummy_token, tokens), dim=1)
53
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
54
+ return out
55
+
56
+
57
+
58
+
59
+ def generate_beam(
60
+ model,
61
+ tokenizer,
62
+ beam_size: int = 10,
63
+ prompt=None,
64
+ embed=None,
65
+ entry_length=76,
66
+ temperature=0.9,
67
+ stop_token: str = ".",
68
+ ):
69
+
70
+ model.eval()
71
+ stop_token_index = tokenizer.encode(stop_token)[0]
72
+ tokens = None
73
+ scores = None
74
+ device = next(model.parameters()).device
75
+ seq_lengths = torch.ones(beam_size, device=device)
76
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
77
+ with torch.no_grad():
78
+ if embed is not None:
79
+ generated = embed
80
+ else:
81
+ if tokens is None:
82
+ tokens = torch.tensor(tokenizer.encode(prompt))
83
+ tokens = tokens.unsqueeze(0).to(device)
84
+ generated = model.gpt.transformer.wte(tokens)
85
+ for i in range(entry_length):
86
+ outputs = model.gpt(inputs_embeds=generated)
87
+ logits = outputs.logits
88
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
89
+ logits = logits.softmax(-1).log()
90
+ if scores is None:
91
+ scores, next_tokens = logits.topk(beam_size, -1)
92
+ generated = generated.expand(beam_size, *generated.shape[1:])
93
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
94
+ if tokens is None:
95
+ tokens = next_tokens
96
+ else:
97
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
98
+ tokens = torch.cat((tokens, next_tokens), dim=1)
99
+ else:
100
+ logits[is_stopped] = -float(np.inf)
101
+ logits[is_stopped, 0] = 0
102
+ scores_sum = scores[:, None] + logits
103
+ seq_lengths[~is_stopped] += 1
104
+ scores_sum_average = scores_sum / seq_lengths[:, None]
105
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
106
+ beam_size, -1
107
+ )
108
+ next_tokens_source = next_tokens // scores_sum.shape[1]
109
+ seq_lengths = seq_lengths[next_tokens_source]
110
+ next_tokens = next_tokens % scores_sum.shape[1]
111
+ next_tokens = next_tokens.unsqueeze(1)
112
+ tokens = tokens[next_tokens_source]
113
+ tokens = torch.cat((tokens, next_tokens), dim=1)
114
+ generated = generated[next_tokens_source]
115
+ scores = scores_sum_average * seq_lengths
116
+ is_stopped = is_stopped[next_tokens_source]
117
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
118
+ generated.shape[0], 1, -1
119
+ )
120
+ generated = torch.cat((generated, next_token_embed), dim=1)
121
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
122
+ if is_stopped.all():
123
+ break
124
+ scores = scores / seq_lengths
125
+ output_list = tokens.cpu().numpy()
126
+ output_texts = [
127
+ tokenizer.decode(output[: int(length)])
128
+ for output, length in zip(output_list, seq_lengths)
129
+ ]
130
+ order = scores.argsort(descending=True)
131
+ output_texts = [output_texts[i] for i in order]
132
+ return output_texts
133
+
134
+
135
+
136
+ def generate_caption_clipgpt(img):
137
+
138
+ prefix_length = 10
139
+ model = ClipGPT2Model(prefix_length, img_feature_size = feature_dim)
140
+ model.load_state_dict(torch.load('model_train_best_run_clipGPT.pt'))
141
+ model = model.eval()
142
+ device = "cuda" if torch.cuda.is_available() else "cpu"
143
+ model = model.to(device)
144
+
145
+
146
+ clip_model, preprocess = clip.load('ViT-B/32', device, jit=False)
147
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
148
+
149
+ start_time = time.time()
150
+ image = io.imread(img)
151
+ pil_image = PIL.Image.fromarray(image)
152
+ image = preprocess(pil_image).unsqueeze(0).to(device)
153
+
154
+ with torch.no_grad():
155
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
156
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
157
+ beam_caption = generate_beam(model, tokenizer, embed=prefix_embed)[0]
158
+
159
+ end_time = time.time()
160
+ print("--- Time taken to generate: %s seconds ---" % (end_time - start_time))
161
+
162
+ return beam_caption
163
+
164
+
model_train_best_run_clipGPT.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d75b4bf1a982290d2675a78b1f2bc39fa212178f5f609a555a1725150fe5275
3
+ size 561159626