Spaces:
Runtime error
Runtime error
File size: 6,208 Bytes
7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 8d0e872 4dab50d 65193db 4dab50d 8d0e872 f307fe5 a8416ee f307fe5 8d0e872 4dab50d 65193db 4dab50d 8d0e872 ab30850 8d0e872 79c7b01 09da12b 8d0e872 1df8b5f 7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 7cc986f 7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 b80df5c 8d0e872 250dc27 8d0e872 1599f11 ed768de b58ad35 8d0e872 7d1df38 4dab50d 7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 5650fb4 1599f11 5650fb4 8d0e872 11ab28e 8d0e872 7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 8d0e872 7d1df38 8d0e872 defbed4 8d0e872 7d1df38 8d0e872 5d3b8a6 8d0e872 5650fb4 8d0e872 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import os
import json
import glob
import random
import torch
import torchvision
import streamlit as st
import wordsegment as ws
from PIL import Image
from huggingface_hub import hf_hub_url, cached_download
from virtex.config import Config
from virtex.factories import TokenizerFactory, PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager
CONFIG_PATH = "config.yaml"
MODEL_PATH = "checkpoint_last5.pth"
VALID_SUBREDDITS_PATH = "subreddit_list.json"
SAMPLES_PATH = "./samples/*.jpg"
class ImageLoader:
def __init__(self):
self.image_transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
),
]
)
self.show_size = 500
def load(self, im_path):
im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0)
return {"image": im}
def raw_load(self, im_path):
im = torch.FloatTensor(Image.open(im_path))
return {"image": im}
def transform(self, image):
im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0)
return {"image": im}
def text_transform(self, text):
# at present just lowercasing:
return text.lower()
def show_resize(self, image):
# ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
image = torchvision.transforms.functional.to_tensor(image)
x, y = image.shape[-2:]
ratio = float(self.show_size / max((x, y)))
image = torchvision.transforms.functional.resize(
image, [int(x * ratio), int(y * ratio)]
)
return torchvision.transforms.functional.to_pil_image(image)
class VirTexModel:
def __init__(self):
self.config = Config(CONFIG_PATH)
ws.load()
self.device = "cpu"
self.tokenizer = TokenizerFactory.from_config(self.config)
self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
CheckpointManager(model=self.model).load(MODEL_PATH)
self.model.eval()
self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
def predict(self, image_dict, sub_prompt=None, prompt=""):
if sub_prompt is None:
subreddit_tokens = torch.tensor(
[self.model.sos_index], device=self.device
).long()
else:
subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt)))
subreddit_tokens = (
[self.model.sos_index]
+ self.tokenizer.encode(subreddit_tokens)
+ [self.tokenizer.token_to_id("[SEP]")]
)
subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
if prompt != "":
# at present prompts without subreddits will break without this change
# TODO FIX
cap_tokens = self.tokenizer.encode(prompt)
cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
subreddit_tokens = (
subreddit_tokens
if sub_prompt is not None
else torch.tensor(
(
[self.model.sos_index]
+ self.tokenizer.encode("pics")
+ [self.tokenizer.token_to_id("[SEP]")]
),
device=self.device,
).long()
)
subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
is_valid_subreddit = False
subreddit, rest_of_caption = "", ""
image_dict["decode_prompt"] = subreddit_tokens
while not is_valid_subreddit:
with torch.no_grad():
caption = self.model(image_dict)["predictions"][0].tolist()
if self.tokenizer.token_to_id("[SEP]") in caption:
sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
caption[sep_index] = self.tokenizer.token_to_id("://")
caption = self.tokenizer.decode(caption)
if "://" in caption:
subreddit, rest_of_caption = caption.split("://")
subreddit = "".join(subreddit.split())
rest_of_caption = rest_of_caption.strip()
else:
subreddit, rest_of_caption = "", caption.strip()
# split prompt for coloring:
if prompt != "":
_, rest_of_caption = caption.split(prompt.strip())
is_valid_subreddit = subreddit in self.valid_subs
return subreddit, rest_of_caption
def download_files():
# download model files
download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
for f in download_files:
fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
os.system(f"cp {fp} ./{f}")
def get_samples():
return glob.glob(SAMPLES_PATH)
def get_rand_idx(samples):
return random.randint(0, len(samples) - 1)
@st.cache(allow_output_mutation=True) # allow mutation to update nucleus size
def create_objects():
sample_images = get_samples()
virtexModel = VirTexModel()
imageLoader = ImageLoader()
valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
valid_subs.insert(0, None)
return virtexModel, imageLoader, sample_images, valid_subs
footer = """<style>
a:link , a:visited{
color: blue;
background-color: transparent;
text-decoration: underline;
}
a:hover, a:active {
color: red;
background-color: transparent;
text-decoration: underline;
}
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: white;
color: black;
text-align: center;
}
</style>
<div class="footer">
<p>
*Please note that this model was explicitly not trained on images of people, and as a result is not designed to caption images with humans.
This demo accompanies our paper RedCaps.
Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson
</p>
</div>
"""
|