File size: 4,798 Bytes
d3c95ed
9555522
 
 
 
 
 
 
 
d3c95ed
 
9555522
 
 
 
 
 
 
 
d3c95ed
 
 
 
9555522
d3c95ed
 
9555522
 
 
 
 
 
 
 
d3c95ed
9555522
 
 
 
 
 
 
 
 
d3c95ed
9555522
 
 
 
 
 
 
 
d3c95ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9555522
d3c95ed
 
9555522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import os
import zipfile
from pathlib import Path
import io
from tempfile import NamedTemporaryFile

from PIL import Image
import gradio as gr
import torch
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader
import spaces

torch.jit.script = lambda f: f
# torch.cuda.amp.autocast(enabled=True)

caption_ext = ".txt"
exclude_tags = ("explicit", "questionable", "safe")

transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class ZipImageDataset(Dataset):
    def __init__(self, zip_file, dtype):
        self.zip_file = zip_file
        self.dtype = dtype
        self.image_files = [file_info for file_info in zip_file.infolist() if file_info.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        file_info = self.image_files[index]
        with self.zip_file.open(file_info) as file:
            image = Image.open(file).convert("RGB")
        image = transform(image).to(self.dtype)
        return {
            "image": image,
            "image_name": file_info.filename,
        }

model = torch.load("./model.pth", map_location=torch.device('cpu'))
model.eval()

with open("tags_9940.json", "r") as file:
    tags = json.load(file)
allowed_tags = sorted(tags) + ["explicit", "questionable", "safe"]

@spaces.GPU(duration=5)
def create_tags(image, threshold):
    img = image.convert('RGB')
    tensor = transform(img).unsqueeze(0)

    with torch.no_grad():
        logits = model(tensor)
        probabilities = torch.nn.functional.sigmoid(logits[0])
        indices = torch.where(probabilities > threshold)[0]
        values = probabilities[indices]

    temp = []
    tag_score = dict()
    for i in range(indices.size(0)):
        temp.append([allowed_tags[indices[i]], values[i].item()])
        tag_score[allowed_tags[indices[i]]] = values[i].item()
    temp = [t[0] for t in temp]
    text_no_impl = ", ".join(temp)
    return text_no_impl, tag_score

@spaces.GPU(duration=180)
def process_zip(zip_file, threshold):
    with zipfile.ZipFile(zip_file.name) as zip_ref:
        dataset = ZipImageDataset(zip_ref, next(model.parameters()).dtype)
        dataloader = DataLoader(
            dataset,
            batch_size=64,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
        )
        all_image_names = []
        all_probabilities = []
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                images = batch["image"]
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    outputs = model(images)
                    probabilities = torch.nn.functional.sigmoid(outputs)
                    for image_name, prob in zip(batch["image_name"], probabilities):
                        indices = torch.where(prob > threshold)[0]
                        values = prob[indices]
                        temp = []
                        tag_score = dict()
                        for j in range(indices.size(0)):
                            temp.append([allowed_tags[indices[j]], values[j].item()])
                            tag_score[allowed_tags[indices[j]]] = values[j].item()
                        temp = [t[0] for t in temp]
                        text_no_impl = ", ".join(temp)
                        all_image_names.append(image_name)
                        all_probabilities.append(text_no_impl)

    temp_file = NamedTemporaryFile(delete=False, suffix=".zip")
    with zipfile.ZipFile(temp_file, "w") as zip_ref:
        for image_name, text_no_impl in zip(all_image_names, all_probabilities):
            with zip_ref.open(image_name + caption_ext, "w") as file:
                file.write(text_no_impl.encode())
    temp_file.seek(0)
    return temp_file.name

with gr.Blocks() as demo:
    with gr.Tab("Single Image"):
        gr.Interface(
            create_tags,
            inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold")],
            outputs=[
                gr.Textbox(label="Tag String"),
                gr.Label(label="Tag Predictions", num_top_classes=200),
            ],
            allow_flagging="never",
        )
    with gr.Tab("Multiple Images"):
        gr.Interface(fn=process_zip, inputs=[gr.File(label="Zip File", file_types=[".zip"]), gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="Threshold")],
                      outputs=gr.File(type="binary"))

if __name__ == "__main__":
    demo.launch()