File size: 6,257 Bytes
711ffc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import random
import time
import os
from glob import glob
from PIL import Image
import torchvision.transforms as transforms

num_rank = 200
image_prefix = "/deep/u/eprakash/AngioSeg/diffusion/lung_seg_synthetic_60/synth/"
mask_prefix = "/deep/u/eprakash/AngioSeg/diffusion/lung_seg_synthetic_60/orig/"
image_ids = []
img_list = "/deep/u/eprakash/lung_seg/train_60.csv"
with open(img_list) as fp:
    for line in fp:
        image_ids.append("('" + line.strip().split(",")[0] + "',)")
image_ids = image_ids[301:501]
save_path = "lung_seg_ranks"

def is_int(s):
    try:
        int(s)
        return True
    except ValueError:
        return False

def load_img(img_path, size=512):
    img = Image.open(img_path).convert('RGB')
    transform_list = [transforms.Resize((size, size))]
    transform = transforms.Compose(transform_list)
    img = transform(img)
    return img

def find_completed_idxs(save_path=save_path):
    files = os.listdir(save_path)
    incorrect_files = []
    if len(files) == 0:
        return [-1], []
    else:
        file_list = []
        for f in files:
            f_name = int(f.split(".")[0])
            with open(save_path + "/" + f) as fp:
                for line in fp:
                    items = line.strip().split(",")
                    if (len(items) != 5 and f_name != -1):
                        incorrect_files.append(f_name)
                    else:
                        if ((not is_int(items[1].strip()) or not is_int(items[2].strip()) or not is_int(items[3].strip()) or not is_int(items[4].strip())) and f_name != -1):
                            incorrect_files.append(f_name)
            file_list.append(f_name)
        file_list = sorted(file_list)
        incorrect_files = sorted(incorrect_files)
        return file_list, incorrect_files

def load_next(rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example, ids=image_ids, image_prefix=image_prefix, save_path=save_path):
    file_list, incorrect_files = find_completed_idxs()
    print(str(file_list) + " " + str(incorrect_files))
    if (int(example) not in file_list or int(example) in incorrect_files):
        r = str(image_ids[int(example)]).split(",")[0].split("(")[1] + "," + rank
        r_fp = open(save_path + "/" + str(int(example)) +".txt", "w")
        r_fp.write(r + "\n")
        r_fp.close()
    file_list, incorrect_files = find_completed_idxs()
    if (len(incorrect_files) != 0):
        example = incorrect_files[-1]
    else:
        example = file_list[-1] + 1
    if int(example) == num_rank:
        rank = "DONE!"
        example = -1
        mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
    else:
        rank = ""
        img_1 = gr.Image(label="Sample #1", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_0.png"), interactive=False)
        mask_1 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
        img_2 = gr.Image(label="Sample #2", value=load_img(image_prefix+ str(image_ids[int(example)]) + "_synthetic_1.png"), interactive=False)
        mask_2 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
        img_3 = gr.Image(label="Sample #3", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_2.png"), interactive=False)
        mask_3 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
        img_4 = gr.Image(label="Sample #4", value=load_img(image_prefix + str(image_ids[int(example)]) + "_synthetic_3.png"), interactive=False)
        mask_4 = gr.Image(label="Mask", value=load_img(mask_prefix + str(image_ids[int(example)]) + "_mask.png"), interactive=False)
    return [rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example]

with gr.Blocks() as demo:
    last_idx = -1
    example = gr.Number(label="Example #. Click next for #-1 (blank starting page).", value=last_idx, interactive=False)
    rank = gr.Textbox(label="Rankings (Best to worst, comma-separated, no spaces).")
    with gr.Column(scale=1):
        with gr.Row():
            mask_1 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
            img_1 = gr.Image(label="Sample #1", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        with gr.Row():
            mask_2 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
            img_2 = gr.Image(label="Sample #2", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        with gr.Row():
            mask_3 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
            img_3 = gr.Image(label="Sample #3", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
        with gr.Row():
            mask_4 = gr.Image(label="Mask", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
            img_4 = gr.Image(label="Sample #4", value=load_img("/deep/u/eprakash/blank.jpg"), interactive=False)
    next_btn = gr.Button(value="Next")
    next_btn.click(fn=load_next, inputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], outputs=[rank, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4, example], queue=False)
    demo.queue()
    demo.launch(share=True)