File size: 6,257 Bytes
711ffc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
import random
import time
import os
from glob import glob
from PIL import Image
import torchvision.transforms as transforms
num_rank = 200
image_prefix = "/deep/u/eprakash/AngioSeg/diffusion/lung_seg_synthetic_60/synth/"
mask_prefix = "/deep/u/eprakash/AngioSeg/diffusion/lung_seg_synthetic_60/orig/"
image_ids = []
img_list = "/deep/u/eprakash/lung_seg/train_60.csv"
with open(img_list) as fp:
for line in fp:
image_ids.append("('" + line.strip().split(",")[0] + "',)")
image_ids = image_ids[301:501]
save_path = "lung_seg_ranks"
def is_int(s):
try:
int(s)
return True
except ValueError:
return False
def load_img(img_path, size=512):
img = Image.open(img_path).convert('RGB')
transform_list = [transforms.Resize((size, size))]
transform = transforms.Compose(transform_list)
img = transform(img)
return img
def find_completed_idxs(save_path=save_path):
files = os.listdir(save_path)
incorrect_files = []
if len(files) == 0:
return [-1], []
else:
file_list = []
for f in files:
f_name = int(f.split(".")[0])
with open(save_path + "/" + f) as fp:
for line in fp:
items = line.strip().split(",")
if (len(items) != 5 and f_name != -1):
incorrect_files.append(f_name)
else:
if ((not is_int(items[1].strip()) or not is_int(items[2].strip()) or not is_int(items[3].strip()) or not is_int(items[4].strip())) and f_name != -1):
incorrect_files.append(f_name)
file_list.append(f_name)
file_list = sorted(file_list)
incorrect_files = sorted(incorrect_files)
return file_list, incorrect_files
def load_next(rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example, ids=image_ids, image_prefix=image_prefix, save_path=save_path):
file_list, incorrect_files = find_completed_idxs()
print(str(file_list) + " " + str(incorrect_files))
if (int(example) not in file_list or int(example) in incorrect_files):
r = str(image_ids[int(example)]).split(",")[0].split("(")[1] + "," + rank
r_fp = open(save_path + "/" + str(int(example)) +".txt", "w")
r_fp.write(r + "\n")
r_fp.close()
file_list, incorrect_files = find_completed_idxs()
if (len(incorrect_files) != 0):
example = incorrect_files[-1]
else:
example = file_list[-1] + 1
if int(example) == num_rank:
rank = "DONE!"
example = -1
mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
else:
rank = ""
img_1 = gr.Image(label="Sample #1", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_0.png"), interactive=False)
mask_1 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
img_2 = gr.Image(label="Sample #2", value=load_img(image_prefix+ str(image_ids[int(example)]) + "_synthetic_1.png"), interactive=False)
mask_2 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
img_3 = gr.Image(label="Sample #3", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_2.png"), interactive=False)
mask_3 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
img_4 = gr.Image(label="Sample #4", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_3.png"), interactive=False)
mask_4 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
return [rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example]
with gr.Blocks() as demo:
last_idx = -1
example = gr.Number(label="Example #. Click next for #-1 (blank starting page).", value=last_idx, interactive=False)
rank = gr.Textbox(label="Rankings (Best to worst, comma-separated, no spaces).")
with gr.Column(scale=1):
with gr.Row():
mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
with gr.Row():
mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
with gr.Row():
mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
with gr.Row():
mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
next_btn = gr.Button(value="Next")
next_btn.click(fn=load_next, inputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], outputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], queue=False)
demo.queue()
demo.launch(share=True)
|